From 5c588e8c8a456f38c118f94499f13525c23ce3ef Mon Sep 17 00:00:00 2001
From: Andi Gerken <andi.gerken@gmail.com>
Date: Thu, 15 Dec 2022 18:00:09 +0000
Subject: [PATCH] Improved switch detection

---
 src/conversion_scripts/convert_from_csv.py | 257 ++++++++++-----------
 src/conversion_scripts/create_test_csv.py  |  33 +++
 src/conversion_scripts/print_file.py       |  41 ++++
 src/robofish/io/app.py                     |   1 +
 src/robofish/io/file.py                    |  29 ++-
 5 files changed, 224 insertions(+), 137 deletions(-)
 create mode 100644 src/conversion_scripts/create_test_csv.py
 create mode 100644 src/conversion_scripts/print_file.py

diff --git a/src/conversion_scripts/convert_from_csv.py b/src/conversion_scripts/convert_from_csv.py
index e437cd8..def488e 100644
--- a/src/conversion_scripts/convert_from_csv.py
+++ b/src/conversion_scripts/convert_from_csv.py
@@ -15,6 +15,7 @@ The script is not finished and still in progress.
 import pandas as pd
 import numpy as np
 import argparse
+import itertools
 from pathlib import Path
 import robofish.io
 
@@ -27,41 +28,7 @@ except ImportError as e:
         return x
 
 
-DEFAULT_COLUMNS = [
-    "Framenumber",
-    "TimeHuman",
-    "Time",
-    "Type1",
-    "ID",
-    "Robo x",
-    "Robo y",
-    "Robo ori deg",
-    "Robo ori rad",
-    "Type2",
-    "ID2",
-    "Fish x",
-    "Fish y",
-    "Fish ori deg",
-    "Fish ori rad",
-    "Mode",
-    "FishModel",
-    "LeadSubexperiment",
-    "fear",
-    "follow",
-    "cz",
-    "rz",
-    "lz",
-    "v_max",
-    "adapt",
-    "ADir",
-    "ASpeed",
-    "AClose",
-    "LSpeed",
-    "LClose",
-]
-
-
-def get_distances(poses, last_poses=None):
+def get_distances(poses, last_poses=None, diagonal=False):
     def get_single_distance(pose_1, pose_2):
         assert pose_1.shape == pose_2.shape == (3,)
         return np.linalg.norm(pose_1[:2] - pose_2[:2])
@@ -71,10 +38,11 @@ def get_distances(poses, last_poses=None):
         distances = np.zeros((n_fish, n_fish))
         for i in range(n_fish):
             for j in range(n_fish):
-                if last_poses is None:
-                    distances[i, j] = get_single_distance(poses[i], poses[j])
-                else:
-                    distances[i, j] = get_single_distance(poses[i], last_poses[j])
+                if i == j or not diagonal:
+                    if last_poses is None:
+                        distances[i, j] = get_single_distance(poses[i], poses[j])
+                    else:
+                        distances[i, j] = get_single_distance(poses[i], last_poses[j])
     else:
         n_frames, n_fish, three = poses.shape
         distances = np.zeros((n_frames, n_fish, n_fish))
@@ -82,11 +50,96 @@ def get_distances(poses, last_poses=None):
         for t in range(n_frames - 1):
             for i in range(n_fish):
                 for j in range(n_fish):
-                    assert last_poses is None
-                    distances[t, i, j] = get_single_distance(
-                        poses[t + 1, i], poses[t, j]
-                    )
-    return distances
+                    if i == j or not diagonal:
+                        assert last_poses is None
+                        distances[t, i, j] = get_single_distance(
+                            poses[t + 1, i], poses[t, j]
+                        )
+    if diagonal:
+        return np.diagonal(distances, axis1=-1, axis2=-2)
+    else:
+        return distances
+
+
+def handle_switches(poses, supress=None):
+    """Find and handle switches in the data.
+    Switches are defined as a fish that is not at the same position as it was in the previous frame.
+    If a switch is found, the fish is moved to the position of the fish that was at the same position in the previous frame.
+    If multiple switches are found, the fish that is closest to the position of the fish that was at the same position in the previous frame is moved.
+
+    Args:
+        poses (np.ndarray): Array of shape (n_frames, n_fish, 3) containing the poses of the fish.
+        supress (list): List of frames to ignore.
+    Returns:
+        np.ndarray: Array of shape (n_frames, n_fish, 3) containing the poses of the fish with switches handled.
+    """
+
+    n_fish = poses.shape[1]
+
+    all_switches = []
+
+    last_poses = np.copy(poses[0])
+
+    all_connection_permutations = list(itertools.permutations(np.arange(n_fish)))
+
+    for t in range(1, poses.shape[0]):
+        if np.all(np.isclose(np.abs(np.diff(poses[t - 1], axis=0)), 0)):
+            print(
+                f"Warning: All fish at same position in frame {t - 1} setting them to NaN"
+            )
+            poses[t - 1] = np.nan
+
+        if supress is None or t not in supress:
+
+            distances = get_distances(poses[t], last_poses)
+
+            switches = {}
+            for i in range(n_fish):
+                if np.argmin(distances[i]) != i:
+                    switches[i] = np.argmin(distances[i])
+
+            if len(switches) > 0 or np.any(
+                np.isnan(poses[t]) != np.isnan(poses[t - 1])
+            ):
+
+                switch_distances = np.array(
+                    [
+                        get_distances(poses[t, con_perm], last_poses, diagonal=True)
+                        for con_perm in all_connection_permutations
+                    ]
+                )
+
+                switch_distances_sum = np.nansum(switch_distances, axis=1)
+
+                # if t > 10 and t < 18:
+                #    print("Poses\n", poses[t], "\nLast Poses\n", last_poses)
+                #    print(t, "\n", switch_distances[:2], "\n", switch_distances_sum[:2])
+
+                connections = np.array(
+                    all_connection_permutations[np.argmin(switch_distances_sum)]
+                ).astype(int)
+
+                distances_normal = switch_distances_sum[0]
+                distances_switched = np.min(switch_distances_sum)
+
+                if np.argmin(switch_distances_sum) != 0:
+                    if distances_switched * 1.5 < distances_normal:
+
+                        print(
+                            f"Switch: {connections} distance sum\t{np.min(switch_distances_sum):.2f} vs\t{np.min(switch_distances_sum[0]):.2f}"
+                        )
+                        poses[t:] = poses[t:, connections]
+
+                        all_switches.append(t)
+
+        # Update last poses for every fish that is not nan
+        last_poses[np.where(~np.isnan(poses[t]))] = poses[t][
+            np.where(~np.isnan(poses[t]))
+        ]
+
+        assert not np.any(np.isnan(last_poses)), "Error: NaN in last_poses"
+
+    return poses, all_switches
 
 
 def handle_file(file, args):
@@ -114,7 +167,6 @@ def handle_file(file, args):
                 all_col_types_matching.append(cols)
             # print(cols, "\t", matching_types)
 
-        # pf.columns = DEFAULT_COLUMNS
         if len(all_col_types_matching) == 0:
             print(
                 "Error: Could not detect columns_per_entity. Please specify manually using --columns_per_entity"
@@ -128,7 +180,7 @@ def handle_file(file, args):
         print("Found columns_per_entity: %d" % columns_per_entity)
 
     else:
-        columns_per_entity = args.columns_per_entity
+        columns_per_entity = int(args.columns_per_entity)
 
     n_fish = len(pf.columns) // columns_per_entity
     header_cols = len(pf.columns) % n_fish
@@ -167,113 +219,48 @@ def handle_file(file, args):
                 + f * columns_per_entity : header_cols
                 + (f + 1) * columns_per_entity,
             ]
