diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index 8f3e1d260f8e3be4a36c91d8781121536cf25487..85adc8508d602b7771ad7a6b7ff61eca54942af1 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -40,7 +40,7 @@ from matplotlib import animation from matplotlib import patches from matplotlib import cm -from tqdm import tqdm +from tqdm.auto import tqdm from subprocess import run @@ -638,7 +638,7 @@ class File(h5py.File): def select_entity_property( self, predicate: types.LambdaType = None, - entity_property: property = Entity.poses, + entity_property: Union[property, str] = Entity.poses, ) -> Iterable: """Get a property of selected entities. @@ -648,7 +648,7 @@ class File(h5py.File): Args: predicate: a lambda function, selecting entities (example: lambda e: e.category == "fish") - entity_property: a property of the Entity class (example: Entity.poses_rad) + entity_property: a property of the Entity class (example: Entity.poses_rad) or a string with the name of the dataset. Returns: An three dimensional array of all properties of all entities with the shape (entity, time, property_length). If an entity has a shorter length of the property, the output will be filled with nans. @@ -661,7 +661,10 @@ class File(h5py.File): assert self.common_sampling(entities) is not None # Initialize poses output array - properties = [entity_property.__get__(entity) for entity in entities] + if isinstance(entity_property, str): + properties = [entity[entity_property] for entity in entities] + else: + properties = [entity_property.__get__(entity) for entity in entities] max_timesteps = max([0] + [p.shape[0] for p in properties]) @@ -992,6 +995,7 @@ class File(h5py.File): "show_text": False, "render_goals": False, "render_targets": False, + "dpi": 200, } options = { @@ -1001,7 +1005,7 @@ class File(h5py.File): fig, ax = plt.subplots(figsize=(10, 10)) ax.set_facecolor("gray") - + plt.tight_layout(pad=0.05) n_entities = len(self.entities) lines = [ plt.plot([], [], lw=options["linewidth"], zorder=0)[0] @@ -1040,12 +1044,6 @@ class File(h5py.File): # border = plt.plot(border_vertices[0], border_vertices[1], "k") border = patches.Polygon(border_vertices.T, facecolor="w", zorder=-1) - start_pose = self.entity_poses_rad[:, 0] - - self.middle_of_swarm = np.mean(start_pose, axis=0) - min_view = np.max((np.max(start_pose, axis=0) - np.min(start_pose, axis=0))[:2]) - self.view_size = np.max([options["view_size"], min_view + options["margin"]]) - def title(file_frame: int) -> str: """Search for datasets containing text for displaying it in the video""" output = [] @@ -1094,6 +1092,8 @@ class File(h5py.File): if options["cut_frames_end"] == 0 or options["cut_frames_end"] is None: options["cut_frames_end"] = n_frames + if options["cut_frames_start"] is None: + options["cut_frames_start"] = 0 frame_range = ( options["cut_frames_start"], min(n_frames, options["cut_frames_end"]), @@ -1101,6 +1101,12 @@ class File(h5py.File): n_frames = int((frame_range[1] - frame_range[0]) / options["speedup"]) + start_pose = self.entity_poses_rad[:, frame_range[0]] + + self.middle_of_swarm = np.mean(start_pose, axis=0) + min_view = np.max((np.max(start_pose, axis=0) - np.min(start_pose, axis=0))[:2]) + self.view_size = np.max([options["view_size"], min_view + options["margin"]]) + if video_path is not None: pbar = tqdm(range(n_frames)) @@ -1204,10 +1210,8 @@ class File(h5py.File): if not video_path.exists(): print(f"saving video to {video_path}") - writervideo = animation.FFMpegWriter(fps=25) - ani.save( - video_path, - writer=writervideo, - ) + writervideo = animation.FFMpegWriter(fps=self.frequency) + ani.save(video_path, writer=writervideo, dpi=options["dpi"]) + plt.close() else: plt.show()