From 8feae667bc6c42c4d51e149cbce7e94f00eaac6b Mon Sep 17 00:00:00 2001
From: Andi <andi.gerken@gmail.com>
Date: Tue, 25 Jul 2023 12:42:22 +0200
Subject: [PATCH] Added world shape

---
 src/conversion_scripts/convert_from_csv.py | 30 ++++++++++++-------
 src/robofish/evaluate/app.py               | 35 ++++++++++++++++++++--
 src/robofish/evaluate/evaluate.py          | 35 ++++++++++++----------
 src/robofish/io/file.py                    | 28 +++++++++++++----
 src/robofish/io/validation.py              |  7 +++++
 5 files changed, 101 insertions(+), 34 deletions(-)

diff --git a/src/conversion_scripts/convert_from_csv.py b/src/conversion_scripts/convert_from_csv.py
index 5ef4107..aadec8b 100644
--- a/src/conversion_scripts/convert_from_csv.py
+++ b/src/conversion_scripts/convert_from_csv.py
@@ -141,6 +141,10 @@ def handle_switches(poses, supress=None):
 
 def handle_file(file, args):
     pf = pd.read_csv(file, skiprows=4, header=None, sep=args.sep)
+    assert (
+        len(pf.columns) >= 2
+    ), f"{pf}\nThere were only {len(pf.columns)} columns in {file}. Looks like the seperator was maybe wrong? Choose it with --sep. Default is ;"
+
     if args.columns_per_entity is None:
         all_col_types_matching = []
         for cols in range(1, len(pf.columns) // 2 + 1):
@@ -163,18 +167,15 @@ def handle_file(file, args):
                 all_col_types_matching.append(cols)
             # print(cols, "\t", matching_types)
 
-        if len(all_col_types_matching) == 0:
-            print(
-                "Error: Could not detect columns_per_entity. Please specify manually using --columns_per_entity"
-            )
-        else:
-            assert [
-                all_col_types_matching[i] % all_col_types_matching[0] == 0
-                for i in range(0, len(all_col_types_matching))
-            ], f"Error: Found multiple columns_per_entity which were not multiples of each other {all_col_types_matching}"
-            columns_per_entity = all_col_types_matching[0]
+        assert (
+            len(all_col_types_matching) > 0
+        ), f"Error: Could not detect columns_per_entity. Please specify manually using --columns_per_entity\ndatatypes were {dt}"
+        assert [
+            all_col_types_matching[i] % all_col_types_matching[0] == 0
+            for i in range(0, len(all_col_types_matching))
+        ], f"Error: Found multiple columns_per_entity which were not multiples of each other {all_col_types_matching}"
+        columns_per_entity = all_col_types_matching[0]
         print("Found columns_per_entity: %d" % columns_per_entity)
-
     else:
         columns_per_entity = int(args.columns_per_entity)
 
@@ -204,6 +205,7 @@ def handle_file(file, args):
         io_file_path,
         "w",
         world_size_cm=[args.world_size, args.world_size],
+        world_shape=args.world_shape,
         frequency_hz=args.frequency,
     ) as iof:
         poses = np.empty((pf_np.shape[0], n_fish, 3), dtype=np.float32)
@@ -380,6 +382,7 @@ parser = argparse.ArgumentParser(
 )
 parser.add_argument("path", nargs="+")
 parser.add_argument("-o", "--output", default=None)
+parser.add_argument("--world_shape", type=str)
 parser.add_argument("--sep", default=";")
 parser.add_argument("--header", default=3)
 parser.add_argument("--columns_per_entity", default=None)
@@ -399,6 +402,11 @@ parser.add_argument(
 )
 args = parser.parse_args()
 
+assert args.world_shape in [
+    "rectangle",
+    "ellipse",
+], f"--world_shape has to be 'rectangle' or 'ellipse' but is '{args.world_shape}'"
+
 if args.analysis_path is not None:
     args.analysis_path = Path(args.analysis_path)
 
diff --git a/src/robofish/evaluate/app.py b/src/robofish/evaluate/app.py
index f0c7526..c7d0a44 100644
--- a/src/robofish/evaluate/app.py
+++ b/src/robofish/evaluate/app.py
@@ -14,6 +14,7 @@ import robofish.evaluate
 import argparse
 from pathlib import Path
 import matplotlib.pyplot as plt
+from copy import copy
 
 
 def function_dict() -> dict:
@@ -98,6 +99,12 @@ def evaluate(args: dict = None) -> None:
         help="Filename for saving resulting graphics.",
         default=None,
     )
+    parser.add_argument(
+        "--add_train_data",
+        action="store_true",
+        help="Add the training data to the evaluation.",
+        default=False,
+    )
 
     # TODO: ignore fish/ consider_names
 
@@ -108,11 +115,33 @@ def evaluate(args: dict = None) -> None:
         raise ValueError("When the analysis type is all, a --save_path must be given.")
 
     if args.analysis_type in fdict:
