Skip to content
Snippets Groups Projects
Commit e42758bb authored by Andi Gerken's avatar Andi Gerken
Browse files

Fixed evaluate functions and corresponding tests

parent 4417d8ed
No related branches found
No related tags found
No related merge requests found
...@@ -31,7 +31,7 @@ def get_all_poses_from_paths(paths: Iterable[str]): ...@@ -31,7 +31,7 @@ def get_all_poses_from_paths(paths: Iterable[str]):
files_per_path = [robofish.io.read_multiple_files(p) for p in paths] files_per_path = [robofish.io.read_multiple_files(p) for p in paths]
# Read all poses from the files, shape (paths, files) # Read all poses from the files, shape (paths, files)
poses_per_path = [[f.get_poses() for f in files] for files in files_per_path] poses_per_path = [[f.poses for f in files] for files in files_per_path]
# close all files # close all files
for p in files_per_path: for p in files_per_path:
...@@ -49,6 +49,7 @@ def evaluate_speed( ...@@ -49,6 +49,7 @@ def evaluate_speed(
for k, files in enumerate(files_per_path): for k, files in enumerate(files_per_path):
path_speeds = [] path_speeds = []
for p, file in files.items(): for p, file in files.items():
# TODO: Change to select_poses()
poses = file.get_poses( poses = file.get_poses(
names=None if consider_names is None else consider_names[k], names=None if consider_names is None else consider_names[k],
category=None category=None
...@@ -57,7 +58,7 @@ def evaluate_speed( ...@@ -57,7 +58,7 @@ def evaluate_speed(
) )
for e_poses in poses: for e_poses in poses:
e_speeds = np.linalg.norm(np.diff(e_poses[:, :2], axis=0), axis=1) e_speeds = np.linalg.norm(np.diff(e_poses[:, :2], axis=0), axis=1)
e_speeds *= file.get_frequency() e_speeds *= file.frequency
path_speeds.extend(e_speeds) path_speeds.extend(e_speeds)
speeds.append(path_speeds) speeds.append(path_speeds)
...@@ -87,6 +88,7 @@ def evaluate_turn( ...@@ -87,6 +88,7 @@ def evaluate_turn(
for k, files in enumerate(files_per_path): for k, files in enumerate(files_per_path):
path_turns = [] path_turns = []
for p, file in files.items(): for p, file in files.items():
# TODO: Change to select_poses()
poses = file.get_poses( poses = file.get_poses(
names=None if consider_names is None else consider_names[k], names=None if consider_names is None else consider_names[k],
category=None category=None
...@@ -95,7 +97,7 @@ def evaluate_turn( ...@@ -95,7 +97,7 @@ def evaluate_turn(
) )
# Todo check if all frequencies are the same # Todo check if all frequencies are the same
frequency = file.get_frequency() frequency = file.frequency
for e_poses in poses: for e_poses in poses:
# convert ori_x, ori_y to radians # convert ori_x, ori_y to radians
...@@ -104,7 +106,7 @@ def evaluate_turn( ...@@ -104,7 +106,7 @@ def evaluate_turn(
e_turns = ori_rad[1:] - ori_rad[:-1] e_turns = ori_rad[1:] - ori_rad[:-1]
e_turns = np.where(e_turns < -np.pi, e_turns + 2 * np.pi, e_turns) e_turns = np.where(e_turns < -np.pi, e_turns + 2 * np.pi, e_turns)
e_turns = np.where(e_turns > np.pi, e_turns - 2 * np.pi, e_turns) e_turns = np.where(e_turns > np.pi, e_turns - 2 * np.pi, e_turns)
# e_turns *= file.get_frequency() # e_turns *= file.frequency
e_turns *= 180 / np.pi e_turns *= 180 / np.pi
path_turns.extend(e_turns) path_turns.extend(e_turns)
turns.append(path_turns) turns.append(path_turns)
...@@ -135,6 +137,7 @@ def evaluate_orientation( ...@@ -135,6 +137,7 @@ def evaluate_orientation(
orientations = [] orientations = []
for k, files in enumerate(files_per_path): for k, files in enumerate(files_per_path):
for p, file in files.items(): for p, file in files.items():
# TODO: Change to select_poses()
poses = file.get_poses( poses = file.get_poses(
names=None if consider_names is None else consider_names[k], names=None if consider_names is None else consider_names[k],
category=None category=None
...@@ -206,13 +209,14 @@ def evaluate_relativeOrientation( ...@@ -206,13 +209,14 @@ def evaluate_relativeOrientation(
for k, files in enumerate(files_per_path): for k, files in enumerate(files_per_path):
path_orientations = [] path_orientations = []
for p, file in files.items(): for p, file in files.items():
# TODO: Change to select_poses()
poses = file.get_poses( poses = file.get_poses(
names=None if consider_names is None else consider_names[k], names=None if consider_names is None else consider_names[k],
category=None category=None
if consider_categories is None if consider_categories is None
else consider_categories[k], else consider_categories[k],
) )
all_poses = file.get_poses() all_poses = file.poses
for i in range(len(poses)): for i in range(len(poses)):
for j in range(len(all_poses)): for j in range(len(all_poses)):
if (poses[i] != all_poses[j]).any(): if (poses[i] != all_poses[j]).any():
...@@ -255,6 +259,7 @@ def evaluate_distanceToWall( ...@@ -255,6 +259,7 @@ def evaluate_distanceToWall(
for p, file in files.items(): for p, file in files.items():
worldBoundsX.append(file.attrs["world_size_cm"][0]) worldBoundsX.append(file.attrs["world_size_cm"][0])
worldBoundsY.append(file.attrs["world_size_cm"][1]) worldBoundsY.append(file.attrs["world_size_cm"][1])
# TODO: Change to select_poses()
poses = file.get_poses( poses = file.get_poses(
names=None if consider_names is None else consider_names[k], names=None if consider_names is None else consider_names[k],
category=None category=None
...@@ -326,6 +331,7 @@ def evaluate_tankpositions( ...@@ -326,6 +331,7 @@ def evaluate_tankpositions(
for k, files in enumerate(files_per_path): for k, files in enumerate(files_per_path):
path_x_pos, path_y_pos = [], [] path_x_pos, path_y_pos = [], []
for p, file in files.items(): for p, file in files.items():
# TODO: Change to select_poses()
poses = file.get_poses( poses = file.get_poses(
names=None if consider_names is None else consider_names[k], names=None if consider_names is None else consider_names[k],
category=None category=None
...@@ -371,6 +377,7 @@ def evaluate_trajectories( ...@@ -371,6 +377,7 @@ def evaluate_trajectories(
world_bounds = [] world_bounds = []
for k, files in enumerate(files_per_path): for k, files in enumerate(files_per_path):
for p, file in files.items(): for p, file in files.items():
# TODO: Change to select_poses()
poses = file.get_poses( poses = file.get_poses(
names=None if consider_names is None else consider_names[k], names=None if consider_names is None else consider_names[k],
category=None category=None
...@@ -445,13 +452,14 @@ def evaluate_positionVec( ...@@ -445,13 +452,14 @@ def evaluate_positionVec(
for p, file in files.items(): for p, file in files.items():
worldBoundsX.append(file.attrs["world_size_cm"][0]) worldBoundsX.append(file.attrs["world_size_cm"][0])
worldBoundsY.append(file.attrs["world_size_cm"][1]) worldBoundsY.append(file.attrs["world_size_cm"][1])
# TODO: Change to select_poses()
poses = file.get_poses( poses = file.get_poses(
names=None if consider_names is None else consider_names[k], names=None if consider_names is None else consider_names[k],
category=None category=None
if consider_categories is None if consider_categories is None
else consider_categories[k], else consider_categories[k],
) )
all_poses = file.get_poses() all_poses = file.poses
# calculate posVec for every fish combination # calculate posVec for every fish combination
for i in range(len(poses)): for i in range(len(poses)):
for j in range(len(all_poses)): for j in range(len(all_poses)):
...@@ -505,13 +513,14 @@ def evaluate_follow_iid( ...@@ -505,13 +513,14 @@ def evaluate_follow_iid(
for p, file in files.items(): for p, file in files.items():
worldBoundsX.append(file.attrs["world_size_cm"][0]) worldBoundsX.append(file.attrs["world_size_cm"][0])
worldBoundsY.append(file.attrs["world_size_cm"][1]) worldBoundsY.append(file.attrs["world_size_cm"][1])
# TODO: Change to select_poses()
poses = file.get_poses( poses = file.get_poses(
names=None if consider_names is None else consider_names[k], names=None if consider_names is None else consider_names[k],
category=None category=None
if consider_categories is None if consider_categories is None
else consider_categories[k], else consider_categories[k],
) )
all_poses = file.get_poses() all_poses = file.poses
for i in range(len(poses)): for i in range(len(poses)):
for j in range(len(all_poses)): for j in range(len(all_poses)):
if (poses[i] != all_poses[j]).any(): if (poses[i] != all_poses[j]).any():
......
...@@ -6,16 +6,23 @@ from pathlib import Path ...@@ -6,16 +6,23 @@ from pathlib import Path
logging.getLogger().setLevel(logging.INFO) logging.getLogger().setLevel(logging.INFO)
h5py_files = [utils.full_path(__file__, "../../resources/valid.hdf5")]
graphics_out = utils.full_path(__file__, "output_graph.png")
if graphics_out.exists():
graphics_out.unlink()
# TODO: reactivate test and change evaluate
def deactivated_test_app_validate(): def test_app_validate():
""" This tests the function of the robofish-io-validate command """ """ This tests the function of the robofish-io-validate command """
class DummyArgs: class DummyArgs:
def __init__(self, analysis_type, paths): def __init__(self, analysis_type):
self.analysis_type = analysis_type self.analysis_type = analysis_type
self.paths = [utils.full_path(__file__, paths)] self.paths = h5py_files
self.names = None self.names = None
self.save_path = None self.save_path = graphics_out
app.evaluate(DummyArgs("speed", "../../resources/valid.hdf5")) # TODO: Get rid of deprecated get_poses function
with pytest.warns(DeprecationWarning):
app.evaluate(DummyArgs("speed"))
graphics_out.unlink()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment