diff --git a/examples/example_basic.ipynb b/examples/example_basic.ipynb index ada157912f8586719aa91e45432645fd65248322..5445fe80256a9ee78bc0ae8e3c439135750b41b8 100644 --- a/examples/example_basic.ipynb +++ b/examples/example_basic.ipynb @@ -156,10 +156,10 @@ " # different time scales and have another function, which generates an\n", " # interpolated array.\n", " print(\"\\nAll poses\")\n", - " print(sf.select_entity_poses())\n", + " print(sf.select_poses())\n", "\n", " print(\"\\nFish poses\")\n", - " print(sf.select_entity_poses(lambda e: e.category == \"fish\"))\n", + " print(sf.select_poses(lambda e: e.category == \"fish\"))\n", "\n", " print(\"\\nFile structure\")\n", " print(sf)\n" diff --git a/examples/example_basic.py b/examples/example_basic.py index f0d46d13d268292c1390a9880688316cef344328..70c39b6b862543abba085657a339d4588b8cb67c 100755 --- a/examples/example_basic.py +++ b/examples/example_basic.py @@ -43,10 +43,10 @@ print(sf.entity_names) # Get an array with all poses. As the length of poses varies per agent, it is filled up with nans. print("\nAll poses") -print(sf.entity_poses) +print(sf.poses) print("\nFish poses") -print(sf.select_entity_poses(lambda e: e.category == "fish")) +print(sf.select_poses(lambda e: e.category == "fish")) print("\nFile structure") print(sf) diff --git a/examples/example_readme.py b/examples/example_readme.py index d634df04bdb07cd94d036cf50ae7e54878be05c6..71442e29faac159ffc739a10340ea65ecc64e005 100644 --- a/examples/example_readme.py +++ b/examples/example_readme.py @@ -32,4 +32,4 @@ with robofish.io.File(path, "w", world_size_cm=[100, 100], frequency_hz=25.0) as # Show and save the file print(f) - print("Poses Shape: ", f.entity_poses.shape) + print("Poses Shape: ", f.poses.shape) diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index da0edc38493fecb7881f2ab01d367d4bfea5db9a..2a9392743892c72cd3b4e72c3c45774681e5e553 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -322,10 +322,10 @@ class File(h5py.File): ] @property - def entity_poses(self): - return self.select_entity_poses(None) + def poses(self): + return self.select_poses(None) - def select_entity_poses(self, predicate=None) -> Iterable: + def select_poses(self, predicate=None) -> Iterable: """ Select an array of the poses of entities If no name or category is specified, all entities will be selected. @@ -379,7 +379,7 @@ class File(h5py.File): predicate = lambda e: e.name in names else: predicate = None - return self.select_entity_poses(predicate) + return self.select_poses(predicate) def validate(self, strict_validate: bool = True) -> (bool, str): """Validate the file to the specification. diff --git a/tests/robofish/io/test_file.py b/tests/robofish/io/test_file.py index 68342fa2c9bb8e6dae45fc50feddc5cfceefc0cd..2b81429af20b05f064f66f1b159a6698b3f35ce0 100644 --- a/tests/robofish/io/test_file.py +++ b/tests/robofish/io/test_file.py @@ -80,16 +80,16 @@ def test_multiple_entities(): sf.validate() # The returned poses should be equal to the inserted poses - returned_poses = sf.entity_poses + returned_poses = sf.poses print(returned_poses) assert (returned_poses == poses).all() # Just get the array for some names - returned_poses = sf.select_entity_poses(lambda e: e.name in ["fish_1", "fish_2"]) + returned_poses = sf.select_poses(lambda e: e.name in ["fish_1", "fish_2"]) assert (returned_poses == poses[:2]).all() # Filter on both category and name - returned_poses = sf.select_entity_poses( + returned_poses = sf.select_poses( lambda e: e.category == "fish" and e.name == "fish_1" ) assert (returned_poses == poses[:1]).all() @@ -99,7 +99,7 @@ def test_multiple_entities(): "obstacle", poses=np.random.random((agents, timesteps, 4)) ) # Obstacles should not be returned when only fish are selected - returned_poses = sf.select_entity_poses(lambda e: e.category == "fish") + returned_poses = sf.select_poses(lambda e: e.category == "fish") assert (returned_poses == poses).all() # for each of the entities @@ -138,6 +138,14 @@ def test_multiple_entities(): return sf +def test_deprecated_get_poses(): + f = test_multiple_entities() + with pytest.warns(DeprecationWarning): + assert f.get_poses().shape[0] == 10 + assert f.get_poses(category="fish").shape[0] == 7 + assert f.get_poses(names="fish_1").shape[0] == 1 + + def test_load_validate(): sf = robofish.io.File(path=valid_file_path) sf.validate()