diff --git a/src/robofish/io/app.py b/src/robofish/io/app.py index db9f41a9bf64fec7fe1ffc20d4c69b0fcc410bdb..3b9dcd9840555d81de4d188768bb9d42f973d6ac 100644 --- a/src/robofish/io/app.py +++ b/src/robofish/io/app.py @@ -128,6 +128,7 @@ def render(args=None): "fixed_view": False, "view_size": 60, "slow_view": 0.8, + "cut_frames": 0, } for key, value in default_options.items(): diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index 15e11e51005249e133f5144c695498375ac1601b..e5bb2b1e9092d371b20583f5da87770fb3301ba5 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -39,6 +39,8 @@ from matplotlib import animation from matplotlib import patches from matplotlib import cm +from tqdm import tqdm + from subprocess import run # Remember: Update docstring when updating these two global variables @@ -833,6 +835,7 @@ class File(h5py.File): "margin": 15, "slow_view": 0.8, "slow_zoom": 0.95, + "cut_frames": None, } options = { @@ -891,70 +894,106 @@ class File(h5py.File): for e_poly in entity_polygons: ax.add_patch(e_poly) ax.add_patch(border) - return lines + entity_polygons + [border, grid_points] + return lines + entity_polygons + [border] + + n_frames = self.entity_poses.shape[1] + if options["cut_frames"] > 0: + n_frames = min(n_frames, options["cut_frames"]) + + n_frames = int(n_frames / options["speedup"]) + + if video_path is not None: + pbar = tqdm(range(n_frames)) def update(frame): - entity_poses = self.entity_poses_rad + pbar.update(1) + pbar.refresh() - file_frame = (int)(frame * options["speedup"]) % entity_poses.shape[1] - this_pose = entity_poses[:, file_frame] + if frame < n_frames: + entity_poses = self.entity_poses_rad - if not options["fixed_view"]: - self.middle_of_swarm = options["slow_view"] * self.middle_of_swarm + ( - 1 - options["slow_view"] - ) * np.mean(this_pose, axis=0) + file_frame = frame * options["speedup"] + this_pose = entity_poses[:, file_frame] - # Find the maximal distance between the entities in x or y direction - min_view = np.max( - (np.max(this_pose, axis=0) - np.min(this_pose, axis=0))[:2] - ) - new_view_size = np.max( - [options["view_size"], min_view + options["margin"]] - ) - self.view_size = ( - options["slow_zoom"] * self.view_size - + (1 - options["slow_zoom"]) * new_view_size - ) + if not options["fixed_view"]: + self.middle_of_swarm = options[ + "slow_view" + ] * self.middle_of_swarm + (1 - options["slow_view"]) * np.mean( + this_pose, axis=0 + ) - ax.set_xlim( - self.middle_of_swarm[0] - self.view_size / 2, - self.middle_of_swarm[0] + self.view_size / 2, - ) - ax.set_ylim( - self.middle_of_swarm[1] - self.view_size / 2, - self.middle_of_swarm[1] + self.view_size / 2, - ) + # Find the maximal distance between the entities in x or y direction + min_view = np.max( + (np.max(this_pose, axis=0) - np.min(this_pose, axis=0))[:2] + ) + new_view_size = np.max( + [options["view_size"], min_view + options["margin"]] + ) + self.view_size = ( + options["slow_zoom"] * self.view_size + + (1 - options["slow_zoom"]) * new_view_size + ) - for i_entity in range(n_entities): + ax.set_xlim( + self.middle_of_swarm[0] - self.view_size / 2, + self.middle_of_swarm[0] + self.view_size / 2, + ) + ax.set_ylim( + self.middle_of_swarm[1] - self.view_size / 2, + self.middle_of_swarm[1] + self.view_size / 2, + ) - poses_trail = entity_poses[ - i_entity, max(0, file_frame - options["trail"]) : file_frame + poses_trails = entity_poses[ + :, max(0, file_frame - options["trail"]) : file_frame ] + for i_entity in range(n_entities): + lines[i_entity].set_data( + poses_trails[i_entity, :, 0], poses_trails[i_entity, :, 1] + ) - lines[i_entity].set_data(poses_trail[:, 0], poses_trail[:, 1]) - - current_pose = entity_poses[i_entity, file_frame] - t = mpl.transforms.Affine2D().translate( - current_pose[0], current_pose[1] + current_pose = entity_poses[i_entity, file_frame] + t = mpl.transforms.Affine2D().translate( + current_pose[0], current_pose[1] + ) + r = mpl.transforms.Affine2D().rotate(current_pose[2]) + tra = r + t + ax.transData + entity_polygons[i_entity].set_transform(tra) + else: + raise Exception( + f"Frame is bigger than n_frames {file_frame} of {n_frames}" ) - r = mpl.transforms.Affine2D().rotate(current_pose[2]) - tra = r + t + ax.transData - entity_polygons[i_entity].set_transform(tra) - return lines + entity_polygons + [border, grid_points] + return lines + entity_polygons + [border] + + print(f"Preparing to render n_frames: {n_frames}") ani = animation.FuncAnimation( fig, update, - frames=self.entity_positions.shape[1], + frames=n_frames, init_func=init, - blit=True, - interval=1.0 / self.frequency, - save_count=500, + blit=False, + interval=self.frequency, + repeat=False, ) if video_path is not None: - print(f"saving video to {video_path}") - writervideo = animation.FFMpegWriter(fps=25) - ani.save(video_path, writer=writervideo) + + # if i % (n / 40) == 0: + # print(f"Saving frame {i} of {n} ({100*i/n:.1f}%)") + + video_path = Path(video_path) + if video_path.exists(): + y = input(f"Video {str(video_path)} exists. Overwrite? (y/n)") + if y == "y": + video_path.unlink() + + if not video_path.exists(): + print(f"saving video to {video_path}") + + writervideo = animation.FFMpegWriter(fps=25) + ani.save( + video_path, + writer=writervideo, + ) else: plt.show()