From e42758bb38b699cca80757a3939f8caf32b7033d Mon Sep 17 00:00:00 2001
From: Andi Gerken <andi.gerken@gmail.com>
Date: Tue, 23 Feb 2021 19:57:24 +0100
Subject: [PATCH] Fixed evaluate functions and corresponding tests

---
 src/robofish/evaluate/evaluate.py            | 23 ++++++++++++++------
 tests/robofish/evaluate/test_app_evaluate.py | 19 +++++++++++-----
 2 files changed, 29 insertions(+), 13 deletions(-)

diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py
index 16b2e96..f9253e1 100644
--- a/src/robofish/evaluate/evaluate.py
+++ b/src/robofish/evaluate/evaluate.py
@@ -31,7 +31,7 @@ def get_all_poses_from_paths(paths: Iterable[str]):
     files_per_path = [robofish.io.read_multiple_files(p) for p in paths]
 
     # Read all poses from the files, shape (paths, files)
-    poses_per_path = [[f.get_poses() for f in files] for files in files_per_path]
+    poses_per_path = [[f.poses for f in files] for files in files_per_path]
 
     # close all files
     for p in files_per_path:
@@ -49,6 +49,7 @@ def evaluate_speed(
     for k, files in enumerate(files_per_path):
         path_speeds = []
         for p, file in files.items():
+            # TODO: Change to select_poses()
             poses = file.get_poses(
                 names=None if consider_names is None else consider_names[k],
                 category=None
@@ -57,7 +58,7 @@ def evaluate_speed(
             )
             for e_poses in poses:
                 e_speeds = np.linalg.norm(np.diff(e_poses[:, :2], axis=0), axis=1)
-                e_speeds *= file.get_frequency()
+                e_speeds *= file.frequency
                 path_speeds.extend(e_speeds)
         speeds.append(path_speeds)
 
@@ -87,6 +88,7 @@ def evaluate_turn(
     for k, files in enumerate(files_per_path):
         path_turns = []
         for p, file in files.items():
+            # TODO: Change to select_poses()
             poses = file.get_poses(
                 names=None if consider_names is None else consider_names[k],
                 category=None
@@ -95,7 +97,7 @@ def evaluate_turn(
             )
 
             # Todo check if all frequencies are the same
-            frequency = file.get_frequency()
+            frequency = file.frequency
 
             for e_poses in poses:
                 # convert ori_x, ori_y to radians
@@ -104,7 +106,7 @@ def evaluate_turn(
                 e_turns = ori_rad[1:] - ori_rad[:-1]
                 e_turns = np.where(e_turns < -np.pi, e_turns + 2 * np.pi, e_turns)
                 e_turns = np.where(e_turns > np.pi, e_turns - 2 * np.pi, e_turns)
-                # e_turns *= file.get_frequency()
+                # e_turns *= file.frequency
                 e_turns *= 180 / np.pi
                 path_turns.extend(e_turns)
         turns.append(path_turns)
@@ -135,6 +137,7 @@ def evaluate_orientation(
     orientations = []
     for k, files in enumerate(files_per_path):
         for p, file in files.items():
+            # TODO: Change to select_poses()
             poses = file.get_poses(
                 names=None if consider_names is None else consider_names[k],
                 category=None
@@ -206,13 +209,14 @@ def evaluate_relativeOrientation(
     for k, files in enumerate(files_per_path):
         path_orientations = []
         for p, file in files.items():
+            # TODO: Change to select_poses()
             poses = file.get_poses(
                 names=None if consider_names is None else consider_names[k],
                 category=None
                 if consider_categories is None
                 else consider_categories[k],
             )
-            all_poses = file.get_poses()
+            all_poses = file.poses
             for i in range(len(poses)):
                 for j in range(len(all_poses)):
                     if (poses[i] != all_poses[j]).any():
@@ -255,6 +259,7 @@ def evaluate_distanceToWall(
         for p, file in files.items():
             worldBoundsX.append(file.attrs["world_size_cm"][0])
             worldBoundsY.append(file.attrs["world_size_cm"][1])
+            # TODO: Change to select_poses()
             poses = file.get_poses(
                 names=None if consider_names is None else consider_names[k],
                 category=None
@@ -326,6 +331,7 @@ def evaluate_tankpositions(
     for k, files in enumerate(files_per_path):
         path_x_pos, path_y_pos = [], []
         for p, file in files.items():
+            # TODO: Change to select_poses()
             poses = file.get_poses(
                 names=None if consider_names is None else consider_names[k],
                 category=None
@@ -371,6 +377,7 @@ def evaluate_trajectories(
     world_bounds = []
     for k, files in enumerate(files_per_path):
         for p, file in files.items():
+            # TODO: Change to select_poses()
             poses = file.get_poses(
                 names=None if consider_names is None else consider_names[k],
                 category=None
@@ -445,13 +452,14 @@ def evaluate_positionVec(
         for p, file in files.items():
             worldBoundsX.append(file.attrs["world_size_cm"][0])
             worldBoundsY.append(file.attrs["world_size_cm"][1])
+            # TODO: Change to select_poses()
             poses = file.get_poses(
                 names=None if consider_names is None else consider_names[k],
                 category=None
                 if consider_categories is None
                 else consider_categories[k],
             )
-            all_poses = file.get_poses()
+            all_poses = file.poses
             # calculate posVec for every fish combination
             for i in range(len(poses)):
                 for j in range(len(all_poses)):
@@ -505,13 +513,14 @@ def evaluate_follow_iid(
         for p, file in files.items():
             worldBoundsX.append(file.attrs["world_size_cm"][0])
             worldBoundsY.append(file.attrs["world_size_cm"][1])
+            # TODO: Change to select_poses()
             poses = file.get_poses(
                 names=None if consider_names is None else consider_names[k],
                 category=None
                 if consider_categories is None
                 else consider_categories[k],
             )
-            all_poses = file.get_poses()
+            all_poses = file.poses
             for i in range(len(poses)):
                 for j in range(len(all_poses)):
                     if (poses[i] != all_poses[j]).any():
diff --git a/tests/robofish/evaluate/test_app_evaluate.py b/tests/robofish/evaluate/test_app_evaluate.py
index 6c47a0b..58ca278 100644
--- a/tests/robofish/evaluate/test_app_evaluate.py
+++ b/tests/robofish/evaluate/test_app_evaluate.py
@@ -6,16 +6,23 @@ from pathlib import Path
 
 logging.getLogger().setLevel(logging.INFO)
 
+h5py_files = [utils.full_path(__file__, "../../resources/valid.hdf5")]
+graphics_out = utils.full_path(__file__, "output_graph.png")
+if graphics_out.exists():
+    graphics_out.unlink()
 
-# TODO: reactivate test and change evaluate
-def deactivated_test_app_validate():
+
+def test_app_validate():
     """ This tests the function of the robofish-io-validate command """
 
     class DummyArgs:
-        def __init__(self, analysis_type, paths):
+        def __init__(self, analysis_type):
             self.analysis_type = analysis_type
-            self.paths = [utils.full_path(__file__, paths)]
+            self.paths = h5py_files
             self.names = None
-            self.save_path = None
+            self.save_path = graphics_out
 
-    app.evaluate(DummyArgs("speed", "../../resources/valid.hdf5"))
+    # TODO: Get rid of deprecated get_poses function
+    with pytest.warns(DeprecationWarning):
+        app.evaluate(DummyArgs("speed"))
+    graphics_out.unlink()
-- 
GitLab