diff --git a/.gitignore b/.gitignore index 0cd5109888fa7ddcebdb5a6386249d8b1aed7110..674857fcdf8c6b545bd3339ba3b7f81210cfd9ae 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,8 @@ env *.hdf5 *.mp4 +# Ignore all files with ignore_ prefix +ignore_* !tests/resources/*.hdf5 feature_requests.md diff --git a/src/robofish/evaluate/app.py b/src/robofish/evaluate/app.py index 5c837bdb9f8490ea2fde70f4ea892cf18d541f4c..78697fd24c4374f8481c6408eaa1201955d6804c 100644 --- a/src/robofish/evaluate/app.py +++ b/src/robofish/evaluate/app.py @@ -94,26 +94,27 @@ def evaluate(args: dict = None) -> None: default=None, ) parser.add_argument( - "--save_path", + "--save-path", + "--save_path", # for backwards compatibility type=str, help="Filename for saving resulting graphics.", default=None, ) parser.add_argument( - "--max_files", + "--max-files", + "--max_files", # for backwards compatibility type=int, default=None, help="The maximum number of files to be loaded from the given paths.", ) parser.add_argument( - "--add_train_data", + "--add-train-data", + "--add_train_data", # for backwards compatibility action="store_true", help="Add the training data to the evaluation.", default=False, ) - # TODO: ignore fish/ consider_names - if args is None: args = parser.parse_args() diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py index 1f4d2af019ae1df6a08f3541cec7233c8d95afbd..793eb31f9a43881de231c5579c3eec493aa6d0a9 100644 --- a/src/robofish/evaluate/evaluate.py +++ b/src/robofish/evaluate/evaluate.py @@ -31,6 +31,7 @@ def evaluate_speed( labels: Iterable[str] = None, predicate: Callable[[robofish.io.Entity], bool] = None, speeds_turns_from_paths: Iterable[Iterable[np.ndarray]] = None, + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the speed of the entities as histogram. @@ -47,7 +48,7 @@ def evaluate_speed( if speeds_turns_from_paths is None: speeds_turns_from_paths, _ = utils.get_all_data_from_paths( - paths, "speeds_turns", predicate=predicate + paths, "speeds_turns", predicate=predicate, max_files=max_files ) speeds = [] @@ -92,6 +93,7 @@ def evaluate_turn( predicate: Callable[[robofish.io.Entity], bool] = None, speeds_turns_from_paths: Iterable[Iterable[np.ndarray]] = None, file_settings: dict = None, + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the turn angles of the entities as histogram. @@ -109,7 +111,7 @@ def evaluate_turn( if speeds_turns_from_paths is None: speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths( - paths, "speeds_turns", predicate=predicate + paths, "speeds_turns", predicate=predicate, max_files=max_files ) turns = [] @@ -156,6 +158,7 @@ def evaluate_orientation( predicate: Callable[[robofish.io.Entity], bool] = None, poses_from_paths: Iterable[Iterable[np.ndarray]] = None, file_settings: dict = None, + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the orientations of the entities on a 2d grid. @@ -172,7 +175,7 @@ def evaluate_orientation( """ if poses_from_paths is None: poses_from_paths, file_settings = utils.get_all_poses_from_paths( - paths, predicate=predicate + paths, predicate=predicate, max_files=max_files ) world_bounds = [ @@ -243,6 +246,7 @@ def evaluate_relative_orientation( predicate: Callable[[robofish.io.Entity], bool] = None, poses_from_paths: Iterable[Iterable[np.ndarray]] = None, file_settings: dict = None, + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the relative orientations of the entities as a histogram. @@ -260,7 +264,7 @@ def evaluate_relative_orientation( if poses_from_paths is None: poses_from_paths, file_settings = utils.get_all_poses_from_paths( - paths, predicate + paths, predicate, max_files=max_files ) orientations = [] @@ -299,6 +303,7 @@ def evaluate_distance_to_wall( predicate: Callable[[robofish.io.Entity], bool] = None, poses_from_paths: Iterable[Iterable[np.ndarray]] = None, file_settings: dict = None, + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the distances of the entities to the walls as a histogram. Lambda function example: lambda e: e.category == "fish" @@ -315,7 +320,7 @@ def evaluate_distance_to_wall( if poses_from_paths is None: poses_from_paths, file_settings = utils.get_all_poses_from_paths( - paths, predicate + paths, predicate, max_files=max_files ) world_bounds = [ @@ -387,6 +392,7 @@ def evaluate_tank_position( poses_from_paths: Iterable[Iterable[np.ndarray]] = None, file_settings: dict = None, max_points: int = 4000, + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the positions of the entities as a heatmap. Lambda function example: lambda e: e.category == "fish" @@ -404,7 +410,7 @@ def evaluate_tank_position( if poses_from_paths is None: poses_from_paths, file_settings = utils.get_all_poses_from_paths( - paths, predicate + paths, predicate, max_files=max_files ) xy_positions = [] @@ -479,18 +485,13 @@ def evaluate_quiver( """ try: import torch - from fish_models.models.pascals_lstms.attribution import SocialVectors except ImportError: - print( - "Either torch or fish_models is not installed.", - "The social vector should come from robofish io.", - "This is a known issue (#24)", - ) + print("Torch is not installed and could not be imported") return if poses_from_paths is None: poses_from_paths, file_settings = utils.get_all_poses_from_paths( - paths, predicate, predicate + paths, predicate, max_files=max_files ) if speeds_turns_from_paths is None: speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths( @@ -547,10 +548,11 @@ def evaluate_quiver( # print(tank_directions[x,y] - tank_directions_speed[x,y]) tank_count[x, y] = len(d) - sv = SocialVectors(poses_from_paths[0]) - sv_r = torch.tensor(sv.social_vectors_without_focal_zeros)[:, :, :-1].reshape( - (-1, 3) - ) # [:1000] + sv_r = torch.tensor( + robofish.io.utils.social_vectors_without_focal_zeros( + poses_from_paths[0][:, :, :-1] + ).reshape((-1, 4)) + ) sv_r = torch.cat( (sv_r[:, :2], torch.cos(sv_r[:, 2:]), torch.sin(sv_r[:, 2:])), dim=1 ) @@ -631,6 +633,7 @@ def evaluate_social_vector( predicate: Callable[[robofish.io.Entity], bool] = None, poses_from_paths: Iterable[Iterable[np.ndarray]] = None, file_settings: dict = None, + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the vectors pointing from the focal fish to the conspecifics as heatmap. Lambda function example: lambda e: e.category == "fish" @@ -645,17 +648,9 @@ def evaluate_social_vector( matplotlib.figure.Figure: The figure of the social vectors. """ - try: - from fish_models.models.pascals_lstms.attribution import SocialVectors - except ImportError as e: - warnings.warn( - "Please install the fish_models package to use this function.\n", e - ) - return - if poses_from_paths is None: poses_from_paths, file_settings = utils.get_all_poses_from_paths( - paths, predicate + paths, predicate, max_files=max_files ) fig, ax = plt.subplots( @@ -675,8 +670,8 @@ def evaluate_social_vector( axis=-1, ) - social_vec = SocialVectors(poses).social_vectors_without_focal_zeros - flat_sv = social_vec.reshape((-1, 3)) + social_vec = robofish.io.utils.social_vectors_without_focal_zeros(poses) + flat_sv = social_vec.reshape((-1, 4)) bins = ( 30 if flat_sv.shape[0] < 40000 else 50 if flat_sv.shape[0] < 65000 else 100 @@ -700,6 +695,7 @@ def evaluate_follow_iid( predicate: Callable[[robofish.io.Entity], bool] = None, poses_from_paths: Iterable[Iterable[np.ndarray]] = None, file_settings: dict = None, + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the follow metric in respect to the inter individual distance (iid). Lambda function example: lambda e: e.category == "fish" @@ -716,7 +712,7 @@ def evaluate_follow_iid( if poses_from_paths is None: poses_from_paths, file_settings = utils.get_all_poses_from_paths( - paths, predicate + paths, predicate, max_files=max_files ) follow, iid = [], [] @@ -792,12 +788,11 @@ def evaluate_follow_iid( # Created an issue fig = plt.figure(figsize=(12, 5)) - gs = gridspec.GridSpec(1, len(grids)) - - for i in range(len(grids)): - SeabornFig2Grid(grids[i], fig, gs[i]) - + + for i in range(len(grids)): + SeabornFig2Grid(grids[i], fig, gs[i]) + return fig @@ -808,6 +803,7 @@ def evaluate_tracks_distance( poses_from_paths: Iterable[Iterable[np.ndarray]] = None, file_settings: dict = None, max_timesteps: int = 4000, + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the distances of two or more fish on the track. Lambda function example: lambda e: e.category == "fish" @@ -828,6 +824,7 @@ def evaluate_tracks_distance( predicate, lw_distances=True, max_timesteps=max_timesteps, + max_files=max_files, ) @@ -841,7 +838,7 @@ def evaluate_tracks( seed: int = 42, max_timesteps: int = None, verbose: bool = False, - max_files: int = None + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the distances of two or more fish on the track. Lambda function example: lambda e: e.category == "fish" @@ -867,7 +864,7 @@ def evaluate_tracks( max_files_per_path = max([len(files) for files in files_per_path]) min_files_per_path = min([len(files) for files in files_per_path]) - + rows, cols = len(files_per_path), min(6, max_files_per_path) multirow = False @@ -1317,26 +1314,26 @@ def calculate_distLinePoint( def show_values( pc: matplotlib.collections.PolyCollection, fmt: str = "%.2f", **kw: dict ) -> None: - """Show numbers on plt.ax.pccolormesh plot. - - https://stackoverflow.com/questions/25071968/ - heatmap-with-text-in-each-cell-with-matplotlibs-pyplot + """Show numbers on plt.ax.pcolormesh plot. Args: pc(matplotlib.collections.PolyCollection): The plot to show the values on. fmt(str): The format of the values. kw(dict): The keyword arguments. """ + values = pc.get_array() pc.update_scalarmappable() ax = pc.axes - for p, color, value in zip(pc.get_paths(), pc.get_facecolors(), pc.get_array()): + + for p, color, value in zip(pc.get_paths(), pc.get_facecolors(), values.ravel()): x, y = p.vertices[:-2, :].mean(0) - if np.all(color[:3] > 0.5): - color = (0.0, 0.0, 0.0) - else: - color = (1.0, 1.0, 1.0) - if np.all(value == "--"): - ax.text(x, y, fmt % value, ha="center", va="center", color=color, **kw) + + # Choose text color based on background + text_color = (0.0, 0.0, 0.0) if np.all(color[:3] > 0.5) else (1.0, 1.0, 1.0) + + # Only show non-masked values + if not np.ma.is_masked(value): + ax.text(x, y, fmt % value, ha="center", va="center", color=text_color, **kw) class SeabornFig2Grid: diff --git a/src/robofish/io/app.py b/src/robofish/io/app.py index 4aec636e080ab18bea68faf12b459a118c38315f..f0f03a32a9def768814158e0fb050ff1de586119 100644 --- a/src/robofish/io/app.py +++ b/src/robofish/io/app.py @@ -41,14 +41,16 @@ def print_file(args: argparse.Namespace = None) -> bool: parser.add_argument("path", type=str, help="The path to a hdf5 file") parser.add_argument( - "--output_format", + "--output-format", + "--output_format", # backwards compatibility type=str, choices=["shape", "full"], default="shape", help="Choose how datasets are printed, either the shapes or the full content is printed", ) parser.add_argument( - "--full_attrs", + "--full-attrs", + "--full_attrs", # backwards compatibility default=False, action="store_true", help="Show full unabbreviated values for attributes", @@ -122,7 +124,8 @@ def validate(args: argparse.Namespace = None) -> int: description="The function can be directly accessed from the commandline and can be given any number of files or folders. The function returns the validity of the files in a human readable format or as a raw output." ) parser.add_argument( - "--output_format", + "--output-format", + "--output_format", # backwards compatibility type=str, default="h", choices=["h", "raw"], @@ -179,7 +182,9 @@ def render_file(kwargs: Dict) -> None: kwargs (Dict, optional): A dictionary containing the arguments for the render function. """ with robofish.io.File(path=kwargs["path"]) as f: - f.render(**{k:v for k,v in kwargs.items() if k not in ("path", "reference_track")}) + f.render( + **{k: v for k, v in kwargs.items() if k not in ("path", "reference_track")} + ) def overwrite_user_configs() -> None: @@ -248,7 +253,7 @@ def render(args: argparse.Namespace = None) -> None: default=[], help="Custom colors to use for guppies. Use spaces as delimiter. " "To set all guppies to the same color, pass only one color. " - "Hexadecimal values, color names and matplotlib abbreviations are supported (\"#000000\", black, k)" + 'Hexadecimal values, color names and matplotlib abbreviations are supported ("#000000", black, k)', ) default_options = { @@ -267,7 +272,7 @@ def render(args: argparse.Namespace = None) -> None: "render_swarm_center": False, "highlight_switches": False, "figsize": 10, - "fov_smoothing_factor": 0.8 + "fov_smoothing_factor": 0.8, } for key, value in default_options.items(): diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index 8af3217ae6bf7f3f5c4230fe7319f8bf6fdb95dd..0afc5cdef44b33f972c6bc59e86b4c4f89c2f708 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -936,7 +936,7 @@ class File(h5py.File): else: step_size = poses.shape[1] - cmap = matplotlib.cm.get_cmap(cmap) + cmap = plt.get_cmap(cmap) x_world, y_world = self.world_size if figsize is None: @@ -1026,7 +1026,7 @@ class File(h5py.File): custom_colors: bool = None, dpi: int = 200, figsize: int = 10, - fov_smoothing_factor: float = 0.8 + fov_smoothing_factor: float = 0.8, ) -> None: """Render a video of the file. @@ -1102,26 +1102,50 @@ class File(h5py.File): if render_zones: fovs = [] for ei, e in enumerate(self.entities): - zone_sizes_str = get_zone_sizes_from_model_str(self.attrs.get("guppy_model_rollout", "")) + zone_sizes_str = get_zone_sizes_from_model_str( + self.attrs.get("guppy_model_rollout", "") + ) zone_sizes_attrs = get_zone_sizes_from_attrs(e) - + # Check that there are no zone sizes in the model and in the attributes - assert zone_sizes_str == {} or zone_sizes_attrs == {}, "There are zone sizes in the model and in the attributes. Please use only one (preferrably the attributes)." - zone_sizes = zone_sizes_attrs if zone_sizes_attrs != {} else zone_sizes_str - - fov = zone_sizes.get("fov", np.pi*2) + assert ( + zone_sizes_str == {} or zone_sizes_attrs == {} + ), "There are zone sizes in the model and in the attributes. Please use only one (preferrably the attributes)." + zone_sizes = ( + zone_sizes_attrs if zone_sizes_attrs != {} else zone_sizes_str + ) + + fov = zone_sizes.get("fov", np.pi * 2) fov = np.rad2deg(fov) zone_sizes.pop("fov", None) fovs.append(fov) - + entity_zones = [] for zone_size in zone_sizes.values(): if fov >= 360: - entity_zones.append(matplotlib.patches.Circle((0, 0), zone_size, color=fish_colors[ei], alpha=0.2, fill=False)) + entity_zones.append( + matplotlib.patches.Circle( + (0, 0), + zone_size, + color=fish_colors[ei], + alpha=0.2, + fill=False, + ) + ) else: - entity_zones.append(matplotlib.patches.Wedge((0,0), zone_size, theta1=-fov/2, theta2=fov/2, color=fish_colors[ei], alpha=0.2, fill=False)) + entity_zones.append( + matplotlib.patches.Wedge( + (0, 0), + zone_size, + theta1=-fov / 2, + theta2=fov / 2, + color=fish_colors[ei], + alpha=0.2, + fill=False, + ) + ) zones.append(entity_zones) - + zones_flat = [] for zones_fish in zones: for zone in zones_fish: @@ -1321,7 +1345,7 @@ class File(h5py.File): pbar = tqdm(range(n_frames)) if video_path is not None else None self.fov_orientations = [None] * n_entities - + def update(frame): output_list = [] @@ -1399,19 +1423,27 @@ class File(h5py.File): poses_trails = entity_poses[:, max(0, file_frame - trail) : file_frame] for i_entity in range(n_entities): new_ori = this_pose[i_entity, 2] - old_ori = self.fov_orientations[i_entity] if self.fov_orientations[i_entity] is not None else new_ori + old_ori = ( + self.fov_orientations[i_entity] + if self.fov_orientations[i_entity] is not None + else new_ori + ) # Mix in complex space to handle the wrap around - smoothed_ori = fov_smoothing_factor * np.exp(1j * old_ori) + (1-fov_smoothing_factor) * np.exp(1j * new_ori) + smoothed_ori = fov_smoothing_factor * np.exp(1j * old_ori) + ( + 1 - fov_smoothing_factor + ) * np.exp(1j * new_ori) self.fov_orientations[i_entity] = np.angle(smoothed_ori) - + ori_deg = np.rad2deg(self.fov_orientations[i_entity]) - - for zone in zones[i_entity]: - zone.set_center(( - this_pose[i_entity, 0], - this_pose[i_entity, 1], - )) + + for zone in zones[i_entity]: + zone.set_center( + ( + this_pose[i_entity, 0], + this_pose[i_entity, 1], + ) + ) if fovs[i_entity] < 360: zone.theta1 = ori_deg - fovs[i_entity] / 2 zone.theta2 = ori_deg + fovs[i_entity] / 2 @@ -1482,18 +1514,11 @@ def get_zone_sizes_from_attrs(e: robofish.io.Entity) -> Dict[str, float]: possible_params = ["fov", "zor", "zoa", "zoo", "preferred_d", "neighbor_radius"] model_params = e["model"]["parameters"].attrs return {p: model_params[p] for p in possible_params if p in model_params} - + def get_zone_sizes_from_model_str(model: str) -> Dict[str, float]: zone_sizes = {} - for zone in [ - "zor", - "zoa", - "zoo", - "preferred_d", - "neighbor_radius", - "fov" - ]: + for zone in ["zor", "zoa", "zoo", "preferred_d", "neighbor_radius", "fov"]: match = re.search(r"{}=(\d+(?:\.\d+)?)".format(zone), model) if match: value = float(match.group(1)) diff --git a/src/robofish/io/utils.py b/src/robofish/io/utils.py index 940879b0e4d389680f6a84a386bf8550da3b0142..ebc18f13710f82ca2442b6c3fe882c4ff6560a19 100644 --- a/src/robofish/io/utils.py +++ b/src/robofish/io/utils.py @@ -48,7 +48,7 @@ def get_all_files_from_paths(paths: Iterable[Union[str, Path]], max_files=None): files_path = [] for ext in ("hdf", "hdf5", "h5", "he5"): files_path += list(path.rglob(f"*.{ext}")) - + files.append(sorted(files_path)[:max_files]) else: files.append([path]) @@ -123,3 +123,71 @@ def get_all_data_from_paths( all_data.append(data_from_files) pbar.close() return all_data, expected_settings + + +def social_vectors(poses): + """ + Args: + poses (np.ndarray): n_tracks, n_fish, n_timesteps, (x,y, ori_rad) + """ + + assert not np.any(np.isnan(poses)), "poses contains NaNs" + + n_tracks, n_fish, n_timesteps, three = poses.shape + assert three == 3, "poses.shape[-1] must be 3 (x,y,ori_rad). Got {three}" + + sv = np.zeros((n_tracks, n_fish, n_timesteps, n_fish, 4)) + for focal_id in range(n_fish): + for other_id in range(n_fish): + if focal_id == other_id: + continue + + # Compute relative orienation + # sv[:, focal_id, :, other_id, 2] = ( + # poses[:, other_id, :, 2] - poses[:, focal_id, :, 2] + # ) + relative_orientation = poses[:, other_id, :, 2] - poses[:, focal_id, :, 2] + sv[:, focal_id, :, other_id, 2:4] = np.stack( + [np.cos(relative_orientation), np.sin(relative_orientation)], axis=-1 + ) + + # Raw social vectors + xy_offset = poses[:, other_id, :, :2] - poses[:, focal_id, :, :2] + + # Rotate them, so they are from the POV of the focal fish + + # to convert a vector from world coords to fish coords (when the fish has + # an orientation of alpha), you need to rotate the vector by pi/2 first + # and then rotate it back by alpha + rotate_by = np.pi / 2 - poses[:, focal_id, :, 2] + + # perform rotation of raw social vectors + sin, cos = np.sin(rotate_by), np.cos(rotate_by) + sv[:, focal_id, :, other_id, 0] = ( + xy_offset[:, :, 0] * cos - xy_offset[:, :, 1] * sin + ) + sv[:, focal_id, :, other_id, 1] = ( + xy_offset[:, :, 0] * sin + xy_offset[:, :, 1] * cos + ) + + assert not np.isnan(sv).any(), np.isnan(sv) + + # assert length of social vectors was preserved + assert True or np.all( + np.isclose( + np.linalg.norm( + social_vectors[:, focal_id, :, other_id, :2], axis=2 + ), + np.linalg.norm(xy_offset, axis=2), + ) + ), "The length of the social vector was not preserved." + + return sv + + +def social_vectors_without_focal_zeros(poses): + sv = social_vectors(poses) + mask = np.full_like(sv, fill_value=True, dtype=bool) + for focal_id in range(sv.shape[1]): + mask[:, focal_id, :, focal_id] = False + return sv[mask].reshape((sv.shape[0], sv.shape[1], sv.shape[2], sv.shape[3] - 1, 4)) diff --git a/tests/robofish/evaluate/test_app_evaluate.py b/tests/robofish/evaluate/test_app_evaluate.py index 4b1332f7dad72d268e1c6f59a571731f2ca06159..1d0ff26f6505e6dcc476811dc7e9fa7bc16ea8b8 100644 --- a/tests/robofish/evaluate/test_app_evaluate.py +++ b/tests/robofish/evaluate/test_app_evaluate.py @@ -22,6 +22,7 @@ class DummyArgs: self.save_path = save_path self.labels = None self.add_train_data = add_train_data + self.max_files = None def test_app_validate(tmp_path): diff --git a/tests/robofish/io/test_app_io.py b/tests/robofish/io/test_app_io.py index 6c3a16f62e6f34708624236ec4d1525be4b39254..c9c7515ad9eb5db0ce972efad445fecfe6d5d418 100644 --- a/tests/robofish/io/test_app_io.py +++ b/tests/robofish/io/test_app_io.py @@ -22,8 +22,8 @@ def test_app_validate(): with pytest.warns(UserWarning): raw_output = app.validate(DummyArgs([resources_path], "raw")) - # The three files valid.hdf5, almost_valid.hdf5, and invalid.hdf5 should be found. - assert len(raw_output) == 4 + # The three files valid.hdf5, almost_valid.hdf5, valid_couzin_params.hdf5, and invalid.hdf5 should be found. + assert len(raw_output) == 5 # invalid.hdf5 should pass a warning with pytest.warns(UserWarning):