From 33228dc560d84d3acf0dceba27d22f4c85009036 Mon Sep 17 00:00:00 2001 From: Mathis Hocke <mathis.hocke@fu-berlin.de> Date: Wed, 26 Jun 2024 15:03:04 +0200 Subject: [PATCH] Render swarm center --- src/robofish/io/app.py | 2 ++ src/robofish/io/file.py | 42 +++++++++++++++++++++++++++++++++++++++-- 2 files changed, 42 insertions(+), 2 deletions(-) diff --git a/src/robofish/io/app.py b/src/robofish/io/app.py index 21f8365..08bebc9 100644 --- a/src/robofish/io/app.py +++ b/src/robofish/io/app.py @@ -236,6 +236,8 @@ def render(args: argparse.Namespace = None) -> None: "show_ids": False, "render_goals": False, "render_targets": False, + "render_zones": False, + "render_swarm_center": False, "highlight_switches": False, "figsize": 10, } diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index f289c87..97026b6 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -1060,6 +1060,8 @@ class File(h5py.File): "show_ids": False, "render_goals": False, "render_targets": False, + "render_zones": False, + "render_swarm_center": False, "highlight_switches": False, "dpi": 200, "figsize": 10, @@ -1081,6 +1083,26 @@ class File(h5py.File): plt.plot([], [], lw=options["linewidth"], zorder=0)[0] for _ in range(n_entities) ] + zone_sizes = ( + get_zone_sizes(self.attrs.get("guppy_model_rollout", "")) + if options["render_zones"] + else {} + ) + zones = [ + plt.Circle((0, 0), zone_size, color="k", alpha=0.5, fill=False) + for zone_size in zone_sizes.values() + ] + if options["render_swarm_center"] and len([e for e in self.entities if e.category == "organism"]) > 0: + swarm_center_position = np.stack( + [e.positions for e in self.entities if e.category == "organism"] + ).mean(axis=0) + swarm_center = [ + plt.scatter([], [], marker=".", color="k"), + ] + else: + swarm_center = [] + for zone in zones: + ax.add_artist(zone) annotations = [] if options["show_ids"]: ids = None @@ -1233,7 +1255,7 @@ class File(h5py.File): for e_poly in entity_polygons: ax.add_patch(e_poly) ax.add_patch(border) - return lines + entity_polygons + [border] + points + return lines + entity_polygons + [border] + points + zones + swarm_center n_frames = self.entity_poses.shape[1] @@ -1348,6 +1370,15 @@ class File(h5py.File): poses_trails = entity_poses[ :, max(0, file_frame - options["trail"]) : file_frame ] + for zone in zones: + zone.center = ( + this_pose[0, 0], + this_pose[0, 1], + ) + + if options["render_swarm_center"]: + swarm_center[0].set_offsets(swarm_center_position[file_frame]) + for i_entity in range(n_entities): lines[i_entity].set_data( poses_trails[i_entity, :, 0], poses_trails[i_entity, :, 1] @@ -1366,7 +1397,14 @@ class File(h5py.File): ) return ( - output_list + lines + entity_polygons + [border] + points + annotations + output_list + + lines + + entity_polygons + + [border] + + points + + annotations + + zones + + swarm_center ) print(f"Preparing to render n_frames: {n_frames}") -- GitLab