From cf02c73b6305f60cd2d442bb79a788ca79160bf1 Mon Sep 17 00:00:00 2001 From: Andi Gerken <andi.gerken@gmail.com> Date: Tue, 23 Feb 2021 19:24:39 +0100 Subject: [PATCH] Renamed entity_poses to poses and select_entity_poses() to select_poses() --- examples/example_basic.ipynb | 4 ++-- examples/example_basic.py | 4 ++-- examples/example_readme.py | 2 +- src/robofish/io/file.py | 8 ++++---- tests/robofish/io/test_file.py | 16 ++++++++++++---- 5 files changed, 21 insertions(+), 13 deletions(-) diff --git a/examples/example_basic.ipynb b/examples/example_basic.ipynb index ada1579..5445fe8 100644 --- a/examples/example_basic.ipynb +++ b/examples/example_basic.ipynb @@ -156,10 +156,10 @@ " # different time scales and have another function, which generates an\n", " # interpolated array.\n", " print(\"\\nAll poses\")\n", - " print(sf.select_entity_poses())\n", + " print(sf.select_poses())\n", "\n", " print(\"\\nFish poses\")\n", - " print(sf.select_entity_poses(lambda e: e.category == \"fish\"))\n", + " print(sf.select_poses(lambda e: e.category == \"fish\"))\n", "\n", " print(\"\\nFile structure\")\n", " print(sf)\n" diff --git a/examples/example_basic.py b/examples/example_basic.py index f0d46d1..70c39b6 100755 --- a/examples/example_basic.py +++ b/examples/example_basic.py @@ -43,10 +43,10 @@ print(sf.entity_names) # Get an array with all poses. As the length of poses varies per agent, it is filled up with nans. print("\nAll poses") -print(sf.entity_poses) +print(sf.poses) print("\nFish poses") -print(sf.select_entity_poses(lambda e: e.category == "fish")) +print(sf.select_poses(lambda e: e.category == "fish")) print("\nFile structure") print(sf) diff --git a/examples/example_readme.py b/examples/example_readme.py index d634df0..71442e2 100644 --- a/examples/example_readme.py +++ b/examples/example_readme.py @@ -32,4 +32,4 @@ with robofish.io.File(path, "w", world_size_cm=[100, 100], frequency_hz=25.0) as # Show and save the file print(f) - print("Poses Shape: ", f.entity_poses.shape) + print("Poses Shape: ", f.poses.shape) diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index da0edc3..2a93927 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -322,10 +322,10 @@ class File(h5py.File): ] @property - def entity_poses(self): - return self.select_entity_poses(None) + def poses(self): + return self.select_poses(None) - def select_entity_poses(self, predicate=None) -> Iterable: + def select_poses(self, predicate=None) -> Iterable: """ Select an array of the poses of entities If no name or category is specified, all entities will be selected. @@ -379,7 +379,7 @@ class File(h5py.File): predicate = lambda e: e.name in names else: predicate = None - return self.select_entity_poses(predicate) + return self.select_poses(predicate) def validate(self, strict_validate: bool = True) -> (bool, str): """Validate the file to the specification. diff --git a/tests/robofish/io/test_file.py b/tests/robofish/io/test_file.py index 68342fa..2b81429 100644 --- a/tests/robofish/io/test_file.py +++ b/tests/robofish/io/test_file.py @@ -80,16 +80,16 @@ def test_multiple_entities(): sf.validate() # The returned poses should be equal to the inserted poses - returned_poses = sf.entity_poses + returned_poses = sf.poses print(returned_poses) assert (returned_poses == poses).all() # Just get the array for some names - returned_poses = sf.select_entity_poses(lambda e: e.name in ["fish_1", "fish_2"]) + returned_poses = sf.select_poses(lambda e: e.name in ["fish_1", "fish_2"]) assert (returned_poses == poses[:2]).all() # Filter on both category and name - returned_poses = sf.select_entity_poses( + returned_poses = sf.select_poses( lambda e: e.category == "fish" and e.name == "fish_1" ) assert (returned_poses == poses[:1]).all() @@ -99,7 +99,7 @@ def test_multiple_entities(): "obstacle", poses=np.random.random((agents, timesteps, 4)) ) # Obstacles should not be returned when only fish are selected - returned_poses = sf.select_entity_poses(lambda e: e.category == "fish") + returned_poses = sf.select_poses(lambda e: e.category == "fish") assert (returned_poses == poses).all() # for each of the entities @@ -138,6 +138,14 @@ def test_multiple_entities(): return sf +def test_deprecated_get_poses(): + f = test_multiple_entities() + with pytest.warns(DeprecationWarning): + assert f.get_poses().shape[0] == 10 + assert f.get_poses(category="fish").shape[0] == 7 + assert f.get_poses(names="fish_1").shape[0] == 1 + + def test_load_validate(): sf = robofish.io.File(path=valid_file_path) sf.validate() -- GitLab