diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index e734258dbe18b4b387417884533127b19b76f317..94ce119bdbf39869708435d4fb723a01df271f20 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -1084,37 +1084,27 @@ class File(h5py.File): [], [], lw=linewidth, - color=custom_colors[i % len(custom_colors) - 1] - if custom_colors - else None, + color=( + custom_colors[i % len(custom_colors) - 1] if custom_colors else None + ), zorder=0, )[0] for i in range(n_entities) ] entity_colors = [lines[entity].get_color() for entity in range(n_entities)] - fish_colors = [ - color - for i, color in enumerate(entity_colors) - if categories[i] == "organism" - ] zones = [] if render_zones: fovs = [] - for ei, e in enumerate(self.entities): - zone_sizes_str = get_zone_sizes_from_model_str( - self.attrs.get("guppy_model_rollout", "") + for entity, color in { + entity: color + for entity, color, category in zip( + self.entities, entity_colors, categories ) - zone_sizes_attrs = get_zone_sizes_from_attrs(e) - - # Check that there are no zone sizes in the model and in the attributes - assert ( - zone_sizes_str == {} or zone_sizes_attrs == {} - ), "There are zone sizes in the model and in the attributes. Please use only one (preferrably the attributes)." - zone_sizes = ( - zone_sizes_attrs if zone_sizes_attrs != {} else zone_sizes_str - ) - + if category == "organism" + }.items(): + zone_sizes = get_zone_sizes_from_attrs(entity) + print(f"{zone_sizes = }") fov = zone_sizes.get("fov", np.pi * 2) fov = np.rad2deg(fov) zone_sizes.pop("fov", None) @@ -1127,7 +1117,7 @@ class File(h5py.File): matplotlib.patches.Circle( (0, 0), zone_size, - color=fish_colors[ei], + color=color, alpha=0.2, fill=False, ) @@ -1139,7 +1129,7 @@ class File(h5py.File): zone_size, theta1=-fov / 2, theta2=fov / 2, - color=fish_colors[ei], + color=color, alpha=0.2, fill=False, ) @@ -1514,24 +1504,11 @@ def get_zone_sizes_from_attrs(e: robofish.io.Entity) -> Dict[str, float]: else: possible_params = ["fov", "zor", "zoa", "zoo", "preferred_d", "neighbor_radius"] model_params = e["model"]["parameters"].attrs - return {p: model_params[p] for p in possible_params if p in model_params} - - -def get_zone_sizes_from_model_str(model: str) -> Dict[str, float]: - zone_sizes = {} - for zone in ["zor", "zoa", "zoo", "preferred_d", "neighbor_radius", "fov"]: - match = re.search(r"{}=(\d+(?:\.\d+)?)".format(zone), model) - if match: - value = float(match.group(1)) - zone_sizes[zone] = value - else: - match = re.search(r"'{}': (\d+(?:\.\d+)?)".format(zone), model) - if match: - value = float(match.group(1)) - zone_sizes[zone] = value - - match = re.search(r"additive_zone_sizes[:,\'\" ]*(\b(?:True|False)\b)", model) - if match and match.group(1) == "True": - zone_sizes["zoo"] += zone_sizes["zor"] - zone_sizes["zoa"] += zone_sizes["zoo"] - return zone_sizes + zone_sizes = {p: model_params[p] for p in possible_params if p in model_params} + if ( + "additive_zone_sizes" in model_params + and model_params["additive_zone_sizes"] + ): + zone_sizes["zoo"] = zone_sizes["zor"] + zone_sizes["zoo"] + zone_sizes["zoa"] = zone_sizes["zoo"] + zone_sizes["zoa"] + return zone_sizes