From c838f70dc293647d2bb1b8de380bdb5c5716b831 Mon Sep 17 00:00:00 2001
From: Andi Gerken <andi.gerken@gmail.com>
Date: Fri, 26 Mar 2021 15:39:44 +0100
Subject: [PATCH] Added support for calculated orientation Added support for
 reading properties of multiple files.

---
 src/robofish/io/entity.py        | 24 +++++++++++++++++-------
 src/robofish/io/file.py          |  8 ++++++--
 src/robofish/io/io.py            | 20 +++++++++++++-------
 tests/robofish/io/test_entity.py | 13 ++++++-------
 tests/robofish/io/test_file.py   |  7 ++++++-
 tests/robofish/io/test_io.py     |  6 ++++--
 6 files changed, 52 insertions(+), 26 deletions(-)

diff --git a/src/robofish/io/entity.py b/src/robofish/io/entity.py
index 5e8898a..e407989 100644
--- a/src/robofish/io/entity.py
+++ b/src/robofish/io/entity.py
@@ -137,6 +137,18 @@ class Entity(h5py.Group):
             return np.tile([1, 0], (self.positions.shape[0], 1))
         return self["orientations"]
 
+    @property
+    def orientations_calculated(self):
+        diff = np.diff(self.positions, axis=0)
+        angles = np.arctan2(diff[:, 1], diff[:, 0])
+        return angles[:, np.newaxis]
+
+    @property
+    def poses_calc_ori_rad(self):
+        return np.concatenate(
+            [self.positions[:-1], self.orientations_calculated], axis=1
+        )
+
     @property
     def poses(self):
         return np.concatenate([self.positions, self.orientations], axis=1)
@@ -145,13 +157,12 @@ class Entity(h5py.Group):
     def poses_rad(self):
         poses = self.poses
         # calculate the angles from the orientation vectors, write them to the third row and delete the fourth row
-        poses[:, 2] = np.arctan2(poses[:, 3], poses[:, 2])
-        poses = poses[:, :3]
-        return poses
+        ori_rad = np.arctan2(poses[:, 3], poses[:, 2])
+        return np.concatenate([poses[:, :2], ori_rad[:, np.newaxis]], axis=1)
 
     @property
-    def speed_turn_angle(self):
-        """Get the speed, turn and angles from the positions.
+    def speed_turn(self):
+        """Get the speed, turn and from the positions.
 
         The vectors pointing from each position to the next are computed.
         The output of the function describe these vectors.
@@ -160,7 +171,6 @@ class Entity(h5py.Group):
             The first column is the length of the vectors.
             The second column is the turning angle, required to get from one vector to the next.
             We assume, that the entity is oriented "correctly" in the first pose. So the first turn angle is 0.
-            The third column is the orientation of each vector.
         """
 
         diff = np.diff(self.positions, axis=0)
@@ -169,4 +179,4 @@ class Entity(h5py.Group):
         turn = np.zeros_like(angles)
         turn[0] = 0
         turn[1:] = utils.limit_angle_range(np.diff(angles))
-        return np.stack([speed, turn, angles], axis=-1)
\ No newline at end of file
+        return np.stack([speed, turn], axis=-1)
diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py
index 55ac2b2..9ff216c 100644
--- a/src/robofish/io/file.py
+++ b/src/robofish/io/file.py
@@ -366,11 +366,15 @@ class File(h5py.File):
         return self.select_entity_property(None, entity_property=Entity.poses_rad)
 
     @property
-    def speeds_turns_angles(self):
+    def entity_poses_calc_ori_rad(self):
         return self.select_entity_property(
-            None, entity_property=Entity.speed_turn_angle
+            None, entity_property=Entity.poses_calc_ori_rad
         )
 
+    @property
+    def entity_speeds_turns(self):
+        return self.select_entity_property(None, entity_property=Entity.speed_turn)
+
     def select_entity_poses(self, *args, ori_rad=False, **kwargs):
         entity_property = Entity.poses_rad if ori_rad else Entity.poses
         return self.select_entity_property(
diff --git a/src/robofish/io/io.py b/src/robofish/io/io.py
index 240496f..c8894ca 100644
--- a/src/robofish/io/io.py
+++ b/src/robofish/io/io.py
@@ -71,28 +71,34 @@ def read_multiple_files(
     return sf_dict
 
 
-def read_poses_from_multiple_files(
+def read_property_from_multiple_files(
     paths: Union[Path, str, Iterable[Path], Iterable[str]],
+    entity_property: property = None,
+    *,
     strict_validate: bool = False,
     max_files: int = None,
     shuffle: bool = False,
-    ori_rad: bool = False,
+    predicate: callable = None,
 ):
-    """Load hdf5 files from a given path and return the entity poses.
+    """Load hdf5 files from a given path and return the property of the entities.
 
     The function can be given the path to a single single hdf5 file, to a folder,
     containing hdf5 files, or an array of multiple files or folders.
 
     Args:
         path: The path to a hdf5 file or folder.
+        entity_property: A property of robofish.io.Entity default is Entity.poses_rad
         strict_validate: Choice between error and warning in case of invalidity
         max_files: Maximum number of files to be read
         shuffle: Shuffle the order of files
-        ori_rad: Return the orientations as radiants instead of unit vectors
+        predicate:
     Returns:
-        An array of all entity poses arrays
+        An array of all entity properties arrays
     """
 
+    assert (
+        entity_property is not None
+    ), "Please select an entity property e.g. 'Entity.poses_rad'"
     logging.info(f"Reading files from path {paths}")
 
     list_types = (list, np.ndarray, pandas.core.series.Series)
@@ -122,12 +128,12 @@ def read_poses_from_multiple_files(
                     with robofish.io.File(
                         path=file, strict_validate=strict_validate
                     ) as f:
-                        p = f.entity_poses_rad if ori_rad else f.entity_poses
+                        p = f.select_entity_property(predicate, entity_property)
                         poses_array.append(p)
 
         elif path is not None and path.exists():
             logging.info("found file %s" % path)
             with robofish.io.File(path=path, strict_validate=strict_validate) as f:
-                p = f.entity_poses_rad if ori_rad else f.entity_poses
+                p = f.select_entity_property(predicate, entity_property)
                 poses_array.append(p)
     return poses_array
\ No newline at end of file
diff --git a/tests/robofish/io/test_entity.py b/tests/robofish/io/test_entity.py
index 4cf14f4..7427f5a 100644
--- a/tests/robofish/io/test_entity.py
+++ b/tests/robofish/io/test_entity.py
@@ -39,21 +39,20 @@ def test_entity_turn_speed():
         [np.cos(circle_rad) * circle_size, np.sin(circle_rad) * circle_size], axis=-1
     )
     e = f.create_entity("fish", positions=positions)
-    speed_turn_angle = e.speed_turn_angle
-    assert speed_turn_angle.shape == (99, 3)
+    speed_turn = e.speed_turn
+    assert speed_turn.shape == (99, 2)
 
     # No turn in the first timestep, since initialization turns it the right way
-    assert speed_turn_angle[0, 1] == 0
+    assert speed_turn[0, 1] == 0
 
     # Turns and speeds shoud afterwards be all the same afterwards, since the fish swims with constant velocity and angular velocity.
-    assert (np.std(speed_turn_angle[1:, :2], axis=0) < 0.0001).all()
+    assert (np.std(speed_turn[1:], axis=0) < 0.0001).all()
 
     # Use turn_speed to generate positions
     gen_positions = np.zeros((positions.shape[0], 3))
-    gen_positions[0, :2] = positions[0]
-    gen_positions[0, 2] = speed_turn_angle[0, 2]
+    gen_positions[0] = e.poses_calc_ori_rad[0]
 
-    for i, (speed, turn, angle) in enumerate(speed_turn_angle):
+    for i, (speed, turn) in enumerate(speed_turn):
         new_angle = gen_positions[i, 2] + turn
         gen_positions[i + 1] = [
             gen_positions[i, 0] + np.cos(new_angle) * speed,
diff --git a/tests/robofish/io/test_file.py b/tests/robofish/io/test_file.py
index 99ccbf4..fe6f1f9 100644
--- a/tests/robofish/io/test_file.py
+++ b/tests/robofish/io/test_file.py
@@ -143,7 +143,9 @@ def test_speeds_turns_angles():
     with robofish.io.File(world_size_cm=[100, 100], frequency_hz=25) as f:
         poses = np.zeros((10, 100, 3))
         f.create_multiple_entities("fish", poses=poses)
-        assert (f.speeds_turns_angles == 0).all()
+
+        # Stationary fish has no speed or turn
+        assert (f.entity_speeds_turns == 0).all()
 
 
 def test_broken_sampling(caplog):
@@ -185,6 +187,9 @@ def test_entity_positions_no_orientation():
         assert f.entity_poses.shape == (1, 100, 4)
         assert (f.entity_poses[:, :] == np.array([1, 1, 1, 0])).all()
 
+        # Calculate the orientation
+        assert f.entity_poses_calc_ori_rad.shape == (1, 99, 3)
+
 
 def test_load_validate():
     sf = robofish.io.File(path=valid_file_path)
diff --git a/tests/robofish/io/test_io.py b/tests/robofish/io/test_io.py
index 13bd876..bb4cb69 100644
--- a/tests/robofish/io/test_io.py
+++ b/tests/robofish/io/test_io.py
@@ -45,8 +45,10 @@ path = utils.full_path(__file__, "../../resources/valid.hdf5")
 
 # TODO read from folder of valid files
 @pytest.mark.parametrize("_path", [path, str(path)])
-def test_read_poses_from_multiple_folder(_path):
-    poses = robofish.io.read_poses_from_multiple_files([_path, _path])
+def test_read_poses_rad_from_multiple_folder(_path):
+    poses = robofish.io.read_property_from_multiple_files(
+        [_path, _path], robofish.io.entity.Entity.poses_rad
+    )
     # Should find the 3 presaved hdf5 files
     assert len(poses) == 2
     for p in poses:
-- 
GitLab