Skip to content
Snippets Groups Projects
Commit 1da164f4 authored by Andi Gerken's avatar Andi Gerken
Browse files

Fixed robofish-io-render

parent fa63c8cf
No related branches found
No related tags found
No related merge requests found
...@@ -128,6 +128,7 @@ def render(args=None): ...@@ -128,6 +128,7 @@ def render(args=None):
"fixed_view": False, "fixed_view": False,
"view_size": 60, "view_size": 60,
"slow_view": 0.8, "slow_view": 0.8,
"cut_frames": 0,
} }
for key, value in default_options.items(): for key, value in default_options.items():
......
...@@ -39,6 +39,8 @@ from matplotlib import animation ...@@ -39,6 +39,8 @@ from matplotlib import animation
from matplotlib import patches from matplotlib import patches
from matplotlib import cm from matplotlib import cm
from tqdm import tqdm
from subprocess import run from subprocess import run
# Remember: Update docstring when updating these two global variables # Remember: Update docstring when updating these two global variables
...@@ -833,6 +835,7 @@ class File(h5py.File): ...@@ -833,6 +835,7 @@ class File(h5py.File):
"margin": 15, "margin": 15,
"slow_view": 0.8, "slow_view": 0.8,
"slow_zoom": 0.95, "slow_zoom": 0.95,
"cut_frames": None,
} }
options = { options = {
...@@ -891,18 +894,33 @@ class File(h5py.File): ...@@ -891,18 +894,33 @@ class File(h5py.File):
for e_poly in entity_polygons: for e_poly in entity_polygons:
ax.add_patch(e_poly) ax.add_patch(e_poly)
ax.add_patch(border) ax.add_patch(border)
return lines + entity_polygons + [border, grid_points] return lines + entity_polygons + [border]
n_frames = self.entity_poses.shape[1]
if options["cut_frames"] > 0:
n_frames = min(n_frames, options["cut_frames"])
n_frames = int(n_frames / options["speedup"])
if video_path is not None:
pbar = tqdm(range(n_frames))
def update(frame): def update(frame):
pbar.update(1)
pbar.refresh()
if frame < n_frames:
entity_poses = self.entity_poses_rad entity_poses = self.entity_poses_rad
file_frame = (int)(frame * options["speedup"]) % entity_poses.shape[1] file_frame = frame * options["speedup"]
this_pose = entity_poses[:, file_frame] this_pose = entity_poses[:, file_frame]
if not options["fixed_view"]: if not options["fixed_view"]:
self.middle_of_swarm = options["slow_view"] * self.middle_of_swarm + ( self.middle_of_swarm = options[
1 - options["slow_view"] "slow_view"
) * np.mean(this_pose, axis=0) ] * self.middle_of_swarm + (1 - options["slow_view"]) * np.mean(
this_pose, axis=0
)
# Find the maximal distance between the entities in x or y direction # Find the maximal distance between the entities in x or y direction
min_view = np.max( min_view = np.max(
...@@ -925,13 +943,13 @@ class File(h5py.File): ...@@ -925,13 +943,13 @@ class File(h5py.File):
self.middle_of_swarm[1] + self.view_size / 2, self.middle_of_swarm[1] + self.view_size / 2,
) )
for i_entity in range(n_entities): poses_trails = entity_poses[
:, max(0, file_frame - options["trail"]) : file_frame
poses_trail = entity_poses[
i_entity, max(0, file_frame - options["trail"]) : file_frame
] ]
for i_entity in range(n_entities):
lines[i_entity].set_data(poses_trail[:, 0], poses_trail[:, 1]) lines[i_entity].set_data(
poses_trails[i_entity, :, 0], poses_trails[i_entity, :, 1]
)
current_pose = entity_poses[i_entity, file_frame] current_pose = entity_poses[i_entity, file_frame]
t = mpl.transforms.Affine2D().translate( t = mpl.transforms.Affine2D().translate(
...@@ -940,21 +958,42 @@ class File(h5py.File): ...@@ -940,21 +958,42 @@ class File(h5py.File):
r = mpl.transforms.Affine2D().rotate(current_pose[2]) r = mpl.transforms.Affine2D().rotate(current_pose[2])
tra = r + t + ax.transData tra = r + t + ax.transData
entity_polygons[i_entity].set_transform(tra) entity_polygons[i_entity].set_transform(tra)
return lines + entity_polygons + [border, grid_points] else:
raise Exception(
f"Frame is bigger than n_frames {file_frame} of {n_frames}"
)
return lines + entity_polygons + [border]
print(f"Preparing to render n_frames: {n_frames}")
ani = animation.FuncAnimation( ani = animation.FuncAnimation(
fig, fig,
update, update,
frames=self.entity_positions.shape[1], frames=n_frames,
init_func=init, init_func=init,
blit=True, blit=False,
interval=1.0 / self.frequency, interval=self.frequency,
save_count=500, repeat=False,
) )
if video_path is not None: if video_path is not None:
# if i % (n / 40) == 0:
# print(f"Saving frame {i} of {n} ({100*i/n:.1f}%)")
video_path = Path(video_path)
if video_path.exists():
y = input(f"Video {str(video_path)} exists. Overwrite? (y/n)")
if y == "y":
video_path.unlink()
if not video_path.exists():
print(f"saving video to {video_path}") print(f"saving video to {video_path}")
writervideo = animation.FFMpegWriter(fps=25) writervideo = animation.FFMpegWriter(fps=25)
ani.save(video_path, writer=writervideo) ani.save(
video_path,
writer=writervideo,
)
else: else:
plt.show() plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment