From 26710c205d04ed26de91f6838e10e406f9d5d04d Mon Sep 17 00:00:00 2001
From: Andi <andi.gerken@gmail.com>
Date: Mon, 24 Mar 2025 14:56:49 +0100
Subject: [PATCH] Moved social vectors to robofish.io.utils to finally solve
 issue #24

---
 .gitignore                                   |  2 +
 src/robofish/evaluate/app.py                 | 11 +--
 src/robofish/evaluate/evaluate.py            | 91 ++++++++++----------
 src/robofish/io/app.py                       | 17 ++--
 src/robofish/io/file.py                      | 87 ++++++++++++-------
 src/robofish/io/utils.py                     | 70 ++++++++++++++-
 tests/robofish/evaluate/test_app_evaluate.py |  1 +
 tests/robofish/io/test_app_io.py             |  4 +-
 8 files changed, 191 insertions(+), 92 deletions(-)

diff --git a/.gitignore b/.gitignore
index 0cd5109..674857f 100644
--- a/.gitignore
+++ b/.gitignore
@@ -17,6 +17,8 @@ env
 *.hdf5
 *.mp4
 
+# Ignore all files with ignore_ prefix
+ignore_*
 !tests/resources/*.hdf5
 
 feature_requests.md
diff --git a/src/robofish/evaluate/app.py b/src/robofish/evaluate/app.py
index 5c837bd..78697fd 100644
--- a/src/robofish/evaluate/app.py
+++ b/src/robofish/evaluate/app.py
@@ -94,26 +94,27 @@ def evaluate(args: dict = None) -> None:
         default=None,
     )
     parser.add_argument(
-        "--save_path",
+        "--save-path",
+        "--save_path",  # for backwards compatibility
         type=str,
         help="Filename for saving resulting graphics.",
         default=None,
     )
     parser.add_argument(
-        "--max_files",
+        "--max-files",
+        "--max_files",  # for backwards compatibility
         type=int,
         default=None,
         help="The maximum number of files to be loaded from the given paths.",
     )
     parser.add_argument(
-        "--add_train_data",
+        "--add-train-data",
+        "--add_train_data",  # for backwards compatibility
         action="store_true",
         help="Add the training data to the evaluation.",
         default=False,
     )
 
-    # TODO: ignore fish/ consider_names
-
     if args is None:
         args = parser.parse_args()
 
diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py
index 1f4d2af..793eb31 100644
--- a/src/robofish/evaluate/evaluate.py
+++ b/src/robofish/evaluate/evaluate.py
@@ -31,6 +31,7 @@ def evaluate_speed(
     labels: Iterable[str] = None,
     predicate: Callable[[robofish.io.Entity], bool] = None,
     speeds_turns_from_paths: Iterable[Iterable[np.ndarray]] = None,
+    max_files: int = None,
 ) -> matplotlib.figure.Figure:
     """Evaluate the speed of the entities as histogram.
 
@@ -47,7 +48,7 @@ def evaluate_speed(
 
     if speeds_turns_from_paths is None:
         speeds_turns_from_paths, _ = utils.get_all_data_from_paths(
-            paths, "speeds_turns", predicate=predicate
+            paths, "speeds_turns", predicate=predicate, max_files=max_files
         )
 
     speeds = []
@@ -92,6 +93,7 @@ def evaluate_turn(
     predicate: Callable[[robofish.io.Entity], bool] = None,
     speeds_turns_from_paths: Iterable[Iterable[np.ndarray]] = None,
     file_settings: dict = None,
+    max_files: int = None,
 ) -> matplotlib.figure.Figure:
     """Evaluate the turn angles of the entities as histogram.
 
@@ -109,7 +111,7 @@ def evaluate_turn(
 
     if speeds_turns_from_paths is None:
         speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths(
-            paths, "speeds_turns", predicate=predicate
+            paths, "speeds_turns", predicate=predicate, max_files=max_files
         )
 
     turns = []
@@ -156,6 +158,7 @@ def evaluate_orientation(
     predicate: Callable[[robofish.io.Entity], bool] = None,
     poses_from_paths: Iterable[Iterable[np.ndarray]] = None,
     file_settings: dict = None,
+    max_files: int = None,
 ) -> matplotlib.figure.Figure:
     """Evaluate the orientations of the entities on a 2d grid.
 
@@ -172,7 +175,7 @@ def evaluate_orientation(
     """
     if poses_from_paths is None:
         poses_from_paths, file_settings = utils.get_all_poses_from_paths(
-            paths, predicate=predicate
+            paths, predicate=predicate, max_files=max_files
         )
 
     world_bounds = [
@@ -243,6 +246,7 @@ def evaluate_relative_orientation(
     predicate: Callable[[robofish.io.Entity], bool] = None,
     poses_from_paths: Iterable[Iterable[np.ndarray]] = None,
     file_settings: dict = None,
+    max_files: int = None,
 ) -> matplotlib.figure.Figure:
     """Evaluate the relative orientations of the entities as a histogram.
 
@@ -260,7 +264,7 @@ def evaluate_relative_orientation(
 
     if poses_from_paths is None:
         poses_from_paths, file_settings = utils.get_all_poses_from_paths(
-            paths, predicate
+            paths, predicate, max_files=max_files
         )
 
     orientations = []
@@ -299,6 +303,7 @@ def evaluate_distance_to_wall(
     predicate: Callable[[robofish.io.Entity], bool] = None,
     poses_from_paths: Iterable[Iterable[np.ndarray]] = None,
     file_settings: dict = None,
+    max_files: int = None,
 ) -> matplotlib.figure.Figure:
     """Evaluate the distances of the entities to the walls as a histogram.
     Lambda function example: lambda e: e.category == "fish"
@@ -315,7 +320,7 @@ def evaluate_distance_to_wall(
 
     if poses_from_paths is None:
         poses_from_paths, file_settings = utils.get_all_poses_from_paths(
-            paths, predicate
+            paths, predicate, max_files=max_files
         )
 
     world_bounds = [
@@ -387,6 +392,7 @@ def evaluate_tank_position(
     poses_from_paths: Iterable[Iterable[np.ndarray]] = None,
     file_settings: dict = None,
     max_points: int = 4000,
+    max_files: int = None,
 ) -> matplotlib.figure.Figure:
     """Evaluate the positions of the entities as a heatmap.
     Lambda function example: lambda e: e.category == "fish"
@@ -404,7 +410,7 @@ def evaluate_tank_position(
 
     if poses_from_paths is None:
         poses_from_paths, file_settings = utils.get_all_poses_from_paths(
-            paths, predicate
+            paths, predicate, max_files=max_files
         )
 
     xy_positions = []
@@ -479,18 +485,13 @@ def evaluate_quiver(
     """
     try:
         import torch
-        from fish_models.models.pascals_lstms.attribution import SocialVectors
     except ImportError:
-        print(
-            "Either torch or fish_models is not installed.",
-            "The social vector should come from robofish io.",
-            "This is a known issue (#24)",
-        )
+        print("Torch is not installed and could not be imported")
         return
 
     if poses_from_paths is None:
         poses_from_paths, file_settings = utils.get_all_poses_from_paths(
-            paths, predicate, predicate
+            paths, predicate, max_files=max_files
         )
     if speeds_turns_from_paths is None:
         speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths(
@@ -547,10 +548,11 @@ def evaluate_quiver(
                 # print(tank_directions[x,y] - tank_directions_speed[x,y])
                 tank_count[x, y] = len(d)
 
-    sv = SocialVectors(poses_from_paths[0])
-    sv_r = torch.tensor(sv.social_vectors_without_focal_zeros)[:, :, :-1].reshape(
-        (-1, 3)
-    )  # [:1000]
+    sv_r = torch.tensor(
+        robofish.io.utils.social_vectors_without_focal_zeros(
+            poses_from_paths[0][:, :, :-1]
+        ).reshape((-1, 4))
+    )
     sv_r = torch.cat(
         (sv_r[:, :2], torch.cos(sv_r[:, 2:]), torch.sin(sv_r[:, 2:])), dim=1
     )
@@ -631,6 +633,7 @@ def evaluate_social_vector(
     predicate: Callable[[robofish.io.Entity], bool] = None,
     poses_from_paths: Iterable[Iterable[np.ndarray]] = None,
     file_settings: dict = None,
+    max_files: int = None,
 ) -> matplotlib.figure.Figure:
     """Evaluate the vectors pointing from the focal fish to the conspecifics as heatmap.
     Lambda function example: lambda e: e.category == "fish"
@@ -645,17 +648,9 @@ def evaluate_social_vector(
         matplotlib.figure.Figure: The figure of the social vectors.
     """
 
-    try:
-        from fish_models.models.pascals_lstms.attribution import SocialVectors
-    except ImportError as e:
-        warnings.warn(
-            "Please install the fish_models package to use this function.\n", e
-        )
-        return
-
     if poses_from_paths is None:
         poses_from_paths, file_settings = utils.get_all_poses_from_paths(
-            paths, predicate
+            paths, predicate, max_files=max_files
         )
 
     fig, ax = plt.subplots(
@@ -675,8 +670,8 @@ def evaluate_social_vector(
             axis=-1,
         )
 
-        social_vec = SocialVectors(poses).social_vectors_without_focal_zeros
-        flat_sv = social_vec.reshape((-1, 3))
+        social_vec = robofish.io.utils.social_vectors_without_focal_zeros(poses)
+        flat_sv = social_vec.reshape((-1, 4))
 
         bins = (
             30 if flat_sv.shape[0] < 40000 else 50 if flat_sv.shape[0] < 65000 else 100
@@ -700,6 +695,7 @@ def evaluate_follow_iid(
     predicate: Callable[[robofish.io.Entity], bool] = None,
     poses_from_paths: Iterable[Iterable[np.ndarray]] = None,
     file_settings: dict = None,
+    max_files: int = None,
 ) -> matplotlib.figure.Figure:
     """Evaluate the follow metric in respect to the inter individual distance (iid).
     Lambda function example: lambda e: e.category == "fish"
@@ -716,7 +712,7 @@ def evaluate_follow_iid(
 
     if poses_from_paths is None:
         poses_from_paths, file_settings = utils.get_all_poses_from_paths(
-            paths, predicate
+            paths, predicate, max_files=max_files
         )
 
     follow, iid = [], []
@@ -792,12 +788,11 @@ def evaluate_follow_iid(
     # Created an issue
     fig = plt.figure(figsize=(12, 5))
 
-    
     gs = gridspec.GridSpec(1, len(grids))
-    
-    for i in range(len(grids)):    
-        SeabornFig2Grid(grids[i], fig, gs[i])    
-    
+
+    for i in range(len(grids)):
+        SeabornFig2Grid(grids[i], fig, gs[i])
+
     return fig
 
 
@@ -808,6 +803,7 @@ def evaluate_tracks_distance(
     poses_from_paths: Iterable[Iterable[np.ndarray]] = None,
     file_settings: dict = None,
     max_timesteps: int = 4000,
+    max_files: int = None,
 ) -> matplotlib.figure.Figure:
     """Evaluate the distances of two or more fish on the track.
     Lambda function example: lambda e: e.category == "fish"
@@ -828,6 +824,7 @@ def evaluate_tracks_distance(
         predicate,
         lw_distances=True,
         max_timesteps=max_timesteps,
+        max_files=max_files,
     )
 
 
@@ -841,7 +838,7 @@ def evaluate_tracks(
     seed: int = 42,
     max_timesteps: int = None,
     verbose: bool = False,
-    max_files: int = None
+    max_files: int = None,
 ) -> matplotlib.figure.Figure:
     """Evaluate the distances of two or more fish on the track.
     Lambda function example: lambda e: e.category == "fish"
@@ -867,7 +864,7 @@ def evaluate_tracks(
 
     max_files_per_path = max([len(files) for files in files_per_path])
     min_files_per_path = min([len(files) for files in files_per_path])
-    
+
     rows, cols = len(files_per_path), min(6, max_files_per_path)
 
     multirow = False
@@ -1317,26 +1314,26 @@ def calculate_distLinePoint(
 def show_values(
     pc: matplotlib.collections.PolyCollection, fmt: str = "%.2f", **kw: dict
 ) -> None:
-    """Show numbers on plt.ax.pccolormesh plot.
-
-    https://stackoverflow.com/questions/25071968/
-    heatmap-with-text-in-each-cell-with-matplotlibs-pyplot
+    """Show numbers on plt.ax.pcolormesh plot.
 
     Args:
         pc(matplotlib.collections.PolyCollection): The plot to show the values on.
         fmt(str): The format of the values.
         kw(dict): The keyword arguments.
     """
+    values = pc.get_array()
     pc.update_scalarmappable()
     ax = pc.axes
-    for p, color, value in zip(pc.get_paths(), pc.get_facecolors(), pc.get_array()):
+
+    for p, color, value in zip(pc.get_paths(), pc.get_facecolors(), values.ravel()):
         x, y = p.vertices[:-2, :].mean(0)
-        if np.all(color[:3] > 0.5):
-            color = (0.0, 0.0, 0.0)
-        else:
-            color = (1.0, 1.0, 1.0)
-        if np.all(value == "--"):
-            ax.text(x, y, fmt % value, ha="center", va="center", color=color, **kw)
+
+        # Choose text color based on background
+        text_color = (0.0, 0.0, 0.0) if np.all(color[:3] > 0.5) else (1.0, 1.0, 1.0)
+
+        # Only show non-masked values
+        if not np.ma.is_masked(value):
+            ax.text(x, y, fmt % value, ha="center", va="center", color=text_color, **kw)
 
 
 class SeabornFig2Grid:
diff --git a/src/robofish/io/app.py b/src/robofish/io/app.py
index 4aec636..f0f03a3 100644
--- a/src/robofish/io/app.py
+++ b/src/robofish/io/app.py
@@ -41,14 +41,16 @@ def print_file(args: argparse.Namespace = None) -> bool:
 
     parser.add_argument("path", type=str, help="The path to a hdf5 file")
     parser.add_argument(
-        "--output_format",
+        "--output-format",
+        "--output_format",  # backwards compatibility
         type=str,
         choices=["shape", "full"],
         default="shape",
         help="Choose how datasets are printed, either the shapes or the full content is printed",
     )
     parser.add_argument(
-        "--full_attrs",
+        "--full-attrs",
+        "--full_attrs",  # backwards compatibility
         default=False,
         action="store_true",
         help="Show full unabbreviated values for attributes",
@@ -122,7 +124,8 @@ def validate(args: argparse.Namespace = None) -> int:
         description="The function can be directly accessed from the commandline and can be given any number of files or folders. The function returns the validity of the files in a human readable format or as a raw output."
     )
     parser.add_argument(
-        "--output_format",
+        "--output-format",
+        "--output_format",  # backwards compatibility
         type=str,
         default="h",
         choices=["h", "raw"],
@@ -179,7 +182,9 @@ def render_file(kwargs: Dict) -> None:
         kwargs (Dict, optional): A dictionary containing the arguments for the render function.
     """
     with robofish.io.File(path=kwargs["path"]) as f:
-        f.render(**{k:v for k,v in kwargs.items() if k not in ("path", "reference_track")})
+        f.render(
+            **{k: v for k, v in kwargs.items() if k not in ("path", "reference_track")}
+        )
 
 
 def overwrite_user_configs() -> None:
@@ -248,7 +253,7 @@ def render(args: argparse.Namespace = None) -> None:
         default=[],
         help="Custom colors to use for guppies. Use spaces as delimiter. "
         "To set all guppies to the same color, pass only one color. "
-        "Hexadecimal values, color names and matplotlib abbreviations are supported (\"#000000\", black, k)"
+        'Hexadecimal values, color names and matplotlib abbreviations are supported ("#000000", black, k)',
     )
 
     default_options = {
@@ -267,7 +272,7 @@ def render(args: argparse.Namespace = None) -> None:
         "render_swarm_center": False,
         "highlight_switches": False,
         "figsize": 10,
-        "fov_smoothing_factor": 0.8
+        "fov_smoothing_factor": 0.8,
     }
 
     for key, value in default_options.items():
diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py
index 8af3217..0afc5cd 100644
--- a/src/robofish/io/file.py
+++ b/src/robofish/io/file.py
@@ -936,7 +936,7 @@ class File(h5py.File):
         else:
             step_size = poses.shape[1]
 
-        cmap = matplotlib.cm.get_cmap(cmap)
+        cmap = plt.get_cmap(cmap)
 
         x_world, y_world = self.world_size
         if figsize is None:
@@ -1026,7 +1026,7 @@ class File(h5py.File):
         custom_colors: bool = None,
         dpi: int = 200,
         figsize: int = 10,
-        fov_smoothing_factor: float = 0.8
+        fov_smoothing_factor: float = 0.8,
     ) -> None:
         """Render a video of the file.
 
@@ -1102,26 +1102,50 @@ class File(h5py.File):
         if render_zones:
             fovs = []
             for ei, e in enumerate(self.entities):
-                zone_sizes_str = get_zone_sizes_from_model_str(self.attrs.get("guppy_model_rollout", ""))
+                zone_sizes_str = get_zone_sizes_from_model_str(
+                    self.attrs.get("guppy_model_rollout", "")
+                )
                 zone_sizes_attrs = get_zone_sizes_from_attrs(e)
-                
+
                 # Check that there are no zone sizes in the model and in the attributes
-                assert zone_sizes_str == {} or zone_sizes_attrs == {}, "There are zone sizes in the model and in the attributes. Please use only one (preferrably the attributes)."
-                zone_sizes = zone_sizes_attrs if zone_sizes_attrs != {} else zone_sizes_str
-                
-                fov = zone_sizes.get("fov", np.pi*2)
+                assert (
+                    zone_sizes_str == {} or zone_sizes_attrs == {}
+                ), "There are zone sizes in the model and in the attributes. Please use only one (preferrably the attributes)."
+                zone_sizes = (
+                    zone_sizes_attrs if zone_sizes_attrs != {} else zone_sizes_str
+                )
+
+                fov = zone_sizes.get("fov", np.pi * 2)
                 fov = np.rad2deg(fov)
                 zone_sizes.pop("fov", None)
                 fovs.append(fov)
-                
+
                 entity_zones = []
                 for zone_size in zone_sizes.values():
                     if fov >= 360:
-                        entity_zones.append(matplotlib.patches.Circle((0, 0), zone_size, color=fish_colors[ei], alpha=0.2, fill=False))
+                        entity_zones.append(
+                            matplotlib.patches.Circle(
+                                (0, 0),
+                                zone_size,
+                                color=fish_colors[ei],
+                                alpha=0.2,
+                                fill=False,
+                            )
+                        )
                     else:
-                        entity_zones.append(matplotlib.patches.Wedge((0,0), zone_size, theta1=-fov/2, theta2=fov/2, color=fish_colors[ei], alpha=0.2, fill=False))
+                        entity_zones.append(
+                            matplotlib.patches.Wedge(
+                                (0, 0),
+                                zone_size,
+                                theta1=-fov / 2,
+                                theta2=fov / 2,
+                                color=fish_colors[ei],
+                                alpha=0.2,
+                                fill=False,
+                            )
+                        )
                 zones.append(entity_zones)
-                
+
         zones_flat = []
         for zones_fish in zones:
             for zone in zones_fish:
@@ -1321,7 +1345,7 @@ class File(h5py.File):
         pbar = tqdm(range(n_frames)) if video_path is not None else None
 
         self.fov_orientations = [None] * n_entities
-        
+
         def update(frame):
             output_list = []
 
@@ -1399,19 +1423,27 @@ class File(h5py.File):
                 poses_trails = entity_poses[:, max(0, file_frame - trail) : file_frame]
                 for i_entity in range(n_entities):
                     new_ori = this_pose[i_entity, 2]
-                    old_ori = self.fov_orientations[i_entity] if self.fov_orientations[i_entity] is not None else new_ori
+                    old_ori = (
+                        self.fov_orientations[i_entity]
+                        if self.fov_orientations[i_entity] is not None
+                        else new_ori
+                    )
 
                     # Mix in complex space to handle the wrap around
-                    smoothed_ori = fov_smoothing_factor * np.exp(1j * old_ori) + (1-fov_smoothing_factor) * np.exp(1j * new_ori)
+                    smoothed_ori = fov_smoothing_factor * np.exp(1j * old_ori) + (
+                        1 - fov_smoothing_factor
+                    ) * np.exp(1j * new_ori)
                     self.fov_orientations[i_entity] = np.angle(smoothed_ori)
-                    
+
                     ori_deg = np.rad2deg(self.fov_orientations[i_entity])
-                    
-                    for zone in zones[i_entity]:    
-                        zone.set_center((
-                            this_pose[i_entity, 0],
-                            this_pose[i_entity, 1],
-                        ))
+
+                    for zone in zones[i_entity]:
+                        zone.set_center(
+                            (
+                                this_pose[i_entity, 0],
+                                this_pose[i_entity, 1],
+                            )
+                        )
                         if fovs[i_entity] < 360:
                             zone.theta1 = ori_deg - fovs[i_entity] / 2
                             zone.theta2 = ori_deg + fovs[i_entity] / 2
@@ -1482,18 +1514,11 @@ def get_zone_sizes_from_attrs(e: robofish.io.Entity) -> Dict[str, float]:
         possible_params = ["fov", "zor", "zoa", "zoo", "preferred_d", "neighbor_radius"]
         model_params = e["model"]["parameters"].attrs
         return {p: model_params[p] for p in possible_params if p in model_params}
-        
+
 
 def get_zone_sizes_from_model_str(model: str) -> Dict[str, float]:
     zone_sizes = {}
-    for zone in [
-        "zor",
-        "zoa",
-        "zoo",
-        "preferred_d",
-        "neighbor_radius",
-        "fov"
-    ]:
+    for zone in ["zor", "zoa", "zoo", "preferred_d", "neighbor_radius", "fov"]:
         match = re.search(r"{}=(\d+(?:\.\d+)?)".format(zone), model)
         if match:
             value = float(match.group(1))
diff --git a/src/robofish/io/utils.py b/src/robofish/io/utils.py
index 940879b..ebc18f1 100644
--- a/src/robofish/io/utils.py
+++ b/src/robofish/io/utils.py
@@ -48,7 +48,7 @@ def get_all_files_from_paths(paths: Iterable[Union[str, Path]], max_files=None):
             files_path = []
             for ext in ("hdf", "hdf5", "h5", "he5"):
                 files_path += list(path.rglob(f"*.{ext}"))
-            
+
             files.append(sorted(files_path)[:max_files])
         else:
             files.append([path])
@@ -123,3 +123,71 @@ def get_all_data_from_paths(
         all_data.append(data_from_files)
     pbar.close()
     return all_data, expected_settings
+
+
+def social_vectors(poses):
+    """
+    Args:
+        poses (np.ndarray): n_tracks, n_fish, n_timesteps, (x,y, ori_rad)
+    """
+
+    assert not np.any(np.isnan(poses)), "poses contains NaNs"
+
+    n_tracks, n_fish, n_timesteps, three = poses.shape
+    assert three == 3, "poses.shape[-1] must be 3 (x,y,ori_rad). Got {three}"
+
+    sv = np.zeros((n_tracks, n_fish, n_timesteps, n_fish, 4))
+    for focal_id in range(n_fish):
+        for other_id in range(n_fish):
+            if focal_id == other_id:
+                continue
+
+            # Compute relative orienation
+            # sv[:, focal_id, :, other_id, 2] = (
+            # poses[:, other_id, :, 2] - poses[:, focal_id, :, 2]
+            # )
+            relative_orientation = poses[:, other_id, :, 2] - poses[:, focal_id, :, 2]
+            sv[:, focal_id, :, other_id, 2:4] = np.stack(
+                [np.cos(relative_orientation), np.sin(relative_orientation)], axis=-1
+            )
+
+            # Raw social vectors
+            xy_offset = poses[:, other_id, :, :2] - poses[:, focal_id, :, :2]
+
+            # Rotate them, so they are from the POV of the focal fish
+
+            # to convert a vector from world coords to fish coords (when the fish has
+            # an orientation of alpha), you need to rotate the vector by pi/2 first
+            # and then rotate it back by alpha
+            rotate_by = np.pi / 2 - poses[:, focal_id, :, 2]
+
+            # perform rotation of raw social vectors
+            sin, cos = np.sin(rotate_by), np.cos(rotate_by)
+            sv[:, focal_id, :, other_id, 0] = (
+                xy_offset[:, :, 0] * cos - xy_offset[:, :, 1] * sin
+            )
+            sv[:, focal_id, :, other_id, 1] = (
+                xy_offset[:, :, 0] * sin + xy_offset[:, :, 1] * cos
+            )
+
+            assert not np.isnan(sv).any(), np.isnan(sv)
+
+            # assert length of social vectors was preserved
+            assert True or np.all(
+                np.isclose(
+                    np.linalg.norm(
+                        social_vectors[:, focal_id, :, other_id, :2], axis=2
+                    ),
+                    np.linalg.norm(xy_offset, axis=2),
+                )
+            ), "The length of the social vector was not preserved."
+
+    return sv
+
+
+def social_vectors_without_focal_zeros(poses):
+    sv = social_vectors(poses)
+    mask = np.full_like(sv, fill_value=True, dtype=bool)
+    for focal_id in range(sv.shape[1]):
+        mask[:, focal_id, :, focal_id] = False
+    return sv[mask].reshape((sv.shape[0], sv.shape[1], sv.shape[2], sv.shape[3] - 1, 4))
diff --git a/tests/robofish/evaluate/test_app_evaluate.py b/tests/robofish/evaluate/test_app_evaluate.py
index 4b1332f..1d0ff26 100644
--- a/tests/robofish/evaluate/test_app_evaluate.py
+++ b/tests/robofish/evaluate/test_app_evaluate.py
@@ -22,6 +22,7 @@ class DummyArgs:
         self.save_path = save_path
         self.labels = None
         self.add_train_data = add_train_data
+        self.max_files = None
 
 
 def test_app_validate(tmp_path):
diff --git a/tests/robofish/io/test_app_io.py b/tests/robofish/io/test_app_io.py
index 6c3a16f..c9c7515 100644
--- a/tests/robofish/io/test_app_io.py
+++ b/tests/robofish/io/test_app_io.py
@@ -22,8 +22,8 @@ def test_app_validate():
     with pytest.warns(UserWarning):
         raw_output = app.validate(DummyArgs([resources_path], "raw"))
 
-    # The three files valid.hdf5, almost_valid.hdf5, and invalid.hdf5 should be found.
-    assert len(raw_output) == 4
+    # The three files valid.hdf5, almost_valid.hdf5, valid_couzin_params.hdf5, and invalid.hdf5 should be found.
+    assert len(raw_output) == 5
 
     # invalid.hdf5 should pass a warning
     with pytest.warns(UserWarning):
-- 
GitLab