Skip to content
Snippets Groups Projects
Commit fa6de1cf authored by Andi Gerken's avatar Andi Gerken
Browse files

Reverted renaming of entity_poses to poses. The functions are called

- entity_poses
- select_entity_poses()
again.
parent e42758bb
Branches
Tags
No related merge requests found
Pipeline #35785 passed
......@@ -44,10 +44,10 @@ def create_example_file(path):
# Get an array with all poses. As the length of poses varies per agent, it is filled up with nans.
print("\nAll poses")
print(sf.poses)
print(sf.entity_poses)
print("\nFish poses")
print(sf.select_poses(lambda e: e.category == "fish"))
print(sf.select_entity_poses(lambda e: e.category == "fish"))
print("\nFile structure")
print(sf)
......
......@@ -29,7 +29,7 @@ def create_example_file(path):
# Show and save the file
print(f)
print("Poses Shape: ", f.poses.shape)
print("Poses Shape: ", f.entity_poses.shape)
if __name__ == "__main__":
......
......@@ -31,7 +31,7 @@ def get_all_poses_from_paths(paths: Iterable[str]):
files_per_path = [robofish.io.read_multiple_files(p) for p in paths]
# Read all poses from the files, shape (paths, files)
poses_per_path = [[f.poses for f in files] for files in files_per_path]
poses_per_path = [[f.entity_poses for f in files] for files in files_per_path]
# close all files
for p in files_per_path:
......@@ -216,7 +216,7 @@ def evaluate_relativeOrientation(
if consider_categories is None
else consider_categories[k],
)
all_poses = file.poses
all_poses = file.entity_poses
for i in range(len(poses)):
for j in range(len(all_poses)):
if (poses[i] != all_poses[j]).any():
......@@ -459,7 +459,7 @@ def evaluate_positionVec(
if consider_categories is None
else consider_categories[k],
)
all_poses = file.poses
all_poses = file.entity_poses
# calculate posVec for every fish combination
for i in range(len(poses)):
for j in range(len(all_poses)):
......@@ -520,7 +520,7 @@ def evaluate_follow_iid(
if consider_categories is None
else consider_categories[k],
)
all_poses = file.poses
all_poses = file.entity_poses
for i in range(len(poses)):
for j in range(len(all_poses)):
if (poses[i] != all_poses[j]).any():
......
......@@ -324,10 +324,10 @@ class File(h5py.File):
]
@property
def poses(self):
return self.select_poses(None)
def entity_poses(self):
return self.select_entity_poses(None)
def select_poses(self, predicate=None) -> Iterable:
def select_entity_poses(self, predicate=None) -> Iterable:
""" Select an array of the poses of entities
If no name or category is specified, all entities will be selected.
......@@ -381,7 +381,7 @@ class File(h5py.File):
predicate = lambda e: e.name in names
else:
predicate = None
return self.select_poses(predicate)
return self.select_entity_poses(predicate)
def validate(self, strict_validate: bool = True) -> (bool, str):
"""Validate the file to the specification.
......
......@@ -80,16 +80,16 @@ def test_multiple_entities():
sf.validate()
# The returned poses should be equal to the inserted poses
returned_poses = sf.poses
returned_poses = sf.entity_poses
print(returned_poses)
assert (returned_poses == poses).all()
# Just get the array for some names
returned_poses = sf.select_poses(lambda e: e.name in ["fish_1", "fish_2"])
returned_poses = sf.select_entity_poses(lambda e: e.name in ["fish_1", "fish_2"])
assert (returned_poses == poses[:2]).all()
# Filter on both category and name
returned_poses = sf.select_poses(
returned_poses = sf.select_entity_poses(
lambda e: e.category == "fish" and e.name == "fish_1"
)
assert (returned_poses == poses[:1]).all()
......@@ -99,7 +99,7 @@ def test_multiple_entities():
"obstacle", poses=np.random.random((agents, timesteps, 4))
)
# Obstacles should not be returned when only fish are selected
returned_poses = sf.select_poses(lambda e: e.category == "fish")
returned_poses = sf.select_entity_poses(lambda e: e.category == "fish")
assert (returned_poses == poses).all()
# for each of the entities
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment