diff --git a/src/robofish/io/entity.py b/src/robofish/io/entity.py index c2e7a67d128039d075d4ece787ac59f2a900e3f8..7ff5b340a33d63d8d8b6430378702cd0020b0d53 100644 --- a/src/robofish/io/entity.py +++ b/src/robofish/io/entity.py @@ -236,7 +236,7 @@ class Entity(h5py.Group): "merge to the master branch of fish_models if nothing helps, contact Andi.\n" "Don't ignore this warning, it's a serious issue.", ) - def speed_turn(self): + def speed_turn(self) -> np.ndarray: """Get the speed, turn and from the positions. The vectors pointing from each position to the next are computed. diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index 106396022174bcb7b80be1386080a7104c2d4118..ca734447e8860aa4bf28d3832dd63ee04ef20e89 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -29,7 +29,7 @@ import warnings from pathlib import Path from subprocess import run from textwrap import wrap -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Tuple, Union import deprecation import h5py @@ -61,7 +61,7 @@ class File(h5py.File): def __init__( self, - path: Optinal[Union[str, Path]] = None, + path: Optional[Union[str, Path]] = None, mode: str = "r", *, # PEP 3102 world_size_cm: Optional[List[int]] = None, @@ -150,6 +150,8 @@ class File(h5py.File): self.validate_when_saving = validate_when_saving self.calculate_data_on_close = calculate_data_on_close + self._entities = None + if open_copy: assert ( path is not None @@ -502,6 +504,8 @@ class File(h5py.File): sampling=sampling, ) + self._entities = None # Reset the entities cache + return entity def create_multiple_entities( @@ -584,10 +588,12 @@ class File(h5py.File): @property def entities(self): - return [ - robofish.io.Entity.from_h5py_group(self["entities"][name]) - for name in self.entity_names - ] + if self._entities is None: + self._entities = [ + robofish.io.Entity.from_h5py_group(self["entities"][name]) + for name in self.entity_names + ] + return self._entities @property def entity_positions(self): @@ -702,7 +708,11 @@ class File(h5py.File): else: properties = [entity_property.__get__(entity) for entity in entities] - max_timesteps = max([0] + [p.shape[0] for p in properties]) + n_timesteps = [p.shape[0] for p in properties] + max_timesteps = max(n_timesteps) + + if np.all(np.equal(n_timesteps, max_timesteps)): + return np.array(properties) property_array = np.empty( (len(entities), max_timesteps, properties[0].shape[1]) @@ -1068,13 +1078,14 @@ class File(h5py.File): categories = [entity.attrs.get("category", None) for entity in self.entities] n_fish = len([c for c in categories if c == "organism"]) - lines = [ plt.plot( [], [], lw=linewidth, - color=custom_colors[i%len(custom_colors)-1] if custom_colors else None, + color=custom_colors[i % len(custom_colors) - 1] + if custom_colors + else None, zorder=0, )[0] for i in range(n_entities)