-            poses[:, f] = f_cols[:, [args.xcol, args.ycol, args.oricol]].astype(
-                np.float32
-            )
-
-        poses[:, :, :2] -= args.world_size / 2  #  center world around 0,0
-        poses[:, :, 1] *= -1  #                    flip y axis
-        poses[:, :, 2] *= 2 * np.pi / 360  #       convert to radians
-        poses[:, :, 2] -= np.pi  #                 rotate 180 degrees
-
-        def handle_switches(poses, supress=None):
-            """Find and handle switches in the data.
-            Switches are defined as a fish that is not at the same position as it was in the previous frame.
-            If a switch is found, the fish is moved to the position of the fish that was at the same position in the previous frame.
-            If multiple switches are found, the fish that is closest to the position of the fish that was at the same position in the previous frame is moved.
-
-            Args:
-                poses (np.ndarray): Array of shape (n_frames, n_fish, 3) containing the poses of the fish.
-                supress (list): List of frames to ignore.
-            Returns:
-                np.ndarray: Array of shape (n_frames, n_fish, 3) containing the poses of the fish with switches handled.
-            """
-
-            all_switches = []
-            last_poses = np.copy(poses[0])
-            assert not np.any(np.isnan(last_poses)), "Error: NaN in first frame"
-            for t in range(1, poses.shape[0]):
-                if not np.any(np.isnan(poses[t])) and (
-                    supress is None or t not in supress
-                ):
-
-                    distances = get_distances(poses[t], last_poses)
-
-                    switches = {}
-                    for i in range(n_fish):
-                        if np.argmin(distances[i]) != i:
-                            switches[i] = np.argmin(distances[i])
-
-                    if len(switches) > 0:
-                        print(f"Switches at time {t}: {switches}")
-
-                        if sorted(switches.keys()) == sorted(switches.values()):
-                            print("Switches are one-to-one")
-
-                            for i in switches:
-                                print(list(switches.keys()), list(switches.values()))
-                            connections = np.arange(n_fish, dtype=int)
-                            for k, v in switches.items():
-                                connections[v] = k
-                        else:
-                            print("Attempting to fix switches...")
-
-                            smallest_distances = np.argsort(distances.flatten())
-
-                            connections = np.empty(n_fish)
-                            connections[:] = np.nan
-
-                            print(distances)
-                            print(
-                                smallest_distances[0] // n_fish,
-                                smallest_distances[0] % n_fish,
-                            )
-                            for sd in smallest_distances:
-                                if np.isnan(connections[sd // n_fish]) and not np.any(
-                                    connections == sd % n_fish
-                                ):
-                                    connections[sd // n_fish] = sd % n_fish
-
-                                if not np.any(np.isnan(connections)):
-                                    break
-                            assert np.sum(connections) == np.sum(
-                                range(n_fish)
-                            )  # Simple check to see if all fish are connected
-
-                        connections = connections.astype(int)
-                        print(f"Connections: {connections}")
-                        poses[t:] = poses[t:, connections]
-
-                        all_switches.append(t)
+            poses[:, f] = f_cols[
+                :, [int(args.xcol), int(args.ycol), int(args.oricol)]
+            ].astype(np.float32)
 
-                    last_poses = poses[t]
-            return poses, all_switches
+        if not args.disable_centering:
+            poses[:, :, :2] -= args.world_size / 2  #  center world around 0,0
+            poses[:, :, 1] *= -1  #                    flip y axis
+            poses[:, :, 2] *= 2 * np.pi / 360  #       convert to radians
+            poses[:, :, 2] -= np.pi  #                 rotate 180 degrees
 
         supress = []
 
+        all_switches = None
         if not args.disable_fix_switches:
             for run in range(20):
                 print("RUN ", run)
                 switched_poses, all_switches = handle_switches(poses)
-
-                print(all_switches)
+                print("All switches: ", all_switches)
                 diff = np.diff(all_switches, axis=0)
 
-                new_supress = np.array(all_switches)[np.where(diff < 10)]
+                new_supress = np.array(all_switches)[
+                    np.where(diff < args.min_timesteps_between_switches)
+                ]
                 new_supress = [n for n in new_supress if n not in supress]
+
                 if len(new_supress) == 0:
+                    print("No need to supress more.")
                     break
-                print(supress)
+
                 supress.extend(new_supress)
+                print("supressing in the next runn: ", supress)
 
-            distances = get_distances(switched_poses)
-            switched_poses[
-                np.where(np.diagonal(distances > 1, axis1=1, axis2=2))
-            ] = np.nan
+            # distances = get_distances(switched_poses)
+            # switched_poses[
+            #    np.where(np.diagonal(distances > 1, axis1=1, axis2=2))
+            # ] = np.nan
             poses = switched_poses
 
         for f in range(n_fish):
             iof.create_entity("fish", poses[:, f])
+        if all_switches is not None:
+            iof.attrs["switches"] = all_switches
 
         # assert np.all(
         #     poses[np.logical_not(np.isnan(poses[:, 0])), 0] >= 0
@@ -299,6 +286,8 @@ parser.add_argument("--oricol", default=7)
 parser.add_argument("--world_size", default=100)
 parser.add_argument("--frequency", default=25)
 parser.add_argument("--disable_fix_switches", action="store_true")
+parser.add_argument("--disable_centering", action="store_true")
+parser.add_argument("--min_timesteps_between_switches", type=int, default=0)
 args = parser.parse_args()
 
 for path in args.path:
diff --git a/src/conversion_scripts/create_test_csv.py b/src/conversion_scripts/create_test_csv.py
new file mode 100644
index 0000000..40623ec
--- /dev/null
+++ b/src/conversion_scripts/create_test_csv.py
@@ -0,0 +1,33 @@
+import pandas as pd
+import numpy as np
+
+# python convert_from_csv.py test.csv --sep , --header 1 --columns_per_entity 3 --xcol 0 --ycol 1 --oricol 2 --disable_fix_switches
+
+df = pd.DataFrame(columns=["header", "x1", "y1", "o1", "x2", "y2", "o2"])
+
+df["header"] = 0
+
+df["x1"] = np.linspace(20, 80, 60)
+df["x2"] = np.linspace(20, 80, 60)
+df["y1"] = 40
+df["y2"] = 60
+
+# Switch y1 and y2 after 50 samples
+df.loc[20:30, "y1"] = 60
+df.loc[20:30, "y2"] = 40
+
+# Set x1 and y1 to np.nan for 5 samples
+df.loc[40:45, "x1"] = np.nan
+df.loc[40:45, "y1"] = np.nan
+df.loc[46:50, "y1"] = 60
+df.loc[46:50, "y2"] = 40
+df.loc[51:55, "x2"] = np.nan
+df.loc[51:55, "y2"] = np.nan
+
+
+df["o1"] = 180
+df["o2"] = 180
+
+
+print(df)
+df.to_csv("test.csv", index=False)
diff --git a/src/conversion_scripts/print_file.py b/src/conversion_scripts/print_file.py
new file mode 100644
index 0000000..03455b1
--- /dev/null
+++ b/src/conversion_scripts/print_file.py
@@ -0,0 +1,41 @@
+import robofish.io
+import pandas as pd
+import numpy as np
+
+
+def print_file(args):
+
+    # Open the robofish.io.File
+    with robofish.io.File(args.filename, "r") as f:
+        timesteps = (
+            args.max_timesteps + args.skip_timesteps
+            if args.max_timesteps is not None
+            else f.entity_positions.shape[1]
+        )
+
+        n_fish = len(f.entities)
+        df = pd.DataFrame(
+            columns=np.concatenate([[f"x{i}", f"y{i}", f"o{i}"] for i in range(n_fish)])
+        )
+
+        for i in range(n_fish):
+            df[f"x{i}"] = f.entity_positions[i][args.skip_timesteps : timesteps, 0]
+            df[f"y{i}"] = f.entity_positions[i][args.skip_timesteps : timesteps, 1]
+            df[f"o{i}"] = f.entity_orientations_rad[i][
+                args.skip_timesteps : timesteps, 0
+            ]
+
+        print(df)
+
+    df.to_csv("cutout.csv", index=False)
+
+
+if __name__ == "__main__":
+    import argparse
+
+    parser = argparse.ArgumentParser(description="Print a file")
+    parser.add_argument("filename", help="File to print")
+    parser.add_argument("--skip_timesteps", type=int, default=0, help="Skip timesteps")
+    parser.add_argument("--max_timesteps", type=int, default=None, help="Max timesteps")
+    args = parser.parse_args()
+    print_file(args)
diff --git a/src/robofish/io/app.py b/src/robofish/io/app.py
index 16f714d..eaf9d96 100644
--- a/src/robofish/io/app.py
+++ b/src/robofish/io/app.py
@@ -235,6 +235,7 @@ def render(args: argparse.Namespace = None) -> None:
         "show_ids": False,
         "render_goals": False,
         "render_targets": False,
+        "highlight_switches": False,
         "figsize": 10,
     }
 
diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py
index b3f7913..114e417 100644
--- a/src/robofish/io/file.py
+++ b/src/robofish/io/file.py
@@ -1035,6 +1035,7 @@ class File(h5py.File):
             "show_ids": False,
             "render_goals": False,
             "render_targets": False,
+            "highlight_switches": False,
             "dpi": 200,
             "figsize": 10,
         }