-        if args.labels is None:
-            args.labels = args.paths
+        paths = args.paths
+        labels = args.labels
+        print("starting", paths, labels)
+
+        if labels is None:
+            labels = copy(paths)
+
+        if args.add_train_data:
+            # Open any file to get the path
+            if len(paths) > 1:
+                warnings.warn(
+                    "Multiple paths given. Only the first path is used for the training data."
+                )
+
+            files = list(Path(args.paths[0]).rglob("*.hdf5"))
+            if len(files) == 0:
+                warnings.warn("No hdf5 files found in the given path.")
+            with robofish.io.File(files[0]) as f:
+                train_data = str(f.attrs["training_data"])
+
+                paths += [train_data]
+                labels += ["/".join(Path(train_data).parts[-2:])]
+
+        print("starting", paths, labels)
 
         save_path = None if args.save_path is None else Path(args.save_path)
-        params = {"paths": args.paths, "labels": args.labels}
+        params = {"paths": paths, "labels": labels}
         if args.analysis_type == "all":
             normal_functions = function_dict()
             normal_functions.pop("all")
diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py
index 1d5b3f4..bede8f7 100644
--- a/src/robofish/evaluate/evaluate.py
+++ b/src/robofish/evaluate/evaluate.py
@@ -185,7 +185,6 @@ def evaluate_orientation(
     orientations = []
     # Iterate all paths
     for poses_per_path in poses_from_paths:
-
         # Iterate all files
         for poses in poses_per_path:
             reshaped_poses = poses.reshape((-1, 4))
@@ -340,7 +339,6 @@ def evaluate_distance_to_wall(
 
         # Iterate all files
         for poses in poses_per_path:
-
             for e_poses in poses:
                 dist = []
                 for wall in wall_lines:
@@ -414,7 +412,6 @@ def evaluate_tank_position(
     # Iterate all paths
 
     for poses_per_path in poses_from_paths:
-
         # Get all positions for each file (skipping steps with poses_step), flatten them to (..., 2) and concatenate all positions.
         new_xy_positions = np.concatenate(
             [p[:, :, :2].reshape(-1, 2) for p in poses_per_path], axis=0
@@ -427,7 +424,6 @@ def evaluate_tank_position(
         ax = [ax]
 
     for i in range(len(xy_positions)):
-
         ax[i].set_xlim(
             -file_settings["world_size_cm_x"] / 2, file_settings["world_size_cm_x"] / 2
         )
@@ -682,14 +678,21 @@ def evaluate_social_vector(
             axis=-1,
         )
 
-        print(np.stack(poses).shape)
         social_vec = SocialVectors(poses).social_vectors_without_focal_zeros
         flat_sv = social_vec.reshape((-1, 3))
 
+        bins = (
+            30 if flat_sv.shape[0] < 40000 else 50 if flat_sv.shape[0] < 65000 else 100
+        )
+
         ax[i].hist2d(
-            flat_sv[:, 0], flat_sv[:, 1], range=[[-7.5, 7.5], [-7.5, 7.5]], bins=100
+            flat_sv[:, 0], flat_sv[:, 1], range=[[-7.5, 7.5], [-7.5, 7.5]], bins=bins
         )
         ax[i].set_title(labels[i])
+        if bins < 100:
+            ax[i].set_title(
+                f"{labels[i]} downsampled to {bins} bins.\nFor more details generate {60000/flat_sv.shape[0]:.1f} times as much."
+            )
     plt.suptitle("Social Vectors")
     return fig
 
@@ -734,9 +737,13 @@ def evaluate_follow_iid(
                             calculate_iid(poses[i, :-1, :2], poses[j, :-1, :2])
                         )
 
-                        path_follow.append(
-                            calculate_follow(poses[i, :, :2], poses[j, :, :2])
-                        )
+                        file_follow = calculate_follow(poses[i, :, :2], poses[j, :, :2])
+
+                        # Remove inf values and replace them with 0
+                        file_follow = np.where(np.isnan(file_follow), 0, file_follow)
+                        file_follow = np.where(np.isinf(file_follow), 0, file_follow)
+
+                        path_follow.append(file_follow)
 
         follow.append(np.concatenate([np.atleast_1d(f) for f in path_follow]))
         iid.append(np.concatenate([np.atleast_1d(i) for i in path_iid]))
@@ -753,7 +760,6 @@ def evaluate_follow_iid(
     max_iid = np.quantile(np.concatenate(iid), 0.9)
 
     for i in range(len(follow)):
-
         # Mask data which is outside of the ranges
         mask = (
             (follow[i] > -1 * follow_range)
@@ -914,7 +920,6 @@ def evaluate_tracks(
                     if verbose:
                         print(f"Using file {file_path}")
                     with robofish.io.File(new_file_path, "r") as new_file:
-
                         new_file.plot(
                             selected_ax,
                             lw_distances=lw_distances,
@@ -925,13 +930,11 @@ def evaluate_tracks(
                             f"{selected_ax.get_title()} (reference track: id {initial_poses_info['track_id']})",
                         )
             else:
-
                 with robofish.io.File(file_path, "r") as file:
                     if (
                         "initial_poses_info" in file
                         and "track_id" in file["initial_poses_info"].attrs
                     ):
-
                         # Save the corresponding file that a model mimics.
                         initial_poses_info = {
                             a: file["initial_poses_info"].attrs[a]
@@ -1139,6 +1142,7 @@ def evaluate_all(
     predicate: Callable[[robofish.io.Entity], bool] = None,
     max_files: int = None,
     evaluations: List[str] = None,
+    file_format: str = "png",
 ) -> Iterable[Path]:
     """Generate all evaluation graphs and save them to a folder.
 
@@ -1184,7 +1188,7 @@ def evaluate_all(
         if evaluations is None or f_name in evaluations:
             t.set_description(f_name)
             t.refresh()  # to show the update immediately
-            save_path = save_folder / (f_name + ".png")
+            save_path = save_folder / (f_name + "." + file_format)
 
             requested_inputs = {
                 k: input_dict[k]
@@ -1195,7 +1199,8 @@ def evaluate_all(
                 paths=paths, labels=labels, predicate=predicate, **requested_inputs
             )
             if fig is not None:
-                fig.savefig(save_path)
+                # Increase resolution
+                fig.savefig(save_path, dpi=300, bbox_inches="tight")
                 plt.close(fig)
                 save_paths.append(save_path)
 
diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py
index acdec48..8df570f 100644
--- a/src/robofish/io/file.py
+++ b/src/robofish/io/file.py
@@ -67,7 +67,7 @@ class File(h5py.File):
         mode: str = "r",
         *,  # PEP 3102
         world_size_cm: List[int] = None,
-        world_shape: str = "rectangle",
+        world_shape: str = None,
         validate: bool = False,
         validate_when_saving: bool = True,
         strict_validate: bool = False,
@@ -211,6 +211,15 @@ class File(h5py.File):
             ), "world_size_cm and format_version have to be given when creating a new file."
 
             self.attrs["world_size_cm"] = np.array(world_size_cm, dtype=np.float32)
+
+            assert (
+                world_shape is not None
+            ), "world_shape has to be given when creating a new file."
+            assert world_shape in [
+                "rectangle",
+                "ellipse",
+            ], f"Unknown world shape {world_shape}."
+
             self.attrs["world_shape"] = world_shape
             self.attrs["format_version"] = np.array(format_version, dtype=np.int32)
             self.attrs["format_url"] = format_url
@@ -405,6 +414,15 @@ class File(h5py.File):
     def world_size(self):
         return self.attrs["world_size_cm"]
 
+    @property
+    def world_shape(self):
+        if "world_shape" not in self.attrs:
+            warnings.warn(
+                "File did not have a world_shape attribute. Assuming rectangle."
+            )
+            return "rectangle"
+        return self.attrs["world_shape"]
+
     @property
     def default_sampling(self):
         assert (
@@ -855,7 +873,7 @@ class File(h5py.File):
         lw: int = 2,
         ms: int = 32,
         figsize: Tuple[int] = None,
-        step_size: int = 4,
+        step_size: int = 25,
         c: List = None,
         cmap: matplotlib.colors.Colormap = "Set1",
         skip_timesteps=0,
@@ -1117,12 +1135,12 @@ class File(h5py.File):
             y = np.array([-1, -1, 1, 1, -1]) * self.world_size[1] / 2
             return np.array([x, y])
 
-        if "world_shape" not in self.attrs or self.attrs["world_shape"] == "rectangle":
+        if self.world_shape == "rectangle":
             border_vertices = create_square()
-        elif self.attrs["world_shape"] == "ellipse":
+        elif self.world_shape == "ellipse":
             border_vertices = create_circle(150)
         else:
-            raise ValueError(f"Unknown world shape: {self.attrs['world_shape']}")
+            raise ValueError(f"Unknown world shape: {self.world_shape}")
 
         spacing = 10
         x = np.arange(
diff --git a/src/robofish/io/validation.py b/src/robofish/io/validation.py
index b211e10..1371a63 100644
--- a/src/robofish/io/validation.py
+++ b/src/robofish/io/validation.py
@@ -92,6 +92,13 @@ def validate(iofile: File, strict_validate: bool = True) -> (bool, str):
             assert_validate(a in iofile.attrs, f'Attribute "{a}" missing","root')
             assert_validate_type(iofile.attrs[a], a_type, a, "root")
 
+        # Validate world shape
+        if "world_shape" in iofile.attrs:
+            assert iofile.attrs["world_shape"] in [
+                "rectangle",
+                "ellipse",
+            ], f"Invalid world shape {iofile.attrs['world_shape']}. Allowed are 'rectangle' and 'ellipse'."
+
         # validate samplings
         assert_validate("samplings" in iofile, "samplings not found")
         for s_name, sampling in iofile["samplings"].items():
-- 
GitLab