From 33228dc560d84d3acf0dceba27d22f4c85009036 Mon Sep 17 00:00:00 2001
From: Mathis Hocke <mathis.hocke@fu-berlin.de>
Date: Wed, 26 Jun 2024 15:03:04 +0200
Subject: [PATCH] Render swarm center

---
 src/robofish/io/app.py  |  2 ++
 src/robofish/io/file.py | 42 +++++++++++++++++++++++++++++++++++++++--
 2 files changed, 42 insertions(+), 2 deletions(-)

diff --git a/src/robofish/io/app.py b/src/robofish/io/app.py
index 21f8365..08bebc9 100644
--- a/src/robofish/io/app.py
+++ b/src/robofish/io/app.py
@@ -236,6 +236,8 @@ def render(args: argparse.Namespace = None) -> None:
         "show_ids": False,
         "render_goals": False,
         "render_targets": False,
+        "render_zones": False,
+        "render_swarm_center": False,
         "highlight_switches": False,
         "figsize": 10,
     }
diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py
index f289c87..97026b6 100644
--- a/src/robofish/io/file.py
+++ b/src/robofish/io/file.py
@@ -1060,6 +1060,8 @@ class File(h5py.File):
             "show_ids": False,
             "render_goals": False,
             "render_targets": False,
+            "render_zones": False,
+            "render_swarm_center": False,
             "highlight_switches": False,
             "dpi": 200,
             "figsize": 10,
@@ -1081,6 +1083,26 @@ class File(h5py.File):
             plt.plot([], [], lw=options["linewidth"], zorder=0)[0]
             for _ in range(n_entities)
         ]
+        zone_sizes = (
+            get_zone_sizes(self.attrs.get("guppy_model_rollout", ""))
+            if options["render_zones"]
+            else {}
+        )
+        zones = [
+            plt.Circle((0, 0), zone_size, color="k", alpha=0.5, fill=False)
+            for zone_size in zone_sizes.values()
+        ]
+        if options["render_swarm_center"] and len([e for e in self.entities if e.category == "organism"]) > 0:
+            swarm_center_position = np.stack(
+                [e.positions for e in self.entities if e.category == "organism"]
+            ).mean(axis=0)
+            swarm_center = [
+                plt.scatter([], [], marker=".", color="k"),
+            ]
+        else:
+            swarm_center = []
+        for zone in zones:
+            ax.add_artist(zone)
         annotations = []
         if options["show_ids"]:
             ids = None
@@ -1233,7 +1255,7 @@ class File(h5py.File):
             for e_poly in entity_polygons:
                 ax.add_patch(e_poly)
             ax.add_patch(border)
-            return lines + entity_polygons + [border] + points
+            return lines + entity_polygons + [border] + points + zones + swarm_center
 
         n_frames = self.entity_poses.shape[1]
 
@@ -1348,6 +1370,15 @@ class File(h5py.File):
                 poses_trails = entity_poses[
                     :, max(0, file_frame - options["trail"]) : file_frame
                 ]
+                for zone in zones:
+                    zone.center = (
+                        this_pose[0, 0],
+                        this_pose[0, 1],
+                    )
+
+                if options["render_swarm_center"]:
+                    swarm_center[0].set_offsets(swarm_center_position[file_frame])
+
                 for i_entity in range(n_entities):
                     lines[i_entity].set_data(
                         poses_trails[i_entity, :, 0], poses_trails[i_entity, :, 1]
@@ -1366,7 +1397,14 @@ class File(h5py.File):
                 )
 
             return (
-                output_list + lines + entity_polygons + [border] + points + annotations
+                output_list
+                + lines
+                + entity_polygons
+                + [border]
+                + points
+                + annotations
+                + zones
+                + swarm_center
             )
 
         print(f"Preparing to render n_frames: {n_frames}")
-- 
GitLab