Skip to content
Snippets Groups Projects
Commit c838f70d authored by Andi Gerken's avatar Andi Gerken
Browse files

Added support for calculated orientation

Added support for reading properties of multiple files.
parent e2f67823
Branches
Tags 0.1.5
No related merge requests found
Pipeline #36670 passed
...@@ -137,6 +137,18 @@ class Entity(h5py.Group): ...@@ -137,6 +137,18 @@ class Entity(h5py.Group):
return np.tile([1, 0], (self.positions.shape[0], 1)) return np.tile([1, 0], (self.positions.shape[0], 1))
return self["orientations"] return self["orientations"]
@property
def orientations_calculated(self):
diff = np.diff(self.positions, axis=0)
angles = np.arctan2(diff[:, 1], diff[:, 0])
return angles[:, np.newaxis]
@property
def poses_calc_ori_rad(self):
return np.concatenate(
[self.positions[:-1], self.orientations_calculated], axis=1
)
@property @property
def poses(self): def poses(self):
return np.concatenate([self.positions, self.orientations], axis=1) return np.concatenate([self.positions, self.orientations], axis=1)
...@@ -145,13 +157,12 @@ class Entity(h5py.Group): ...@@ -145,13 +157,12 @@ class Entity(h5py.Group):
def poses_rad(self): def poses_rad(self):
poses = self.poses poses = self.poses
# calculate the angles from the orientation vectors, write them to the third row and delete the fourth row # calculate the angles from the orientation vectors, write them to the third row and delete the fourth row
poses[:, 2] = np.arctan2(poses[:, 3], poses[:, 2]) ori_rad = np.arctan2(poses[:, 3], poses[:, 2])
poses = poses[:, :3] return np.concatenate([poses[:, :2], ori_rad[:, np.newaxis]], axis=1)
return poses
@property @property
def speed_turn_angle(self): def speed_turn(self):
"""Get the speed, turn and angles from the positions. """Get the speed, turn and from the positions.
The vectors pointing from each position to the next are computed. The vectors pointing from each position to the next are computed.
The output of the function describe these vectors. The output of the function describe these vectors.
...@@ -160,7 +171,6 @@ class Entity(h5py.Group): ...@@ -160,7 +171,6 @@ class Entity(h5py.Group):
The first column is the length of the vectors. The first column is the length of the vectors.
The second column is the turning angle, required to get from one vector to the next. The second column is the turning angle, required to get from one vector to the next.
We assume, that the entity is oriented "correctly" in the first pose. So the first turn angle is 0. We assume, that the entity is oriented "correctly" in the first pose. So the first turn angle is 0.
The third column is the orientation of each vector.
""" """
diff = np.diff(self.positions, axis=0) diff = np.diff(self.positions, axis=0)
...@@ -169,4 +179,4 @@ class Entity(h5py.Group): ...@@ -169,4 +179,4 @@ class Entity(h5py.Group):
turn = np.zeros_like(angles) turn = np.zeros_like(angles)
turn[0] = 0 turn[0] = 0
turn[1:] = utils.limit_angle_range(np.diff(angles)) turn[1:] = utils.limit_angle_range(np.diff(angles))
return np.stack([speed, turn, angles], axis=-1) return np.stack([speed, turn], axis=-1)
\ No newline at end of file
...@@ -366,11 +366,15 @@ class File(h5py.File): ...@@ -366,11 +366,15 @@ class File(h5py.File):
return self.select_entity_property(None, entity_property=Entity.poses_rad) return self.select_entity_property(None, entity_property=Entity.poses_rad)
@property @property
def speeds_turns_angles(self): def entity_poses_calc_ori_rad(self):
return self.select_entity_property( return self.select_entity_property(
None, entity_property=Entity.speed_turn_angle None, entity_property=Entity.poses_calc_ori_rad
) )
@property
def entity_speeds_turns(self):
return self.select_entity_property(None, entity_property=Entity.speed_turn)
def select_entity_poses(self, *args, ori_rad=False, **kwargs): def select_entity_poses(self, *args, ori_rad=False, **kwargs):
entity_property = Entity.poses_rad if ori_rad else Entity.poses entity_property = Entity.poses_rad if ori_rad else Entity.poses
return self.select_entity_property( return self.select_entity_property(
......
...@@ -71,28 +71,34 @@ def read_multiple_files( ...@@ -71,28 +71,34 @@ def read_multiple_files(
return sf_dict return sf_dict
def read_poses_from_multiple_files( def read_property_from_multiple_files(
paths: Union[Path, str, Iterable[Path], Iterable[str]], paths: Union[Path, str, Iterable[Path], Iterable[str]],
entity_property: property = None,
*,
strict_validate: bool = False, strict_validate: bool = False,
max_files: int = None, max_files: int = None,
shuffle: bool = False, shuffle: bool = False,
ori_rad: bool = False, predicate: callable = None,
): ):
"""Load hdf5 files from a given path and return the entity poses. """Load hdf5 files from a given path and return the property of the entities.
The function can be given the path to a single single hdf5 file, to a folder, The function can be given the path to a single single hdf5 file, to a folder,
containing hdf5 files, or an array of multiple files or folders. containing hdf5 files, or an array of multiple files or folders.
Args: Args:
path: The path to a hdf5 file or folder. path: The path to a hdf5 file or folder.
entity_property: A property of robofish.io.Entity default is Entity.poses_rad
strict_validate: Choice between error and warning in case of invalidity strict_validate: Choice between error and warning in case of invalidity
max_files: Maximum number of files to be read max_files: Maximum number of files to be read
shuffle: Shuffle the order of files shuffle: Shuffle the order of files
ori_rad: Return the orientations as radiants instead of unit vectors predicate:
Returns: Returns:
An array of all entity poses arrays An array of all entity properties arrays
""" """
assert (
entity_property is not None
), "Please select an entity property e.g. 'Entity.poses_rad'"
logging.info(f"Reading files from path {paths}") logging.info(f"Reading files from path {paths}")
list_types = (list, np.ndarray, pandas.core.series.Series) list_types = (list, np.ndarray, pandas.core.series.Series)
...@@ -122,12 +128,12 @@ def read_poses_from_multiple_files( ...@@ -122,12 +128,12 @@ def read_poses_from_multiple_files(
with robofish.io.File( with robofish.io.File(
path=file, strict_validate=strict_validate path=file, strict_validate=strict_validate
) as f: ) as f:
p = f.entity_poses_rad if ori_rad else f.entity_poses p = f.select_entity_property(predicate, entity_property)
poses_array.append(p) poses_array.append(p)
elif path is not None and path.exists(): elif path is not None and path.exists():
logging.info("found file %s" % path) logging.info("found file %s" % path)
with robofish.io.File(path=path, strict_validate=strict_validate) as f: with robofish.io.File(path=path, strict_validate=strict_validate) as f:
p = f.entity_poses_rad if ori_rad else f.entity_poses p = f.select_entity_property(predicate, entity_property)
poses_array.append(p) poses_array.append(p)
return poses_array return poses_array
\ No newline at end of file
...@@ -39,21 +39,20 @@ def test_entity_turn_speed(): ...@@ -39,21 +39,20 @@ def test_entity_turn_speed():
[np.cos(circle_rad) * circle_size, np.sin(circle_rad) * circle_size], axis=-1 [np.cos(circle_rad) * circle_size, np.sin(circle_rad) * circle_size], axis=-1
) )
e = f.create_entity("fish", positions=positions) e = f.create_entity("fish", positions=positions)
speed_turn_angle = e.speed_turn_angle speed_turn = e.speed_turn
assert speed_turn_angle.shape == (99, 3) assert speed_turn.shape == (99, 2)
# No turn in the first timestep, since initialization turns it the right way # No turn in the first timestep, since initialization turns it the right way
assert speed_turn_angle[0, 1] == 0 assert speed_turn[0, 1] == 0
# Turns and speeds shoud afterwards be all the same afterwards, since the fish swims with constant velocity and angular velocity. # Turns and speeds shoud afterwards be all the same afterwards, since the fish swims with constant velocity and angular velocity.
assert (np.std(speed_turn_angle[1:, :2], axis=0) < 0.0001).all() assert (np.std(speed_turn[1:], axis=0) < 0.0001).all()
# Use turn_speed to generate positions # Use turn_speed to generate positions
gen_positions = np.zeros((positions.shape[0], 3)) gen_positions = np.zeros((positions.shape[0], 3))
gen_positions[0, :2] = positions[0] gen_positions[0] = e.poses_calc_ori_rad[0]
gen_positions[0, 2] = speed_turn_angle[0, 2]
for i, (speed, turn, angle) in enumerate(speed_turn_angle): for i, (speed, turn) in enumerate(speed_turn):
new_angle = gen_positions[i, 2] + turn new_angle = gen_positions[i, 2] + turn
gen_positions[i + 1] = [ gen_positions[i + 1] = [
gen_positions[i, 0] + np.cos(new_angle) * speed, gen_positions[i, 0] + np.cos(new_angle) * speed,
......
...@@ -143,7 +143,9 @@ def test_speeds_turns_angles(): ...@@ -143,7 +143,9 @@ def test_speeds_turns_angles():
with robofish.io.File(world_size_cm=[100, 100], frequency_hz=25) as f: with robofish.io.File(world_size_cm=[100, 100], frequency_hz=25) as f:
poses = np.zeros((10, 100, 3)) poses = np.zeros((10, 100, 3))
f.create_multiple_entities("fish", poses=poses) f.create_multiple_entities("fish", poses=poses)
assert (f.speeds_turns_angles == 0).all()
# Stationary fish has no speed or turn
assert (f.entity_speeds_turns == 0).all()
def test_broken_sampling(caplog): def test_broken_sampling(caplog):
...@@ -185,6 +187,9 @@ def test_entity_positions_no_orientation(): ...@@ -185,6 +187,9 @@ def test_entity_positions_no_orientation():
assert f.entity_poses.shape == (1, 100, 4) assert f.entity_poses.shape == (1, 100, 4)
assert (f.entity_poses[:, :] == np.array([1, 1, 1, 0])).all() assert (f.entity_poses[:, :] == np.array([1, 1, 1, 0])).all()
# Calculate the orientation
assert f.entity_poses_calc_ori_rad.shape == (1, 99, 3)
def test_load_validate(): def test_load_validate():
sf = robofish.io.File(path=valid_file_path) sf = robofish.io.File(path=valid_file_path)
......
...@@ -45,8 +45,10 @@ path = utils.full_path(__file__, "../../resources/valid.hdf5") ...@@ -45,8 +45,10 @@ path = utils.full_path(__file__, "../../resources/valid.hdf5")
# TODO read from folder of valid files # TODO read from folder of valid files
@pytest.mark.parametrize("_path", [path, str(path)]) @pytest.mark.parametrize("_path", [path, str(path)])
def test_read_poses_from_multiple_folder(_path): def test_read_poses_rad_from_multiple_folder(_path):
poses = robofish.io.read_poses_from_multiple_files([_path, _path]) poses = robofish.io.read_property_from_multiple_files(
[_path, _path], robofish.io.entity.Entity.poses_rad
)
# Should find the 3 presaved hdf5 files # Should find the 3 presaved hdf5 files
assert len(poses) == 2 assert len(poses) == 2
for p in poses: for p in poses:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment