diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index a75927cd7ce09240a7c81bb9bd5c4566a506b9de..dcb43405405abe5779a7d879de47676333c9321b 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -34,6 +34,7 @@ import types import warnings from textwrap import wrap import platform +import re import matplotlib import matplotlib.pyplot as plt @@ -1057,6 +1058,13 @@ class File(h5py.File): plt.plot([], [], lw=options["linewidth"], zorder=0)[0] for _ in range(n_entities) ] + zone_sizes = get_zone_sizes(self.attrs.get("guppy_model_rollout", "")) + zones = [ + plt.Circle((0, 0), zone_size, color="k", alpha=0.5, fill=False) + for zone_size in zone_sizes.values() + ] + for zone in zones: + ax.add_artist(zone) annotations = [] if options["show_ids"]: ids = None @@ -1182,7 +1190,7 @@ class File(h5py.File): for e_poly in entity_polygons: ax.add_patch(e_poly) ax.add_patch(border) - return lines + entity_polygons + [border] + points + return lines + entity_polygons + [border] + points + zones n_frames = self.entity_poses.shape[1] @@ -1299,6 +1307,11 @@ class File(h5py.File): poses_trails = entity_poses[ :, max(0, file_frame - options["trail"]) : file_frame ] + for zone in zones: + zone.center = ( + this_pose[0, 0], + this_pose[0, 1], + ) for i_entity in range(n_entities): lines[i_entity].set_data( poses_trails[i_entity, :, 0], poses_trails[i_entity, :, 1] @@ -1317,7 +1330,13 @@ class File(h5py.File): ) return ( - output_list + lines + entity_polygons + [border] + points + annotations + output_list + + lines + + entity_polygons + + [border] + + points + + annotations + + zones ) print(f"Preparing to render n_frames: {n_frames}") @@ -1350,3 +1369,29 @@ class File(h5py.File): plt.close() else: plt.show() + + +def get_zone_sizes(model: str): + zone_sizes = {} + for zone in [ + "zor", + "zoa", + "zoo", + "preferred_d", + "neighbor_radius", + ]: + match = re.search(r"{}=(\d+(?:\.\d+)?)".format(zone), model) + if match: + value = float(match.group(1)) + zone_sizes[zone] = value + else: + match = re.search(r"'{}': (\d+(?:\.\d+)?)".format(zone), model) + if match: + value = float(match.group(1)) + zone_sizes[zone] = value + + match = re.search(r"additive_zone_sizes[:,\'\" ]*(\b(?:True|False)\b)", model) + if match and match.group(1) == "True": + zone_sizes["zoo"] += zone_sizes["zor"] + zone_sizes["zoa"] += zone_sizes["zoo"] + return zone_sizes