From e42758bb38b699cca80757a3939f8caf32b7033d Mon Sep 17 00:00:00 2001 From: Andi Gerken <andi.gerken@gmail.com> Date: Tue, 23 Feb 2021 19:57:24 +0100 Subject: [PATCH] Fixed evaluate functions and corresponding tests --- src/robofish/evaluate/evaluate.py | 23 ++++++++++++++------ tests/robofish/evaluate/test_app_evaluate.py | 19 +++++++++++----- 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py index 16b2e96..f9253e1 100644 --- a/src/robofish/evaluate/evaluate.py +++ b/src/robofish/evaluate/evaluate.py @@ -31,7 +31,7 @@ def get_all_poses_from_paths(paths: Iterable[str]): files_per_path = [robofish.io.read_multiple_files(p) for p in paths] # Read all poses from the files, shape (paths, files) - poses_per_path = [[f.get_poses() for f in files] for files in files_per_path] + poses_per_path = [[f.poses for f in files] for files in files_per_path] # close all files for p in files_per_path: @@ -49,6 +49,7 @@ def evaluate_speed( for k, files in enumerate(files_per_path): path_speeds = [] for p, file in files.items(): + # TODO: Change to select_poses() poses = file.get_poses( names=None if consider_names is None else consider_names[k], category=None @@ -57,7 +58,7 @@ def evaluate_speed( ) for e_poses in poses: e_speeds = np.linalg.norm(np.diff(e_poses[:, :2], axis=0), axis=1) - e_speeds *= file.get_frequency() + e_speeds *= file.frequency path_speeds.extend(e_speeds) speeds.append(path_speeds) @@ -87,6 +88,7 @@ def evaluate_turn( for k, files in enumerate(files_per_path): path_turns = [] for p, file in files.items(): + # TODO: Change to select_poses() poses = file.get_poses( names=None if consider_names is None else consider_names[k], category=None @@ -95,7 +97,7 @@ def evaluate_turn( ) # Todo check if all frequencies are the same - frequency = file.get_frequency() + frequency = file.frequency for e_poses in poses: # convert ori_x, ori_y to radians @@ -104,7 +106,7 @@ def evaluate_turn( e_turns = ori_rad[1:] - ori_rad[:-1] e_turns = np.where(e_turns < -np.pi, e_turns + 2 * np.pi, e_turns) e_turns = np.where(e_turns > np.pi, e_turns - 2 * np.pi, e_turns) - # e_turns *= file.get_frequency() + # e_turns *= file.frequency e_turns *= 180 / np.pi path_turns.extend(e_turns) turns.append(path_turns) @@ -135,6 +137,7 @@ def evaluate_orientation( orientations = [] for k, files in enumerate(files_per_path): for p, file in files.items(): + # TODO: Change to select_poses() poses = file.get_poses( names=None if consider_names is None else consider_names[k], category=None @@ -206,13 +209,14 @@ def evaluate_relativeOrientation( for k, files in enumerate(files_per_path): path_orientations = [] for p, file in files.items(): + # TODO: Change to select_poses() poses = file.get_poses( names=None if consider_names is None else consider_names[k], category=None if consider_categories is None else consider_categories[k], ) - all_poses = file.get_poses() + all_poses = file.poses for i in range(len(poses)): for j in range(len(all_poses)): if (poses[i] != all_poses[j]).any(): @@ -255,6 +259,7 @@ def evaluate_distanceToWall( for p, file in files.items(): worldBoundsX.append(file.attrs["world_size_cm"][0]) worldBoundsY.append(file.attrs["world_size_cm"][1]) + # TODO: Change to select_poses() poses = file.get_poses( names=None if consider_names is None else consider_names[k], category=None @@ -326,6 +331,7 @@ def evaluate_tankpositions( for k, files in enumerate(files_per_path): path_x_pos, path_y_pos = [], [] for p, file in files.items(): + # TODO: Change to select_poses() poses = file.get_poses( names=None if consider_names is None else consider_names[k], category=None @@ -371,6 +377,7 @@ def evaluate_trajectories( world_bounds = [] for k, files in enumerate(files_per_path): for p, file in files.items(): + # TODO: Change to select_poses() poses = file.get_poses( names=None if consider_names is None else consider_names[k], category=None @@ -445,13 +452,14 @@ def evaluate_positionVec( for p, file in files.items(): worldBoundsX.append(file.attrs["world_size_cm"][0]) worldBoundsY.append(file.attrs["world_size_cm"][1]) + # TODO: Change to select_poses() poses = file.get_poses( names=None if consider_names is None else consider_names[k], category=None if consider_categories is None else consider_categories[k], ) - all_poses = file.get_poses() + all_poses = file.poses # calculate posVec for every fish combination for i in range(len(poses)): for j in range(len(all_poses)): @@ -505,13 +513,14 @@ def evaluate_follow_iid( for p, file in files.items(): worldBoundsX.append(file.attrs["world_size_cm"][0]) worldBoundsY.append(file.attrs["world_size_cm"][1]) + # TODO: Change to select_poses() poses = file.get_poses( names=None if consider_names is None else consider_names[k], category=None if consider_categories is None else consider_categories[k], ) - all_poses = file.get_poses() + all_poses = file.poses for i in range(len(poses)): for j in range(len(all_poses)): if (poses[i] != all_poses[j]).any(): diff --git a/tests/robofish/evaluate/test_app_evaluate.py b/tests/robofish/evaluate/test_app_evaluate.py index 6c47a0b..58ca278 100644 --- a/tests/robofish/evaluate/test_app_evaluate.py +++ b/tests/robofish/evaluate/test_app_evaluate.py @@ -6,16 +6,23 @@ from pathlib import Path logging.getLogger().setLevel(logging.INFO) +h5py_files = [utils.full_path(__file__, "../../resources/valid.hdf5")] +graphics_out = utils.full_path(__file__, "output_graph.png") +if graphics_out.exists(): + graphics_out.unlink() -# TODO: reactivate test and change evaluate -def deactivated_test_app_validate(): + +def test_app_validate(): """ This tests the function of the robofish-io-validate command """ class DummyArgs: - def __init__(self, analysis_type, paths): + def __init__(self, analysis_type): self.analysis_type = analysis_type - self.paths = [utils.full_path(__file__, paths)] + self.paths = h5py_files self.names = None - self.save_path = None + self.save_path = graphics_out - app.evaluate(DummyArgs("speed", "../../resources/valid.hdf5")) + # TODO: Get rid of deprecated get_poses function + with pytest.warns(DeprecationWarning): + app.evaluate(DummyArgs("speed")) + graphics_out.unlink() -- GitLab