diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index ca734447e8860aa4bf28d3832dd63ee04ef20e89..c45bb741bf2414992c2a7f19b3c621ca7fec458a 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -29,7 +29,7 @@ import warnings from pathlib import Path from subprocess import run from textwrap import wrap -from typing import Iterable, List, Optional, Tuple, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union import deprecation import h5py @@ -1097,23 +1097,25 @@ class File(h5py.File): if categories[i] == "organism" ] - zone_sizes = [ - ( - get_zone_sizes(self.attrs.get("guppy_model_rollout", "")) - if render_zones - else {} - ) - for _ in range(n_fish) - ] - zones = [ - [ - plt.Circle( - (0, 0), zone_size, color=fish_colors[i], alpha=0.2, fill=False - ) - for zone_size in zone_sizes_fish.values() - ] - for i, zone_sizes_fish in enumerate(zone_sizes) - ] + zones = [] + if render_zones: + for ei, e in enumerate(self.entities): + zone_sizes_str = get_zone_sizes_from_model_str(self.attrs.get("guppy_model_rollout", "")) + zone_sizes_attrs = get_zone_sizes_from_attrs(e) + + # Check that there are no zone sizes in the model and in the attributes + assert zone_sizes_str == {} or zone_sizes_attrs == {}, "There are zone sizes in the model and in the attributes. Please use only one (preferrably the attributes)." + zone_sizes = zone_sizes_attrs if zone_sizes_attrs != {} else zone_sizes_str + + fov = zone_sizes.get("fov", np.pi*2) + fov = np.rad2deg(fov) + zone_sizes.pop("fov", None) + + entity_zones = [] + for zone_size in zone_sizes.values(): + entity_zones.append(matplotlib.patches.Arc((0,0), zone_size, zone_size, angle=0, theta1=-fov/2, theta2=fov/2, color=fish_colors[ei], alpha=0.3, fill=False)) + zones.append(entity_zones) + zones_flat = [] for zones_fish in zones: for zone in zones_fish: @@ -1388,12 +1390,12 @@ class File(h5py.File): poses_trails = entity_poses[:, max(0, file_frame - trail) : file_frame] for i_entity in range(n_entities): - if categories[i_entity] == "organism": - for zone in zones[i_entity]: - zone.center = ( - this_pose[i_entity, 0], - this_pose[i_entity, 1], - ) + for zone in zones[i_entity]: + zone.center = ( + this_pose[i_entity, 0], + this_pose[i_entity, 1], + ) + zone.angle = this_pose[i_entity, 2] * 180 / np.pi if render_swarm_center: swarm_center[0].set_offsets(swarm_center_position[file_frame]) @@ -1454,7 +1456,16 @@ class File(h5py.File): plt.show() -def get_zone_sizes(model: str): +def get_zone_sizes_from_attrs(e: robofish.io.Entity) -> Dict[str, float]: + if "model" not in e: + return {} + else: + possible_params = ["fov", "zor", "zoa", "zoo", "preferred_d", "neighbor_radius"] + model_params = e["model"]["parameters"].attrs + return {p: model_params[p] for p in possible_params if p in model_params} + + +def get_zone_sizes_from_model_str(model: str) -> Dict[str, float]: zone_sizes = {} for zone in [ "zor", @@ -1462,6 +1473,7 @@ def get_zone_sizes(model: str): "zoo", "preferred_d", "neighbor_radius", + "fov" ]: match = re.search(r"{}=(\d+(?:\.\d+)?)".format(zone), model) if match: diff --git a/tests/resources/valid_couzin_params.hdf5 b/tests/resources/valid_couzin_params.hdf5 new file mode 100644 index 0000000000000000000000000000000000000000..e8c7ef32b4b929f0db806dadbfd266ab1eaca66b Binary files /dev/null and b/tests/resources/valid_couzin_params.hdf5 differ