@@ -1086,16 +1087,20 @@ class File(h5py.File):
             plt.plot([], [], linestyle="dotted", alpha=0.5, color="k", zorder=0)[0],
         ]
         categories = [entity.attrs.get("category", None) for entity in self.entities]
+        entity_colors = [lines[entity].get_color() for entity in range(n_entities)]
 
         entity_polygons = [
             patches.Polygon(
                 shape_vertices(options["entity_scale"]),
                 edgecolor=edgecolor,
                 facecolor=color,
+                alpha=0.8,
             )
             for edgecolor, color in [
-                ("k", "white") if category == "robot" else ("k", "k")
-                for category in categories
+                ("k", "white")
+                if category == "robot"
+                else (entity_colors[entity], entity_colors[entity])
+                for entity, category in enumerate(categories)
             ]
         ]
 
@@ -1197,6 +1202,8 @@ class File(h5py.File):
             pbar = tqdm(range(n_frames))
 
         def update(frame):
+            output_list = []
+
             if "pbar" in locals().keys():
                 pbar.update(1)
                 pbar.refresh()
@@ -1207,6 +1214,19 @@ class File(h5py.File):
                 file_frame = (frame * options["speedup"]) + frame_range[0]
                 this_pose = entity_poses[:, file_frame]
 
+                if options["highlight_switches"] and "switches" in self.attrs:
+                    if any(
+                        [
+                            file_frame + i in self.attrs["switches"]
+                            for i in range(options["speedup"])
+                        ]
+                    ):
+                        ax.set_facecolor("lightgray")
+
+                    else:
+                        ax.set_facecolor("white")
+                    output_list.append(ax)
+
                 if not options["fixed_view"]:
 
                     # Find the maximal distance between the entities in x or y direction
@@ -1284,7 +1304,10 @@ class File(h5py.File):
                 raise Exception(
                     f"Frame is bigger than n_frames {file_frame} of {n_frames}"
                 )
-            return lines + entity_polygons + [border] + points + annotations
+
+            return (
+                output_list + lines + entity_polygons + [border] + points + annotations
+            )
 
         print(f"Preparing to render n_frames: {n_frames}")
 
-- 
GitLab