diff --git a/examples/example_basic.ipynb b/examples/example_basic.ipynb index 253ff2fceacac9ad10c6399edf19a2211ed55b9e..ada157912f8586719aa91e45432645fd65248322 100644 --- a/examples/example_basic.ipynb +++ b/examples/example_basic.ipynb @@ -148,7 +148,7 @@ " sf = robofish.io.File(path=example_file)\n", "\n", " print(\"\\nEntity Names\")\n", - " print(sf.get_entity_names())\n", + " print(sf.entity_names)\n", "\n", " # Get an array with all poses. As the length of poses varies per agent, it\n", " # is filled up with nans. The result is not interpolated and the time scales\n", @@ -156,10 +156,10 @@ " # different time scales and have another function, which generates an\n", " # interpolated array.\n", " print(\"\\nAll poses\")\n", - " print(sf.get_poses_array())\n", + " print(sf.select_entity_poses())\n", "\n", " print(\"\\nFish poses\")\n", - " print(sf.get_poses_array(fish_names))\n", + " print(sf.select_entity_poses(lambda e: e.category == \"fish\"))\n", "\n", " print(\"\\nFile structure\")\n", " print(sf)\n" @@ -187,4 +187,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/examples/example_basic.py b/examples/example_basic.py index e328d380144d327512f7476e3493e687b57b43d8..b8d149ce4383025767d1d262335c4f20da3c2cc1 100755 --- a/examples/example_basic.py +++ b/examples/example_basic.py @@ -46,14 +46,14 @@ if __name__ == "__main__": sf = robofish.io.File(path=example_file) print("\nEntity Names") - print(sf.get_entity_names()) + print(sf.entity_names) # Get an array with all poses. As the length of poses varies per agent, it is filled up with nans. print("\nAll poses") - print(sf.get_poses()) + print(sf.entity_poses) print("\nFish poses") - print(sf.get_poses(category="fish")) + print(sf.select_entity_poses(lambda e: e.category == "fish")) print("\nFile structure") print(sf) diff --git a/src/robofish/io/entity.py b/src/robofish/io/entity.py index e5329988e71b9f8e54018aee9e339b614ad73f64..89e5c22de9fac242391161760af0fd2d0779f8d7 100644 --- a/src/robofish/io/entity.py +++ b/src/robofish/io/entity.py @@ -54,8 +54,17 @@ class Entity(h5py.Group): ori_vec[:, 1] = np.sin(ori_rad[:, 0]) return ori_vec - def getName(self): - return self.name.split("/")[-1] + @property + def group_name(self): + return super().name + + @property + def name(self): + return self.group_name.split("/")[-1] + + @property + def category(self): + return self.attrs["category"] def create_outlines(self, outlines: Iterable, sampling=None): outlines = self.create_dataset("outlines", data=outlines, dtype=np.float32) @@ -98,12 +107,21 @@ class Entity(h5py.Group): positions.attrs["sampling"] = sampling orientations.attrs["sampling"] = sampling - def get_poses(self): - poses = np.concatenate([self["positions"], self["orientations"]], axis=1) - return poses + @property + def positions(self): + return self["positions"] + + @property + def orientations(self): + return self["orientations"] + + @property + def poses(self): + return np.concatenate([self.positions, self.orientations], axis=1) - def get_poses_rad(self): - poses = self.get_poses() + @property + def poses_rad(self): + poses = self.poses # calculate the angles from the orientation vectors, write them to the third row and delete the fourth row poses[:, 2] = np.arctan2(poses[:, 3], poses[:, 2]) poses = poses[:, :3] diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index 1914f99d520facecaeb122da8a83b2ac86cb48be..3bc5833140f9edbe3f87bf23227ac55a0e654610 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -201,7 +201,13 @@ class File(h5py.File): self["samplings"].attrs["default"] = self.default_sampling return name - def get_frequency(self): + @property + def world_size(self): + return self.attrs["world_size_cm"] + + @property + def frequency(self): + # NOTE: Only works if default sampling availabe and specified with frequency_hz. default_sampling = self["samplings"].attrs["default"] return self["samplings"][default_sampling].attrs["frequency_hz"] @@ -290,13 +296,8 @@ class File(h5py.File): ) return returned_names - def get_entities(self): - return { - e_name: robofish.io.Entity.from_h5py_group(e_group) - for e_name, e_group in self["entities"].items() - } - - def get_entity_names(self) -> Iterable[str]: + @property + def entity_names(self) -> Iterable[str]: """ Getter for the names of all entities Returns: @@ -304,8 +305,16 @@ class File(h5py.File): """ return sorted(self["entities"].keys()) - def get_poses(self, names: Iterable = None, category: str = None) -> Iterable: - """ Get an array of the poses of entities + @property + def entities(self): + return [robofish.io.Entity.from_h5py_group(self["entities"][name]) for name in self.entity_names] + + @property + def entity_poses(self): + return self.select_entity_poses(None) + + def select_entity_poses(self, predicate = None) -> Iterable: + """ Select an array of the poses of entities If no name or category is specified, all entities will be selected. @@ -316,57 +325,29 @@ class File(h5py.File): An three dimensional array of all poses with the shape (entity, time, 4) """ - if names is not None and category is not None: - logging.error("Specify either names or a category, not both.") - raise Exception - - # collect the names of all entities with the correct category - if category is not None: - names = [ - e_name - for e_name, e_data in self["entities"].items() - if e_data.attrs["category"] == category - ] - - entities = self.get_entities() - - # If no names or category are given, select all - if names is None: - names = sorted(entities.keys()) - - # Entity objects given as names - if all([type(name) == robofish.io.Entity for name in names]): - names = [entity.getName() for entity in names] + entities = self.entities + if predicate is not None: + entities = [e for e in entities if predicate(e)] - if not all([type(name) == str for name in names]): - raise Exception( - "Given names were not strings. Instead names were %s" % names - ) - - max_timesteps = ( - 0 - if len(names) == 0 - else max([entities[e_name]["positions"].shape[0] for e_name in names]) - ) + max_timesteps = max([0] + [e.positions.shape[0] for e in entities]) # Initialize poses output array - poses_output = np.empty((len(names), max_timesteps, 4)) + poses_output = np.empty((len(entities), max_timesteps, 4)) poses_output[:] = np.nan # Fill poses output array i = 0 custom_sampling = None - for name in names: - entity = entities[name] + for entity in entities: if "sampling" in entity["positions"].attrs: if custom_sampling is None: custom_sampling = entity["positions"].attrs["sampling"] elif custom_sampling != entity["positions"].attrs["sampling"]: raise Exception( - "Multiple samplings found, which can not be given back by the get_poses function collectively." + "Multiple samplings found, preventing return of a single array." ) - poses = entity.get_poses() + poses = entity.poses poses_output[i][: poses.shape[0]] = poses i += 1 return poses_output diff --git a/src/robofish/io/validation.py b/src/robofish/io/validation.py index 776984b606fb89469a0d02cc860ac4489439ea63..0718d2d95b5685baa46147f3f390179c0b8a16f5 100644 --- a/src/robofish/io/validation.py +++ b/src/robofish/io/validation.py @@ -150,7 +150,9 @@ def validate(iofile: File, strict_validate: bool = True) -> (bool, str): # validate entities assert_validate("entities" in iofile, "entities not found") - for e_name, entity in iofile.get_entities().items(): + for entity in iofile.entities: + e_name = entity.name + assert_validate( type(entity) == Entity, "Entity group was not a robofish.io.Entity object", diff --git a/tests/robofish/io/test_entity.py b/tests/robofish/io/test_entity.py index 0ebf09738a53d8f983b8a99f8c54de7c351b8bda..8b3f22d809398a433f88dfb68c881e9d5dab2f82 100644 --- a/tests/robofish/io/test_entity.py +++ b/tests/robofish/io/test_entity.py @@ -12,7 +12,7 @@ def test_entity_object(): sf = robofish.io.File(world_size_cm=[100, 100], frequency_hz=25) f = sf.create_entity("fish", positions=[[10, 10]]) assert type(f) == robofish.io.Entity - assert f.getName() == "fish_1" + assert f.name == "fish_1" assert f.attrs["category"] == "fish" print(dir(f)) print(f["positions"]) @@ -26,7 +26,7 @@ def test_entity_object(): f2 = sf.create_entity("fish", poses=poses_rad) assert type(f2["positions"]) == h5py.Dataset assert type(f2["orientations"]) == h5py.Dataset - poses_rad_retrieved = f2.get_poses_rad() + poses_rad_retrieved = f2.poses_rad # Check if retrieved rad poses is close to the original poses. # Internally always ori_x and ori_y are used. When retrieved, the range is from -pi to pi, so for some of our original data 2 pi has to be substracted. diff --git a/tests/robofish/io/test_file.py b/tests/robofish/io/test_file.py index 8ce4bd4b9899f0fd785dcf315e7f75ee79eb26e2..3caab455abc921f601a25f6aff5301167dc081a5 100644 --- a/tests/robofish/io/test_file.py +++ b/tests/robofish/io/test_file.py @@ -69,7 +69,7 @@ def test_multiple_entities(): sf = robofish.io.File(world_size_cm=[100, 100], monotonic_time_points_us=m_points) returned_entities = sf.create_multiple_entities("fish", poses) - returned_names = [entity.getName() for entity in returned_entities] + returned_names = [entity.name for entity in returned_entities] expected_names = ["fish_1", "fish_2", "fish_3"] print(returned_names) @@ -79,24 +79,24 @@ def test_multiple_entities(): sf.validate() # The returned poses should be equal to the inserted poses - returned_poses = sf.get_poses() + returned_poses = sf.entity_poses print(returned_poses) assert (returned_poses == poses).all() # Just get the array for some names - returned_poses = sf.get_poses(["fish_1", "fish_2"]) + returned_poses = sf.select_entity_poses(lambda e: e.name in ["fish_1", "fish_2"]) assert (returned_poses == poses[:2]).all() - # Falsely specify names and category - with pytest.raises(Exception): - sf.get_poses(names=["fish_1"], category="fish") + # Filter on both category and name + returned_poses = sf.select_entity_poses(lambda e: e.category == "fish" and e.name == "fish_1") + assert (returned_poses == poses[:1]).all() # Insert some random obstacles returned_names = sf.create_multiple_entities( "obstacle", poses=np.random.random((agents, timesteps, 4)) ) # Obstacles should not be returned when only fish are selected - returned_poses = sf.get_poses(category="fish") + returned_poses = sf.select_entity_poses(lambda e: e.category == "fish") assert (returned_poses == poses).all() # for each of the entities @@ -124,12 +124,12 @@ def test_multiple_entities(): print(returned_names) print(sf) - # pass an poses array in separate parts (positions, orientations) and retreive it with get_poses. + # pass an poses array in separate parts (positions, orientations) and retrieve it with poses. poses_arr = np.random.random((100, 4)) position_orientation_fish = sf.create_entity( "fish", positions=poses_arr[:, :2], orientations=poses_arr[:, 2:] ) - assert np.isclose(poses_arr, position_orientation_fish.get_poses()).all() + assert np.isclose(poses_arr, position_orientation_fish.poses).all() sf.validate() return sf @@ -142,7 +142,7 @@ def test_load_validate(): def test_get_entity_names(): sf = robofish.io.File(path=valid_file_path) - names = sf.get_entity_names() + names = sf.entity_names assert len(names) == 9 assert names[0] == "fish_1" assert names[1] == "fish_2"