From f8561bfbf1bdcc93c7e14da0c34344effdf93c38 Mon Sep 17 00:00:00 2001 From: Mathis Hocke <mathis.hocke@fu-berlin.de> Date: Tue, 6 Jun 2023 19:11:37 +0200 Subject: [PATCH] Display zones for Couzin and Klamser models --- src/robofish/io/app.py | 7 +-- src/robofish/io/file.py | 102 ++++++++++++++++++++++++++++++---------- 2 files changed, 81 insertions(+), 28 deletions(-) diff --git a/src/robofish/io/app.py b/src/robofish/io/app.py index 08bebc9..cf33493 100644 --- a/src/robofish/io/app.py +++ b/src/robofish/io/app.py @@ -198,7 +198,8 @@ def render(args: argparse.Namespace = None) -> None: """ parser = argparse.ArgumentParser( - description="This function shows the file as animation." + description="This function shows the file as animation.", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument( "path", @@ -248,14 +249,14 @@ def render(args: argparse.Namespace = None) -> None: f"--{key}", default=value, action="store_true" if value is False else "store_false", - help=f"Optional setter for video option {key}.\tDefault: {value}", + help=f"Optional setter for video option {key}.", ) else: parser.add_argument( f"--{key}", default=value, type=type(value), - help=f"Optional video option {key} with type {type(value)}.\tDefault: {value}", + help=f"Optional video option {key} with type {type(value)}.", ) if args is None: diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index 97026b6..66d3b5a 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -20,6 +20,7 @@ from __future__ import annotations import datetime import logging import platform +import re import shutil import tempfile import types @@ -1079,20 +1080,43 @@ class File(h5py.File): ax.set_facecolor("gray") plt.tight_layout(pad=0.05) n_entities = len(self.entities) + categories = [entity.attrs.get("category", None) for entity in self.entities] + n_fish = len([c for c in categories if c == "organism"]) + lines = [ plt.plot([], [], lw=options["linewidth"], zorder=0)[0] for _ in range(n_entities) ] - zone_sizes = ( - get_zone_sizes(self.attrs.get("guppy_model_rollout", "")) - if options["render_zones"] - else {} - ) + entity_colors = [lines[entity].get_color() for entity in range(n_entities)] + fish_colors = [ + color + for i, color in enumerate(entity_colors) + if categories[i] == "organism" + ] + + zone_sizes = [ + ( + get_zone_sizes(self.attrs.get("guppy_model_rollout", "")) + if options["render_zones"] + else {} + ) + for _ in range(n_fish) + ] zones = [ - plt.Circle((0, 0), zone_size, color="k", alpha=0.5, fill=False) - for zone_size in zone_sizes.values() + [ + plt.Circle( + (0, 0), zone_size, color=fish_colors[i], alpha=0.2, fill=False + ) + for zone_size in zone_sizes_fish.values() + ] + for i, zone_sizes_fish in enumerate(zone_sizes) ] - if options["render_swarm_center"] and len([e for e in self.entities if e.category == "organism"]) > 0: + zones_flat = [] + for zones_fish in zones: + for zone in zones_fish: + ax.add_artist(zone) + zones_flat.append(zone) + if options["render_swarm_center"] and n_fish > 0: swarm_center_position = np.stack( [e.positions for e in self.entities if e.category == "organism"] ).mean(axis=0) @@ -1101,8 +1125,6 @@ class File(h5py.File): ] else: swarm_center = [] - for zone in zones: - ax.add_artist(zone) annotations = [] if options["show_ids"]: ids = None @@ -1133,8 +1155,6 @@ class File(h5py.File): plt.scatter([], [], marker="x", color="k"), plt.plot([], [], linestyle="dotted", alpha=0.5, color="k", zorder=0)[0], ] - categories = [entity.attrs.get("category", None) for entity in self.entities] - entity_colors = [lines[entity].get_color() for entity in range(n_entities)] entity_polygons = [ patches.Polygon( @@ -1147,12 +1167,14 @@ class File(h5py.File): alpha=0.8, ) for edgecolor, color, entity in [ - ("k", "white", self.entities[entity]) - if category == "robot" - else ( - entity_colors[entity], - entity_colors[entity], - self.entities[entity], + ( + ("k", "white", self.entities[entity]) + if category == "robot" + else ( + entity_colors[entity], + entity_colors[entity], + self.entities[entity], + ) ) for entity, category in enumerate(categories) ] @@ -1255,7 +1277,9 @@ class File(h5py.File): for e_poly in entity_polygons: ax.add_patch(e_poly) ax.add_patch(border) - return lines + entity_polygons + [border] + points + zones + swarm_center + return ( + lines + entity_polygons + [border] + points + zones_flat + swarm_center + ) n_frames = self.entity_poses.shape[1] @@ -1370,11 +1394,13 @@ class File(h5py.File): poses_trails = entity_poses[ :, max(0, file_frame - options["trail"]) : file_frame ] - for zone in zones: - zone.center = ( - this_pose[0, 0], - this_pose[0, 1], - ) + for i_entity in range(n_entities): + if categories[i_entity] == "organism": + for zone in zones[i_entity]: + zone.center = ( + this_pose[i_entity, 0], + this_pose[i_entity, 1], + ) if options["render_swarm_center"]: swarm_center[0].set_offsets(swarm_center_position[file_frame]) @@ -1403,7 +1429,7 @@ class File(h5py.File): + [border] + points + annotations - + zones + + zones_flat + swarm_center ) @@ -1437,3 +1463,29 @@ class File(h5py.File): plt.close() else: plt.show() + + +def get_zone_sizes(model: str): + zone_sizes = {} + for zone in [ + "zor", + "zoa", + "zoo", + "preferred_d", + "neighbor_radius", + ]: + match = re.search(r"{}=(\d+(?:\.\d+)?)".format(zone), model) + if match: + value = float(match.group(1)) + zone_sizes[zone] = value + else: + match = re.search(r"'{}': (\d+(?:\.\d+)?)".format(zone), model) + if match: + value = float(match.group(1)) + zone_sizes[zone] = value + + match = re.search(r"additive_zone_sizes[:,\'\" ]*(\b(?:True|False)\b)", model) + if match and match.group(1) == "True": + zone_sizes["zoo"] += zone_sizes["zor"] + zone_sizes["zoa"] += zone_sizes["zoo"] + return zone_sizes -- GitLab