diff --git a/.gitignore b/.gitignore index 0cd5109888fa7ddcebdb5a6386249d8b1aed7110..674857fcdf8c6b545bd3339ba3b7f81210cfd9ae 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,8 @@ env *.hdf5 *.mp4 +# Ignore all files with ignore_ prefix +ignore_* !tests/resources/*.hdf5 feature_requests.md diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 1ffc452e989e72155ab4866bb2f7e2e42cbec40d..8649c2f8d11f98dc902d607cacfc9089b7e558ae 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -11,11 +11,11 @@ stages: .macos: tags: [macos, shell] -.windows: - tags: [windows, docker] - image: git.imp.fu-berlin.de:5000/bioroboticslab/auto/ci/windows:latest-devel - before_script: - - . $Profile.AllUsersAllHosts +#.windows: +# tags: [windows, docker] +# image: git.imp.fu-berlin.de:5000/bioroboticslab/auto/ci/windows:latest-devel +# before_script: +# - . $Profile.AllUsersAllHosts .python38: &python38 PYTHON_VERSION: "3.8" @@ -54,9 +54,9 @@ package: extends: .macos <<: *test -"test: [windows, 3.8]": - extends: .windows - <<: *test +#"test: [windows, 3.8]": +# extends: .windows +# <<: *test deploy to staging: extends: .centos diff --git a/src/conversion_scripts/load_model_params_from_model_options.py b/src/conversion_scripts/load_model_params_from_model_options.py new file mode 100644 index 0000000000000000000000000000000000000000..fb17140de699ae3a8b2c0419f8ee32da59dec804 --- /dev/null +++ b/src/conversion_scripts/load_model_params_from_model_options.py @@ -0,0 +1,44 @@ +from pathlib import Path +import argparse +import robofish.io +from tqdm import tqdm + +def main(file_path): + files = list(file_path.glob("*.hdf5")) if file_path.is_dir() else [file_path] + + for f in tqdm(files): + with robofish.io.File(f, "r+") as iof: + model_options = iof.attrs["model_options"] + model_options = eval(model_options)["options"] + + neccessary_options = ["zor", "zoo", "zoa", "fov", "additive_zone_sizes"] + + # extracted_options + e_options = {no: model_options[no] for no in neccessary_options} + + if e_options["additive_zone_sizes"]: + e_options["zoo"] += e_options["zor"] + e_options["zoa"] += e_options["zoo"] + + for e in iof.entities: + e.attrs["category"] = "organism" + + if "model" in e: + del e["model"] + + g = e.create_group("model") + g.attrs["name"] = "couzin" + + p = g.create_group("parameters") + p.attrs["zoo"] = e_options["zoo"] + p.attrs["zoa"] = e_options["zoa"] + p.attrs["zor"] = e_options["zor"] + p.attrs["fov"] = e_options["fov"] + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert csv file from the socoro experiments to RoboFish track format.") + parser.add_argument("file_path", type=str, help="Path to the csv file.") + args = parser.parse_args() + main(Path(args.file_path)) \ No newline at end of file diff --git a/src/robofish/evaluate/app.py b/src/robofish/evaluate/app.py index c7d0a443959a55ee1f724fd8dafa03898312562a..78697fd24c4374f8481c6408eaa1201955d6804c 100644 --- a/src/robofish/evaluate/app.py +++ b/src/robofish/evaluate/app.py @@ -94,20 +94,27 @@ def evaluate(args: dict = None) -> None: default=None, ) parser.add_argument( - "--save_path", + "--save-path", + "--save_path", # for backwards compatibility type=str, help="Filename for saving resulting graphics.", default=None, ) parser.add_argument( - "--add_train_data", + "--max-files", + "--max_files", # for backwards compatibility + type=int, + default=None, + help="The maximum number of files to be loaded from the given paths.", + ) + parser.add_argument( + "--add-train-data", + "--add_train_data", # for backwards compatibility action="store_true", help="Add the training data to the evaluation.", default=False, ) - # TODO: ignore fish/ consider_names - if args is None: args = parser.parse_args() @@ -117,6 +124,7 @@ def evaluate(args: dict = None) -> None: if args.analysis_type in fdict: paths = args.paths labels = args.labels + max_files = args.max_files print("starting", paths, labels) if labels is None: @@ -141,7 +149,7 @@ def evaluate(args: dict = None) -> None: print("starting", paths, labels) save_path = None if args.save_path is None else Path(args.save_path) - params = {"paths": paths, "labels": labels} + params = {"paths": paths, "labels": labels, "max_files": max_files} if args.analysis_type == "all": normal_functions = function_dict() normal_functions.pop("all") diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py index bede8f7ada5fbabbed8b39940bc94e3a26355097..793eb31f9a43881de231c5579c3eec493aa6d0a9 100644 --- a/src/robofish/evaluate/evaluate.py +++ b/src/robofish/evaluate/evaluate.py @@ -31,6 +31,7 @@ def evaluate_speed( labels: Iterable[str] = None, predicate: Callable[[robofish.io.Entity], bool] = None, speeds_turns_from_paths: Iterable[Iterable[np.ndarray]] = None, + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the speed of the entities as histogram. @@ -47,7 +48,7 @@ def evaluate_speed( if speeds_turns_from_paths is None: speeds_turns_from_paths, _ = utils.get_all_data_from_paths( - paths, "speeds_turns", predicate=predicate + paths, "speeds_turns", predicate=predicate, max_files=max_files ) speeds = [] @@ -92,6 +93,7 @@ def evaluate_turn( predicate: Callable[[robofish.io.Entity], bool] = None, speeds_turns_from_paths: Iterable[Iterable[np.ndarray]] = None, file_settings: dict = None, + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the turn angles of the entities as histogram. @@ -109,7 +111,7 @@ def evaluate_turn( if speeds_turns_from_paths is None: speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths( - paths, "speeds_turns", predicate=predicate + paths, "speeds_turns", predicate=predicate, max_files=max_files ) turns = [] @@ -156,6 +158,7 @@ def evaluate_orientation( predicate: Callable[[robofish.io.Entity], bool] = None, poses_from_paths: Iterable[Iterable[np.ndarray]] = None, file_settings: dict = None, + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the orientations of the entities on a 2d grid. @@ -172,7 +175,7 @@ def evaluate_orientation( """ if poses_from_paths is None: poses_from_paths, file_settings = utils.get_all_poses_from_paths( - paths, predicate=predicate + paths, predicate=predicate, max_files=max_files ) world_bounds = [ @@ -243,6 +246,7 @@ def evaluate_relative_orientation( predicate: Callable[[robofish.io.Entity], bool] = None, poses_from_paths: Iterable[Iterable[np.ndarray]] = None, file_settings: dict = None, + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the relative orientations of the entities as a histogram. @@ -260,7 +264,7 @@ def evaluate_relative_orientation( if poses_from_paths is None: poses_from_paths, file_settings = utils.get_all_poses_from_paths( - paths, predicate + paths, predicate, max_files=max_files ) orientations = [] @@ -299,6 +303,7 @@ def evaluate_distance_to_wall( predicate: Callable[[robofish.io.Entity], bool] = None, poses_from_paths: Iterable[Iterable[np.ndarray]] = None, file_settings: dict = None, + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the distances of the entities to the walls as a histogram. Lambda function example: lambda e: e.category == "fish" @@ -315,7 +320,7 @@ def evaluate_distance_to_wall( if poses_from_paths is None: poses_from_paths, file_settings = utils.get_all_poses_from_paths( - paths, predicate + paths, predicate, max_files=max_files ) world_bounds = [ @@ -387,6 +392,7 @@ def evaluate_tank_position( poses_from_paths: Iterable[Iterable[np.ndarray]] = None, file_settings: dict = None, max_points: int = 4000, + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the positions of the entities as a heatmap. Lambda function example: lambda e: e.category == "fish" @@ -404,7 +410,7 @@ def evaluate_tank_position( if poses_from_paths is None: poses_from_paths, file_settings = utils.get_all_poses_from_paths( - paths, predicate + paths, predicate, max_files=max_files ) xy_positions = [] @@ -479,18 +485,13 @@ def evaluate_quiver( """ try: import torch - from fish_models.models.pascals_lstms.attribution import SocialVectors except ImportError: - print( - "Either torch or fish_models is not installed.", - "The social vector should come from robofish io.", - "This is a known issue (#24)", - ) + print("Torch is not installed and could not be imported") return if poses_from_paths is None: poses_from_paths, file_settings = utils.get_all_poses_from_paths( - paths, predicate, predicate + paths, predicate, max_files=max_files ) if speeds_turns_from_paths is None: speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths( @@ -505,11 +506,6 @@ def evaluate_quiver( poses_from_paths = np.array(poses_from_paths) - # if speeds_turns_from_paths.dtype == np.object: - # print( - # "This will probably fail because the type of speeds_turns_from_path is object" - # ) - # print(speeds_turns_from_paths) try: print(speeds_turns_from_paths) speeds_turns_from_paths = np.stack( @@ -552,10 +548,11 @@ def evaluate_quiver( # print(tank_directions[x,y] - tank_directions_speed[x,y]) tank_count[x, y] = len(d) - sv = SocialVectors(poses_from_paths[0]) - sv_r = torch.tensor(sv.social_vectors_without_focal_zeros)[:, :, :-1].reshape( - (-1, 3) - ) # [:1000] + sv_r = torch.tensor( + robofish.io.utils.social_vectors_without_focal_zeros( + poses_from_paths[0][:, :, :-1] + ).reshape((-1, 4)) + ) sv_r = torch.cat( (sv_r[:, :2], torch.cos(sv_r[:, 2:]), torch.sin(sv_r[:, 2:])), dim=1 ) @@ -636,6 +633,7 @@ def evaluate_social_vector( predicate: Callable[[robofish.io.Entity], bool] = None, poses_from_paths: Iterable[Iterable[np.ndarray]] = None, file_settings: dict = None, + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the vectors pointing from the focal fish to the conspecifics as heatmap. Lambda function example: lambda e: e.category == "fish" @@ -650,15 +648,9 @@ def evaluate_social_vector( matplotlib.figure.Figure: The figure of the social vectors. """ - try: - from fish_models.models.pascals_lstms.attribution import SocialVectors - except ImportError: - warnings.warn("Please install the fish_models package to use this function.") - return - if poses_from_paths is None: poses_from_paths, file_settings = utils.get_all_poses_from_paths( - paths, predicate + paths, predicate, max_files=max_files ) fig, ax = plt.subplots( @@ -678,8 +670,8 @@ def evaluate_social_vector( axis=-1, ) - social_vec = SocialVectors(poses).social_vectors_without_focal_zeros - flat_sv = social_vec.reshape((-1, 3)) + social_vec = robofish.io.utils.social_vectors_without_focal_zeros(poses) + flat_sv = social_vec.reshape((-1, 4)) bins = ( 30 if flat_sv.shape[0] < 40000 else 50 if flat_sv.shape[0] < 65000 else 100 @@ -703,6 +695,7 @@ def evaluate_follow_iid( predicate: Callable[[robofish.io.Entity], bool] = None, poses_from_paths: Iterable[Iterable[np.ndarray]] = None, file_settings: dict = None, + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the follow metric in respect to the inter individual distance (iid). Lambda function example: lambda e: e.category == "fish" @@ -719,7 +712,7 @@ def evaluate_follow_iid( if poses_from_paths is None: poses_from_paths, file_settings = utils.get_all_poses_from_paths( - paths, predicate + paths, predicate, max_files=max_files ) follow, iid = [], [] @@ -769,7 +762,7 @@ def evaluate_follow_iid( follow_iid_data = pd.DataFrame( { - "IID [cm]": iid[i][mask], + f"IID [cm] {labels[i]}": iid[i][mask], "Follow": follow[i][mask], }, dtype=np.float32, @@ -777,7 +770,7 @@ def evaluate_follow_iid( plt.rcParams["lines.markersize"] = 1 grid = sns.jointplot( - x="IID [cm]", + x=f"IID [cm] {labels[i]}", y="Follow", data=follow_iid_data, kind="hist", @@ -788,20 +781,13 @@ def evaluate_follow_iid( joint_kws={"bins": 30}, marginal_kws=dict(bins=30), ) - # grid.fig.set_figwidth(9) - # grid.fig.set_figheight(6) - # grid.fig.subplots_adjust(top=0.9) grids.append(grid) # This is neccessary because joint plot does not receive an ax object. # It creates issues with too many plots though # Created an issue - fig = plt.figure(figsize=(6 * len(grids), 6)) + fig = plt.figure(figsize=(12, 5)) - fig.suptitle( - f"follow/iid: from left to right:\n{', '.join([str(l) for l in labels])}", - fontsize=12, - ) gs = gridspec.GridSpec(1, len(grids)) for i in range(len(grids)): @@ -817,6 +803,7 @@ def evaluate_tracks_distance( poses_from_paths: Iterable[Iterable[np.ndarray]] = None, file_settings: dict = None, max_timesteps: int = 4000, + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the distances of two or more fish on the track. Lambda function example: lambda e: e.category == "fish" @@ -837,6 +824,7 @@ def evaluate_tracks_distance( predicate, lw_distances=True, max_timesteps=max_timesteps, + max_files=max_files, ) @@ -850,6 +838,7 @@ def evaluate_tracks( seed: int = 42, max_timesteps: int = None, verbose: bool = False, + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the distances of two or more fish on the track. Lambda function example: lambda e: e.category == "fish" @@ -871,9 +860,11 @@ def evaluate_tracks( paths = [Path(p) for p in paths] random.seed(seed) - files_per_path = utils.get_all_files_from_paths(paths) + files_per_path = utils.get_all_files_from_paths(paths, max_files=max_files) max_files_per_path = max([len(files) for files in files_per_path]) + min_files_per_path = min([len(files) for files in files_per_path]) + rows, cols = len(files_per_path), min(6, max_files_per_path) multirow = False @@ -886,12 +877,13 @@ def evaluate_tracks( initial_poses_info_available = None all_initial_info = [] + selected_tracks = np.random.choice(range(min_files_per_path), cols, replace=False) + # Iterate all paths for k, files_in_path in enumerate(files_per_path): - random.shuffle(files_in_path) # Iterate all files - for i, file_path in enumerate(files_in_path): + for i, file_path in enumerate(np.array(files_in_path)[selected_tracks]): if multirow: if i >= cols * rows: break @@ -1322,26 +1314,26 @@ def calculate_distLinePoint( def show_values( pc: matplotlib.collections.PolyCollection, fmt: str = "%.2f", **kw: dict ) -> None: - """Show numbers on plt.ax.pccolormesh plot. - - https://stackoverflow.com/questions/25071968/ - heatmap-with-text-in-each-cell-with-matplotlibs-pyplot + """Show numbers on plt.ax.pcolormesh plot. Args: pc(matplotlib.collections.PolyCollection): The plot to show the values on. fmt(str): The format of the values. kw(dict): The keyword arguments. """ + values = pc.get_array() pc.update_scalarmappable() ax = pc.axes - for p, color, value in zip(pc.get_paths(), pc.get_facecolors(), pc.get_array()): + + for p, color, value in zip(pc.get_paths(), pc.get_facecolors(), values.ravel()): x, y = p.vertices[:-2, :].mean(0) - if np.all(color[:3] > 0.5): - color = (0.0, 0.0, 0.0) - else: - color = (1.0, 1.0, 1.0) - if value == "--": - ax.text(x, y, fmt % value, ha="center", va="center", color=color, **kw) + + # Choose text color based on background + text_color = (0.0, 0.0, 0.0) if np.all(color[:3] > 0.5) else (1.0, 1.0, 1.0) + + # Only show non-masked values + if not np.ma.is_masked(value): + ax.text(x, y, fmt % value, ha="center", va="center", color=text_color, **kw) class SeabornFig2Grid: diff --git a/src/robofish/io/app.py b/src/robofish/io/app.py index 9c352111019ec9606694fcff65299ecd86d00ba5..f0f03a32a9def768814158e0fb050ff1de586119 100644 --- a/src/robofish/io/app.py +++ b/src/robofish/io/app.py @@ -41,14 +41,16 @@ def print_file(args: argparse.Namespace = None) -> bool: parser.add_argument("path", type=str, help="The path to a hdf5 file") parser.add_argument( - "--output_format", + "--output-format", + "--output_format", # backwards compatibility type=str, choices=["shape", "full"], default="shape", help="Choose how datasets are printed, either the shapes or the full content is printed", ) parser.add_argument( - "--full_attrs", + "--full-attrs", + "--full_attrs", # backwards compatibility default=False, action="store_true", help="Show full unabbreviated values for attributes", @@ -122,7 +124,8 @@ def validate(args: argparse.Namespace = None) -> int: description="The function can be directly accessed from the commandline and can be given any number of files or folders. The function returns the validity of the files in a human readable format or as a raw output." ) parser.add_argument( - "--output_format", + "--output-format", + "--output_format", # backwards compatibility type=str, default="h", choices=["h", "raw"], @@ -179,7 +182,9 @@ def render_file(kwargs: Dict) -> None: kwargs (Dict, optional): A dictionary containing the arguments for the render function. """ with robofish.io.File(path=kwargs["path"]) as f: - f.render(**{k:v for k,v in kwargs.items() if k not in ("path", "reference_track")}) + f.render( + **{k: v for k, v in kwargs.items() if k not in ("path", "reference_track")} + ) def overwrite_user_configs() -> None: @@ -248,7 +253,7 @@ def render(args: argparse.Namespace = None) -> None: default=[], help="Custom colors to use for guppies. Use spaces as delimiter. " "To set all guppies to the same color, pass only one color. " - "Hexadecimal values, color names and matplotlib abbreviations are supported (\"#000000\", black, k)" + 'Hexadecimal values, color names and matplotlib abbreviations are supported ("#000000", black, k)', ) default_options = { @@ -267,6 +272,7 @@ def render(args: argparse.Namespace = None) -> None: "render_swarm_center": False, "highlight_switches": False, "figsize": 10, + "fov_smoothing_factor": 0.8, } for key, value in default_options.items(): diff --git a/src/robofish/io/entity.py b/src/robofish/io/entity.py index c2e7a67d128039d075d4ece787ac59f2a900e3f8..7ff5b340a33d63d8d8b6430378702cd0020b0d53 100644 --- a/src/robofish/io/entity.py +++ b/src/robofish/io/entity.py @@ -236,7 +236,7 @@ class Entity(h5py.Group): "merge to the master branch of fish_models if nothing helps, contact Andi.\n" "Don't ignore this warning, it's a serious issue.", ) - def speed_turn(self): + def speed_turn(self) -> np.ndarray: """Get the speed, turn and from the positions. The vectors pointing from each position to the next are computed. diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index a1b82e001e2c179f1961d54404d3152619304fd1..0afc5cdef44b33f972c6bc59e86b4c4f89c2f708 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -61,7 +61,7 @@ class File(h5py.File): def __init__( self, - path: Optinal[Union[str, Path]] = None, + path: Optional[Union[str, Path]] = None, mode: str = "r", *, # PEP 3102 world_size_cm: Optional[List[int]] = None, @@ -150,6 +150,8 @@ class File(h5py.File): self.validate_when_saving = validate_when_saving self.calculate_data_on_close = calculate_data_on_close + self._entities = None + if open_copy: assert ( path is not None @@ -502,6 +504,8 @@ class File(h5py.File): sampling=sampling, ) + self._entities = None # Reset the entities cache + return entity def create_multiple_entities( @@ -584,10 +588,12 @@ class File(h5py.File): @property def entities(self): - return [ - robofish.io.Entity.from_h5py_group(self["entities"][name]) - for name in self.entity_names - ] + if self._entities is None: + self._entities = [ + robofish.io.Entity.from_h5py_group(self["entities"][name]) + for name in self.entity_names + ] + return self._entities @property def entity_positions(self): @@ -702,7 +708,11 @@ class File(h5py.File): else: properties = [entity_property.__get__(entity) for entity in entities] - max_timesteps = max([0] + [p.shape[0] for p in properties]) + n_timesteps = [p.shape[0] for p in properties] + max_timesteps = max(n_timesteps) + + if np.all(np.equal(n_timesteps, max_timesteps)): + return np.array(properties) property_array = np.empty( (len(entities), max_timesteps, properties[0].shape[1]) @@ -926,7 +936,7 @@ class File(h5py.File): else: step_size = poses.shape[1] - cmap = matplotlib.cm.get_cmap(cmap) + cmap = plt.get_cmap(cmap) x_world, y_world = self.world_size if figsize is None: @@ -971,6 +981,15 @@ class File(h5py.File): label="End", zorder=5, ) + ax.scatter( + [poses[:, 0, 0]], + [poses[:, 0, 1]], + marker="o", + c="black", + s=ms, + label="Start", + zorder=5, + ) if legend and isinstance(legend, str): ax.legend(legend) elif legend: @@ -1007,6 +1026,7 @@ class File(h5py.File): custom_colors: bool = None, dpi: int = 200, figsize: int = 10, + fov_smoothing_factor: float = 0.8, ) -> None: """Render a video of the file. @@ -1059,13 +1079,14 @@ class File(h5py.File): categories = [entity.attrs.get("category", None) for entity in self.entities] n_fish = len([c for c in categories if c == "organism"]) - lines = [ plt.plot( [], [], lw=linewidth, - color=custom_colors[i%len(custom_colors)-1] if custom_colors else None, + color=custom_colors[i % len(custom_colors) - 1] + if custom_colors + else None, zorder=0, )[0] for i in range(n_entities) @@ -1077,23 +1098,54 @@ class File(h5py.File): if categories[i] == "organism" ] - zone_sizes = [ - ( - get_zone_sizes(self.attrs.get("guppy_model_rollout", "")) - if render_zones - else {} - ) - for _ in range(n_fish) - ] - zones = [ - [ - plt.Circle( - (0, 0), zone_size, color=fish_colors[i], alpha=0.2, fill=False + zones = [] + if render_zones: + fovs = [] + for ei, e in enumerate(self.entities): + zone_sizes_str = get_zone_sizes_from_model_str( + self.attrs.get("guppy_model_rollout", "") ) - for zone_size in zone_sizes_fish.values() - ] - for i, zone_sizes_fish in enumerate(zone_sizes) - ] + zone_sizes_attrs = get_zone_sizes_from_attrs(e) + + # Check that there are no zone sizes in the model and in the attributes + assert ( + zone_sizes_str == {} or zone_sizes_attrs == {} + ), "There are zone sizes in the model and in the attributes. Please use only one (preferrably the attributes)." + zone_sizes = ( + zone_sizes_attrs if zone_sizes_attrs != {} else zone_sizes_str + ) + + fov = zone_sizes.get("fov", np.pi * 2) + fov = np.rad2deg(fov) + zone_sizes.pop("fov", None) + fovs.append(fov) + + entity_zones = [] + for zone_size in zone_sizes.values(): + if fov >= 360: + entity_zones.append( + matplotlib.patches.Circle( + (0, 0), + zone_size, + color=fish_colors[ei], + alpha=0.2, + fill=False, + ) + ) + else: + entity_zones.append( + matplotlib.patches.Wedge( + (0, 0), + zone_size, + theta1=-fov / 2, + theta2=fov / 2, + color=fish_colors[ei], + alpha=0.2, + fill=False, + ) + ) + zones.append(entity_zones) + zones_flat = [] for zones_fish in zones: for zone in zones_fish: @@ -1292,6 +1344,8 @@ class File(h5py.File): pbar = tqdm(range(n_frames)) if video_path is not None else None + self.fov_orientations = [None] * n_entities + def update(frame): output_list = [] @@ -1368,12 +1422,31 @@ class File(h5py.File): poses_trails = entity_poses[:, max(0, file_frame - trail) : file_frame] for i_entity in range(n_entities): - if categories[i_entity] == "organism": - for zone in zones[i_entity]: - zone.center = ( + new_ori = this_pose[i_entity, 2] + old_ori = ( + self.fov_orientations[i_entity] + if self.fov_orientations[i_entity] is not None + else new_ori + ) + + # Mix in complex space to handle the wrap around + smoothed_ori = fov_smoothing_factor * np.exp(1j * old_ori) + ( + 1 - fov_smoothing_factor + ) * np.exp(1j * new_ori) + self.fov_orientations[i_entity] = np.angle(smoothed_ori) + + ori_deg = np.rad2deg(self.fov_orientations[i_entity]) + + for zone in zones[i_entity]: + zone.set_center( + ( this_pose[i_entity, 0], this_pose[i_entity, 1], ) + ) + if fovs[i_entity] < 360: + zone.theta1 = ori_deg - fovs[i_entity] / 2 + zone.theta2 = ori_deg + fovs[i_entity] / 2 if render_swarm_center: swarm_center[0].set_offsets(swarm_center_position[file_frame]) @@ -1434,15 +1507,18 @@ class File(h5py.File): plt.show() -def get_zone_sizes(model: str): +def get_zone_sizes_from_attrs(e: robofish.io.Entity) -> Dict[str, float]: + if "model" not in e: + return {} + else: + possible_params = ["fov", "zor", "zoa", "zoo", "preferred_d", "neighbor_radius"] + model_params = e["model"]["parameters"].attrs + return {p: model_params[p] for p in possible_params if p in model_params} + + +def get_zone_sizes_from_model_str(model: str) -> Dict[str, float]: zone_sizes = {} - for zone in [ - "zor", - "zoa", - "zoo", - "preferred_d", - "neighbor_radius", - ]: + for zone in ["zor", "zoa", "zoo", "preferred_d", "neighbor_radius", "fov"]: match = re.search(r"{}=(\d+(?:\.\d+)?)".format(zone), model) if match: value = float(match.group(1)) diff --git a/src/robofish/io/utils.py b/src/robofish/io/utils.py index 1a646662e77d21e42734d6e8cce1d8bf8a416f7c..ebc18f13710f82ca2442b6c3fe882c4ff6560a19 100644 --- a/src/robofish/io/utils.py +++ b/src/robofish/io/utils.py @@ -48,7 +48,8 @@ def get_all_files_from_paths(paths: Iterable[Union[str, Path]], max_files=None): files_path = [] for ext in ("hdf", "hdf5", "h5", "he5"): files_path += list(path.rglob(f"*.{ext}")) - files.append(files_path[:max_files]) + + files.append(sorted(files_path)[:max_files]) else: files.append([path]) return files @@ -122,3 +123,71 @@ def get_all_data_from_paths( all_data.append(data_from_files) pbar.close() return all_data, expected_settings + + +def social_vectors(poses): + """ + Args: + poses (np.ndarray): n_tracks, n_fish, n_timesteps, (x,y, ori_rad) + """ + + assert not np.any(np.isnan(poses)), "poses contains NaNs" + + n_tracks, n_fish, n_timesteps, three = poses.shape + assert three == 3, "poses.shape[-1] must be 3 (x,y,ori_rad). Got {three}" + + sv = np.zeros((n_tracks, n_fish, n_timesteps, n_fish, 4)) + for focal_id in range(n_fish): + for other_id in range(n_fish): + if focal_id == other_id: + continue + + # Compute relative orienation + # sv[:, focal_id, :, other_id, 2] = ( + # poses[:, other_id, :, 2] - poses[:, focal_id, :, 2] + # ) + relative_orientation = poses[:, other_id, :, 2] - poses[:, focal_id, :, 2] + sv[:, focal_id, :, other_id, 2:4] = np.stack( + [np.cos(relative_orientation), np.sin(relative_orientation)], axis=-1 + ) + + # Raw social vectors + xy_offset = poses[:, other_id, :, :2] - poses[:, focal_id, :, :2] + + # Rotate them, so they are from the POV of the focal fish + + # to convert a vector from world coords to fish coords (when the fish has + # an orientation of alpha), you need to rotate the vector by pi/2 first + # and then rotate it back by alpha + rotate_by = np.pi / 2 - poses[:, focal_id, :, 2] + + # perform rotation of raw social vectors + sin, cos = np.sin(rotate_by), np.cos(rotate_by) + sv[:, focal_id, :, other_id, 0] = ( + xy_offset[:, :, 0] * cos - xy_offset[:, :, 1] * sin + ) + sv[:, focal_id, :, other_id, 1] = ( + xy_offset[:, :, 0] * sin + xy_offset[:, :, 1] * cos + ) + + assert not np.isnan(sv).any(), np.isnan(sv) + + # assert length of social vectors was preserved + assert True or np.all( + np.isclose( + np.linalg.norm( + social_vectors[:, focal_id, :, other_id, :2], axis=2 + ), + np.linalg.norm(xy_offset, axis=2), + ) + ), "The length of the social vector was not preserved." + + return sv + + +def social_vectors_without_focal_zeros(poses): + sv = social_vectors(poses) + mask = np.full_like(sv, fill_value=True, dtype=bool) + for focal_id in range(sv.shape[1]): + mask[:, focal_id, :, focal_id] = False + return sv[mask].reshape((sv.shape[0], sv.shape[1], sv.shape[2], sv.shape[3] - 1, 4)) diff --git a/tests/resources/valid_couzin_params.hdf5 b/tests/resources/valid_couzin_params.hdf5 new file mode 100644 index 0000000000000000000000000000000000000000..506cdd1ae9df5a0438b2bcec9b47ffc016037515 Binary files /dev/null and b/tests/resources/valid_couzin_params.hdf5 differ diff --git a/tests/robofish/evaluate/test_app_evaluate.py b/tests/robofish/evaluate/test_app_evaluate.py index 4b1332f7dad72d268e1c6f59a571731f2ca06159..1d0ff26f6505e6dcc476811dc7e9fa7bc16ea8b8 100644 --- a/tests/robofish/evaluate/test_app_evaluate.py +++ b/tests/robofish/evaluate/test_app_evaluate.py @@ -22,6 +22,7 @@ class DummyArgs: self.save_path = save_path self.labels = None self.add_train_data = add_train_data + self.max_files = None def test_app_validate(tmp_path): diff --git a/tests/robofish/io/test_app_io.py b/tests/robofish/io/test_app_io.py index 6c3a16f62e6f34708624236ec4d1525be4b39254..c9c7515ad9eb5db0ce972efad445fecfe6d5d418 100644 --- a/tests/robofish/io/test_app_io.py +++ b/tests/robofish/io/test_app_io.py @@ -22,8 +22,8 @@ def test_app_validate(): with pytest.warns(UserWarning): raw_output = app.validate(DummyArgs([resources_path], "raw")) - # The three files valid.hdf5, almost_valid.hdf5, and invalid.hdf5 should be found. - assert len(raw_output) == 4 + # The three files valid.hdf5, almost_valid.hdf5, valid_couzin_params.hdf5, and invalid.hdf5 should be found. + assert len(raw_output) == 5 # invalid.hdf5 should pass a warning with pytest.warns(UserWarning):