From 4a1bb3742bcf79bbdf31df65711f7e8aa6d732ab Mon Sep 17 00:00:00 2001
From: Mathis Hocke <mathis.hocke@fu-berlin.de>
Date: Tue, 6 Jun 2023 19:11:37 +0200
Subject: [PATCH] Display zones for Couzin and Klamser models

---
 src/robofish/io/file.py | 49 +++++++++++++++++++++++++++++++++++++++--
 1 file changed, 47 insertions(+), 2 deletions(-)

diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py
index a75927c..dcb4340 100644
--- a/src/robofish/io/file.py
+++ b/src/robofish/io/file.py
@@ -34,6 +34,7 @@ import types
 import warnings
 from textwrap import wrap
 import platform
+import re
 
 import matplotlib
 import matplotlib.pyplot as plt
@@ -1057,6 +1058,13 @@ class File(h5py.File):
             plt.plot([], [], lw=options["linewidth"], zorder=0)[0]
             for _ in range(n_entities)
         ]
+        zone_sizes = get_zone_sizes(self.attrs.get("guppy_model_rollout", ""))
+        zones = [
+            plt.Circle((0, 0), zone_size, color="k", alpha=0.5, fill=False)
+            for zone_size in zone_sizes.values()
+        ]
+        for zone in zones:
+            ax.add_artist(zone)
         annotations = []
         if options["show_ids"]:
             ids = None
@@ -1182,7 +1190,7 @@ class File(h5py.File):
             for e_poly in entity_polygons:
                 ax.add_patch(e_poly)
             ax.add_patch(border)
-            return lines + entity_polygons + [border] + points
+            return lines + entity_polygons + [border] + points + zones
 
         n_frames = self.entity_poses.shape[1]
 
@@ -1299,6 +1307,11 @@ class File(h5py.File):
                 poses_trails = entity_poses[
                     :, max(0, file_frame - options["trail"]) : file_frame
                 ]
+                for zone in zones:
+                    zone.center = (
+                        this_pose[0, 0],
+                        this_pose[0, 1],
+                    )
                 for i_entity in range(n_entities):
                     lines[i_entity].set_data(
                         poses_trails[i_entity, :, 0], poses_trails[i_entity, :, 1]
@@ -1317,7 +1330,13 @@ class File(h5py.File):
                 )
 
             return (
-                output_list + lines + entity_polygons + [border] + points + annotations
+                output_list
+                + lines
+                + entity_polygons
+                + [border]
+                + points
+                + annotations
+                + zones
             )
 
         print(f"Preparing to render n_frames: {n_frames}")
@@ -1350,3 +1369,29 @@ class File(h5py.File):
             plt.close()
         else:
             plt.show()
+
+
+def get_zone_sizes(model: str):
+    zone_sizes = {}
+    for zone in [
+        "zor",
+        "zoa",
+        "zoo",
+        "preferred_d",
+        "neighbor_radius",
+    ]:
+        match = re.search(r"{}=(\d+(?:\.\d+)?)".format(zone), model)
+        if match:
+            value = float(match.group(1))
+            zone_sizes[zone] = value
+        else:
+            match = re.search(r"'{}': (\d+(?:\.\d+)?)".format(zone), model)
+            if match:
+                value = float(match.group(1))
+                zone_sizes[zone] = value
+
+    match = re.search(r"additive_zone_sizes[:,\'\" ]*(\b(?:True|False)\b)", model)
+    if match and match.group(1) == "True":
+        zone_sizes["zoo"] += zone_sizes["zor"]
+        zone_sizes["zoa"] += zone_sizes["zoo"]
+    return zone_sizes
-- 
GitLab