diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 8d588c1503ad62a3e8d5cce3d17121435c833889..c82c2ccbc1d547317c7fece8a2b9ba98c416490f 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,6 +1,6 @@ repos: - repo: https://github.com/psf/black - rev: 21.6b0 + rev: 22.3.0 hooks: - id: black - repo: https://github.com/pre-commit/pre-commit-hooks diff --git a/setup.py b/setup.py index a9ebe9c5d415412a41abe7e72af70e682d3e9601..14d44d9b7e322e78d8823e2266e77b06d8e2c12e 100644 --- a/setup.py +++ b/setup.py @@ -59,6 +59,7 @@ setup( "deprecation", "tqdm", "pre-commit", + "torch", ], classifiers=[ "Development Status :: 3 - Alpha", @@ -68,6 +69,7 @@ setup( "Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", ], python_requires=">=3.6", packages=[f"robofish.{p}" for p in find_packages("src/robofish")], diff --git a/src/robofish/evaluate/app.py b/src/robofish/evaluate/app.py index 234d5029c4d0429bb290c205c355b88ba36b7f13..e9538fd3e146d18251362422677c43b26748a2c9 100644 --- a/src/robofish/evaluate/app.py +++ b/src/robofish/evaluate/app.py @@ -31,6 +31,7 @@ def function_dict(): "follow_iid": base.evaluate_follow_iid, "individual_speed": base.evaluate_individual_speed, "individual_iid": base.evaluate_individual_iid, + "quiver": base.evaluate_quiver, "all": base.evaluate_all, } diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py index fedefc76fdef3735a3af266e31022a027575e587..62ea41932d642b5ed58d16ce7f70edbc7c4d92b7 100644 --- a/src/robofish/evaluate/evaluate.py +++ b/src/robofish/evaluate/evaluate.py @@ -22,6 +22,7 @@ from scipy import stats from tqdm import tqdm import inspect import random +import warnings def evaluate_speed( @@ -420,6 +421,152 @@ def evaluate_tank_position( return fig +def evaluate_quiver( + paths: Iterable[Union[str, Path]], + labels: Iterable[str] = None, + predicate=None, + poses_from_paths=None, + speeds_turns_from_paths=None, + file_settings=None, + max_files=None, + bins=25, +): + """BETA: Plot the flow of movement in the files.""" + try: + import torch + from fish_models.models.pascals_lstms.attribution import SocialVectors + except ImportError: + print( + "Either torch or fish_models is not installed.", + "The social vector should come from robofish io.", + "This is a known issue (#24)", + ) + return + + if poses_from_paths is None: + poses_from_paths, file_settings = utils.get_all_poses_from_paths(paths) + if speeds_turns_from_paths is None: + speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths( + paths, "speeds_turns", max_files=max_files + ) + # print(poses_from_paths.shape) + if len(poses_from_paths) != 1: + warnings.warn( + "NotImplemented: Only one path is supported for now in quiver plot." + ) + poses_from_paths = poses_from_paths[:1] + + poses_from_paths = np.array(poses_from_paths) + speeds_turns_from_paths = np.array(speeds_turns_from_paths) + + all_poses = torch.tensor(poses_from_paths)[0, :, :, :-1].reshape((-1, 4)) + speed = torch.tensor(speeds_turns_from_paths)[0, ..., 0].flatten() + all_poses_speed = torch.clone(all_poses) + all_poses_speed[:, 2] *= speed + all_poses_speed[:, 3] *= speed + + tank_b = torch.linspace( + torch.min(all_poses[:, :2]), torch.max(all_poses[:, :2]), bins + ) + nb = len(tank_b) + poses_buckets = torch.bucketize(all_poses[:, :2], tank_b) + + tank_directions = np.zeros((nb, nb, 2)) + tank_directions_speed = np.zeros_like(tank_directions) + tank_count = np.zeros((nb, nb)) + + for x in tqdm(range(nb)): + for y in range(nb): + d = torch.where( + torch.logical_and(poses_buckets[:, 0] == x, poses_buckets[:, 1] == y) + )[0] + if len(d) > 0: + # print("d", d) + # d_v = torch.stack((torch.cos(all_poses[d,2]), torch.sin(all_poses[d,2])), dim=1) + # print("dv",d_v.shape) + # tank_directions[x, y] = torch.mean(all_poses[d, 2:], dim=0) + tank_directions_speed[x, y] = torch.mean(all_poses_speed[d, 2:], dim=0) + # print(tank_directions[x,y] - tank_directions_speed[x,y]) + tank_count[x, y] = len(d) + + sv = SocialVectors(torch.tensor(poses_from_paths[0])) + sv_r = torch.tensor(sv.social_vectors_without_focal_zeros)[:, :, :-1].reshape( + (-1, 3) + ) # [:1000] + sv_r = torch.cat( + (sv_r[:, :2], torch.cos(sv_r[:, 2:]), torch.sin(sv_r[:, 2:])), dim=1 + ) + sv_r_s = torch.clone(sv_r) + sv_r_s[:, 2] *= speed + sv_r_s[:, 3] *= speed + + social_b = torch.linspace(-20, 20, bins) + nb = len(social_b) + poses_buckets = torch.bucketize(sv_r[:, :2], social_b) + + social_directions = np.zeros((nb, nb, 2)) + social_directions_speed = np.zeros_like(social_directions) + social_count = np.zeros((nb, nb)) + + for x in tqdm(range(nb)): + for y in range(nb): + d = torch.where( + torch.logical_and(poses_buckets[:, 0] == x, poses_buckets[:, 1] == y) + )[0] + if len(d) > 0: + # print("d", d) + + # print("dv",d_v.shape) + social_directions[x, y] = torch.mean(sv_r[d, 2:], dim=0) + social_directions_speed[x, y] = torch.mean(sv_r_s[d, 2:], dim=0) + social_count[x, y] = len(d) + + tank_xs, tank_ys = np.meshgrid(tank_b, tank_b) + social_xs, social_ys = np.meshgrid(social_b, social_b) + fig, axs = plt.subplots(1, 2, figsize=(20, 10)) + + axs[0].quiver( + tank_xs, + tank_ys, + tank_directions_speed[..., 0], + -tank_directions_speed[..., 1], + alpha=tank_count / (np.max(tank_count)), + ) + axs[0].set_title("Avg Direction weighted by speed") + axs[1].quiver( + social_xs, + social_ys, + social_directions_speed[..., 0], + -social_directions_speed[..., 1], + alpha=0.3 + social_count * 0.7 / (np.max(social_count)), + ) + axs[1].set_title("Avg Direction weighted by speed in social situations") + + fish_size = 50 / 20 + linewidth = 2 + + head = axs[1].add_patch( + plt.Circle( + xy=(0, 0), + radius=fish_size, + linewidth=linewidth, + # color="white", + edgecolor="black", + facecolor="white", + fill=True, + alpha=0.8, + ) + ) + + (tail,) = axs[1].plot( + [-fish_size, -fish_size * 2], + [0, 0], + color="black", + linewidth=linewidth, + ) + return fig + + def evaluate_social_vector( paths: Iterable[Union[str, Path]], labels: Iterable[str] = None, diff --git a/src/robofish/io/app.py b/src/robofish/io/app.py index c6a93852d63d07d4bc7b598075e280759dbd4992..fef92106062f3fc0fc85ed6c8db73163813348ec 100644 --- a/src/robofish/io/app.py +++ b/src/robofish/io/app.py @@ -16,6 +16,7 @@ from robofish.io import utils import argparse import logging import warnings +from tqdm.auto import tqdm def print_file(args=None): @@ -75,11 +76,15 @@ def update_calculated_data(args=None): assert len(files) > 0, f"No files found in path {args.path}." - for fp in files: - print(f"File {fp}") + pbar = tqdm(files) + for fp in pbar: + try: with robofish.io.File(fp, "r+", validate_poses_hash=False) as f: - f.update_calculated_data(verbose=True) + if f.update_calculated_data(verbose=False): + pbar.set_description(f"File {fp} was updated") + else: + pbar.set_description(f"File {fp} was already up to date") except Exception as e: warnings.warn(f"The file {fp} could not be updated.") print(e) @@ -203,7 +208,8 @@ def render(args=None): "fixed_view": False, "view_size": 60, "slow_view": 0.8, - "cut_frames": 0, + "cut_frames_start": 0, + "cut_frames_end": 0, "show_text": False, "render_goals": False, "render_targets": False, diff --git a/src/robofish/io/entity.py b/src/robofish/io/entity.py index 3375f377909410f5b6044e2bfdf444d078a09fc7..9fe37a30243af786f182301ea567e65aa23ff4cb 100644 --- a/src/robofish/io/entity.py +++ b/src/robofish/io/entity.py @@ -248,6 +248,7 @@ class Entity(h5py.Group): return np.stack([speed, turn], axis=-1) def update_calculated_data(self, verbose=False, force_update=False): + changed = False if ( "poses_hash" not in self.attrs or self.attrs["poses_hash"] != self.poses_hash @@ -258,8 +259,8 @@ class Entity(h5py.Group): ): try: self.attrs["poses_hash"] = self.poses_hash - self.attrs["unfinished_calculations"] = True if "orientations" in self: + self.attrs["unfinished_calculations"] = True ori_rad = self.calculate_orientations_rad() if "calculated_orientations_rad" in self: del self["calculated_orientations_rad"] @@ -274,6 +275,7 @@ class Entity(h5py.Group): del self.attrs["unfinished_calculations"] if verbose: + changed = True print( f"Updated calculated data for entity {self.name} with poses_hash {self.poses_hash}" ) @@ -291,6 +293,7 @@ class Entity(h5py.Group): ) assert self.attrs["poses_hash"] == self.poses_hash + return changed def calculate_orientations_rad(self): ori_rad = utils.limit_angle_range( diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index b29a54659b48e49fa3314d3c789ffcba5e330459..85adc8508d602b7771ad7a6b7ff61eca54942af1 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -40,7 +40,7 @@ from matplotlib import animation from matplotlib import patches from matplotlib import cm -from tqdm import tqdm +from tqdm.auto import tqdm from subprocess import run @@ -516,8 +516,8 @@ class File(h5py.File): return entity_names def update_calculated_data(self, verbose=False): - for e in self.entities: - e.update_calculated_data(verbose) + changed = any([e.update_calculated_data(verbose) for e in self.entities]) + return changed def clear_calculated_data(self, verbose=True): """Delete all calculated data from the files.""" @@ -638,7 +638,7 @@ class File(h5py.File): def select_entity_property( self, predicate: types.LambdaType = None, - entity_property: property = Entity.poses, + entity_property: Union[property, str] = Entity.poses, ) -> Iterable: """Get a property of selected entities. @@ -648,7 +648,7 @@ class File(h5py.File): Args: predicate: a lambda function, selecting entities (example: lambda e: e.category == "fish") - entity_property: a property of the Entity class (example: Entity.poses_rad) + entity_property: a property of the Entity class (example: Entity.poses_rad) or a string with the name of the dataset. Returns: An three dimensional array of all properties of all entities with the shape (entity, time, property_length). If an entity has a shorter length of the property, the output will be filled with nans. @@ -661,7 +661,10 @@ class File(h5py.File): assert self.common_sampling(entities) is not None # Initialize poses output array - properties = [entity_property.__get__(entity) for entity in entities] + if isinstance(entity_property, str): + properties = [entity[entity_property] for entity in entities] + else: + properties = [entity_property.__get__(entity) for entity in entities] max_timesteps = max([0] + [p.shape[0] for p in properties]) @@ -829,6 +832,8 @@ class File(h5py.File): self, ax=None, lw_distances=False, + lw=2, + ms=32, figsize=None, step_size=4, c=None, @@ -876,7 +881,6 @@ class File(h5py.File): ) else: step_size = poses.shape[1] - line_width = 1 cmap = cm.get_cmap(cmap) @@ -907,8 +911,7 @@ class File(h5py.File): for t in range(skip_timesteps, timesteps, step_size): if lw_distances: lw = np.mean(line_width[t : t + step_size + 1]) - else: - lw = 1 + ax.plot( poses[fish_id, t : t + step_size + 1, 0], poses[fish_id, t : t + step_size + 1, 1], @@ -917,26 +920,29 @@ class File(h5py.File): ) # Plotting outside of the figure to have the label ax.plot([550, 600], [550, 600], lw=5, c=this_c, label=fish_id) - ax.scatter( - [poses[:, skip_timesteps, 0]], - [poses[:, skip_timesteps, 1]], - marker="h", - c="black", - s=32, - label="Start", - zorder=5, - ) + + # ax.scatter( + # [poses[:, skip_timesteps, 0]], + # [poses[:, skip_timesteps, 1]], + # marker="h", + # c="black", + # s=ms, + # label="Start", + # zorder=5, + # ) ax.scatter( [poses[:, max_timesteps, 0]], [poses[:, max_timesteps, 1]], marker="x", c="black", - s=32, + s=ms, label="End", zorder=5, ) - if legend: - ax.legend(loc="lower right") + if legend and isinstance(legend, str): + ax.legend(legend) + elif legend: + ax.legend() ax.set_xlabel("x [cm]") ax.set_ylabel("y [cm]") @@ -984,10 +990,12 @@ class File(h5py.File): "margin": 15, "slow_view": 0.8, "slow_zoom": 0.95, - "cut_frames": None, + "cut_frames_start": None, + "cut_frames_end": None, "show_text": False, "render_goals": False, "render_targets": False, + "dpi": 200, } options = { @@ -997,7 +1005,7 @@ class File(h5py.File): fig, ax = plt.subplots(figsize=(10, 10)) ax.set_facecolor("gray") - + plt.tight_layout(pad=0.05) n_entities = len(self.entities) lines = [ plt.plot([], [], lw=options["linewidth"], zorder=0)[0] @@ -1036,12 +1044,6 @@ class File(h5py.File): # border = plt.plot(border_vertices[0], border_vertices[1], "k") border = patches.Polygon(border_vertices.T, facecolor="w", zorder=-1) - start_pose = self.entity_poses_rad[:, 0] - - self.middle_of_swarm = np.mean(start_pose, axis=0) - min_view = np.max((np.max(start_pose, axis=0) - np.min(start_pose, axis=0))[:2]) - self.view_size = np.max([options["view_size"], min_view + options["margin"]]) - def title(file_frame: int) -> str: """Search for datasets containing text for displaying it in the video""" output = [] @@ -1087,10 +1089,23 @@ class File(h5py.File): return lines + entity_polygons + [border] + points n_frames = self.entity_poses.shape[1] - if options["cut_frames"] > 0: - n_frames = min(n_frames, options["cut_frames"]) - n_frames = int(n_frames / options["speedup"]) + if options["cut_frames_end"] == 0 or options["cut_frames_end"] is None: + options["cut_frames_end"] = n_frames + if options["cut_frames_start"] is None: + options["cut_frames_start"] = 0 + frame_range = ( + options["cut_frames_start"], + min(n_frames, options["cut_frames_end"]), + ) + + n_frames = int((frame_range[1] - frame_range[0]) / options["speedup"]) + + start_pose = self.entity_poses_rad[:, frame_range[0]] + + self.middle_of_swarm = np.mean(start_pose, axis=0) + min_view = np.max((np.max(start_pose, axis=0) - np.min(start_pose, axis=0))[:2]) + self.view_size = np.max([options["view_size"], min_view + options["margin"]]) if video_path is not None: pbar = tqdm(range(n_frames)) @@ -1103,7 +1118,7 @@ class File(h5py.File): if frame < n_frames: entity_poses = self.entity_poses_rad - file_frame = frame * options["speedup"] + file_frame = (frame * options["speedup"]) + frame_range[0] this_pose = entity_poses[:, file_frame] if not options["fixed_view"]: @@ -1195,10 +1210,8 @@ class File(h5py.File): if not video_path.exists(): print(f"saving video to {video_path}") - writervideo = animation.FFMpegWriter(fps=25) - ani.save( - video_path, - writer=writervideo, - ) + writervideo = animation.FFMpegWriter(fps=self.frequency) + ani.save(video_path, writer=writervideo, dpi=options["dpi"]) + plt.close() else: plt.show()