From 2a9ddb37a540bb4825bf2240be591249536ccb18 Mon Sep 17 00:00:00 2001
From: marc131183 <marcg99@zedat.fu-berlin.de>
Date: Wed, 3 Mar 2021 13:47:39 +0100
Subject: [PATCH] removed consider_ and added predicate, to select fishes

---
 src/robofish/evaluate/evaluate.py | 199 +++++++-----------------------
 1 file changed, 47 insertions(+), 152 deletions(-)

diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py
index d47e0d5..f3d0e99 100644
--- a/src/robofish/evaluate/evaluate.py
+++ b/src/robofish/evaluate/evaluate.py
@@ -19,29 +19,20 @@ from typing import Iterable
 from scipy import stats
 
 
-<<<<<<< HEAD
-# def get_all_poses_from_paths(paths: Iterable(str)):
-#     """This function reads all poses from given paths.
-=======
 def get_all_poses_from_paths(paths: Iterable[str]):
     """This function reads all poses from given paths.
->>>>>>> master
-
-#     Args:
-#         paths: An array of strings, with files or folders
-#     Returns:
-#         An array, containing poses with the shape [paths][files][entities, timesteps, 4]
-#     """
-#     # Open all files, shape (paths, files)
-#     files_per_path = [robofish.io.read_multiple_files(p) for p in paths]
-
-<<<<<<< HEAD
-#     # Read all poses from the files, shape (paths, files)
-#     poses_per_path = [[f.get_poses() for f in files] for files in files_per_path]
-=======
+
+    #     Args:
+    #         paths: An array of strings, with files or folders
+    #     Returns:
+    #         An array, containing poses with the shape [paths][files][entities, timesteps, 4]
+    #"""
+    #     # Open all files, shape (paths, files)
+    #     files_per_path = [robofish.io.read_multiple_files(p) for p in paths]
+
     # Read all poses from the files, shape (paths, files)
     poses_per_path = [[f.entity_poses for f in files] for files in files_per_path]
->>>>>>> master
+
 
 #     # close all files
 #     for p in files_per_path:
@@ -51,20 +42,14 @@ def get_all_poses_from_paths(paths: Iterable[str]):
 #     return poses_per_path
 
 
-def evaluate_speed(
-    paths, names=None, save_path=None, consider_names=None, consider_categories=None
-):
+def evaluate_speed(paths, names=None, save_path=None, predicate=None):
     files_per_path = [robofish.io.read_multiple_files(p) for p in paths]
     speeds = []
     for k, files in enumerate(files_per_path):
         path_speeds = []
         for p, file in files.items():
-            # TODO: Change to select_poses()
-            poses = file.get_poses(
-                names=None if consider_names is None else consider_names[k],
-                category=None
-                if consider_categories is None
-                else consider_categories[k],
+            poses = file.select_entity_poses(
+                None if predicate is None else predicate[k]
             )
             for e_poses in poses:
                 e_speeds = np.linalg.norm(np.diff(e_poses[:, :2], axis=0), axis=1)
@@ -90,20 +75,14 @@ def evaluate_speed(
     plt.close()
 
 
-def evaluate_turn(
-    paths, names=None, save_path=None, consider_names=None, consider_categories=None
-):
+def evaluate_turn(paths, names=None, save_path=None, predicate=None):
     files_per_path = [robofish.io.read_multiple_files(p) for p in paths]
     turns = []
     for k, files in enumerate(files_per_path):
         path_turns = []
         for p, file in files.items():
-            # TODO: Change to select_poses()
-            poses = file.get_poses(
-                names=None if consider_names is None else consider_names[k],
-                category=None
-                if consider_categories is None
-                else consider_categories[k],
+            poses = file.select_entity_poses(
+                None if predicate is None else predicate[k]
             )
 
             # Todo check if all frequencies are the same
@@ -140,19 +119,13 @@ def evaluate_turn(
     plt.close()
 
 
-def evaluate_orientation(
-    paths, names=None, save_path=None, consider_names=None, consider_categories=None
-):
+def evaluate_orientation(paths, names=None, save_path=None, predicate=None):
     files_per_path = [robofish.io.read_multiple_files(p) for p in paths]
     orientations = []
     for k, files in enumerate(files_per_path):
         for p, file in files.items():
-            # TODO: Change to select_poses()
-            poses = file.get_poses(
-                names=None if consider_names is None else consider_names[k],
-                category=None
-                if consider_categories is None
-                else consider_categories[k],
+            poses = file.select_entity_poses(
+                None if predicate is None else predicate[k]
             )
             poses = poses.reshape((len(poses) * len(poses[0]), 4))
             world_size = file.attrs["world_size_cm"]
@@ -211,20 +184,14 @@ def evaluate_orientation(
     plt.close()
 
 
-def evaluate_relativeOrientation(
-    paths, names=None, save_path=None, consider_names=None, consider_categories=None
-):
+def evaluate_relativeOrientation(paths, names=None, save_path=None, predicate=None):
     files_per_path = [robofish.io.read_multiple_files(p) for p in paths]
     orientations = []
     for k, files in enumerate(files_per_path):
         path_orientations = []
         for p, file in files.items():
-            # TODO: Change to select_poses()
-            poses = file.get_poses(
-                names=None if consider_names is None else consider_names[k],
-                category=None
-                if consider_categories is None
-                else consider_categories[k],
+            poses = file.select_entity_poses(
+                None if predicate is None else predicate[k]
             )
             all_poses = file.entity_poses
             for i in range(len(poses)):
@@ -255,9 +222,7 @@ def evaluate_relativeOrientation(
     plt.close()
 
 
-def evaluate_distanceToWall(
-    paths, names=None, save_path=None, consider_names=None, consider_categories=None
-):
+def evaluate_distanceToWall(paths, names=None, save_path=None, predicate=None):
     """
     only works for rectangular tanks
     """
@@ -269,12 +234,8 @@ def evaluate_distanceToWall(
         for p, file in files.items():
             worldBoundsX.append(file.attrs["world_size_cm"][0])
             worldBoundsY.append(file.attrs["world_size_cm"][1])
-            # TODO: Change to select_poses()
-            poses = file.get_poses(
-                names=None if consider_names is None else consider_names[k],
-                category=None
-                if consider_categories is None
-                else consider_categories[k],
+            poses = file.select_entity_poses(
+                None if predicate is None else predicate[k]
             )
             world_size = file.attrs["world_size_cm"]
             world_bounds = [
@@ -328,9 +289,7 @@ def evaluate_distanceToWall(
     plt.close()
 
 
-def evaluate_tankpositions(
-    paths, names=None, save_path=None, consider_names=None, consider_categories=None
-):
+def evaluate_tankpositions(paths, names=None, save_path=None, predicate=None):
     """
     Heatmap of fishpositions
     By Moritz Maxeiner
@@ -341,12 +300,8 @@ def evaluate_tankpositions(
     for k, files in enumerate(files_per_path):
         path_x_pos, path_y_pos = [], []
         for p, file in files.items():
-            # TODO: Change to select_poses()
-            poses = file.get_poses(
-                names=None if consider_names is None else consider_names[k],
-                category=None
-                if consider_categories is None
-                else consider_categories[k],
+            poses = file.select_entity_poses(
+                None if predicate is None else predicate[k]
             )
             world_bounds.append(file.attrs["world_size_cm"])
             for e_poses in poses:
@@ -375,9 +330,7 @@ def evaluate_tankpositions(
     plt.close()
 
 
-def evaluate_trajectories(
-    paths, names=None, save_path=None, consider_names=None, consider_categories=None
-):
+def evaluate_trajectories(paths, names=None, save_path=None, predicate=None):
     """
     trajectories of fishes
     By Moritz Maxeiner
@@ -387,12 +340,8 @@ def evaluate_trajectories(
     world_bounds = []
     for k, files in enumerate(files_per_path):
         for p, file in files.items():
-            # TODO: Change to select_poses()
-            poses = file.get_poses(
-                names=None if consider_names is None else consider_names[k],
-                category=None
-                if consider_categories is None
-                else consider_categories[k],
+            poses = file.select_entity_poses(
+                None if predicate is None else predicate[k]
             )
             world_bounds.append(file.attrs["world_size_cm"])
             path_pos = {
@@ -451,9 +400,7 @@ def evaluate_trajectories(
     plt.close()
 
 
-def evaluate_positionVec(
-    paths, names=None, save_path=None, consider_names=None, consider_categories=None
-):
+def evaluate_positionVec(paths, names=None, save_path=None, predicate=None):
     files_per_path = [robofish.io.read_multiple_files(p) for p in paths]
     posVec = []
     worldBoundsX, worldBoundsY = [], []
@@ -462,12 +409,8 @@ def evaluate_positionVec(
         for p, file in files.items():
             worldBoundsX.append(file.attrs["world_size_cm"][0])
             worldBoundsY.append(file.attrs["world_size_cm"][1])
-            # TODO: Change to select_poses()
-            poses = file.get_poses(
-                names=None if consider_names is None else consider_names[k],
-                category=None
-                if consider_categories is None
-                else consider_categories[k],
+            poses = file.select_entity_poses(
+                None if predicate is None else predicate[k]
             )
             all_poses = file.entity_poses
             # calculate posVec for every fish combination
@@ -512,9 +455,7 @@ def evaluate_positionVec(
     plt.close()
 
 
-def evaluate_follow_iid(
-    paths, names=None, save_path=None, consider_names=None, consider_categories=None
-):
+def evaluate_follow_iid(paths, names=None, save_path=None, predicate=None):
     files_per_path = [robofish.io.read_multiple_files(p) for p in paths]
     follow, iid = [], []
     worldBoundsX, worldBoundsY = [], []
@@ -523,12 +464,8 @@ def evaluate_follow_iid(
         for p, file in files.items():
             worldBoundsX.append(file.attrs["world_size_cm"][0])
             worldBoundsY.append(file.attrs["world_size_cm"][1])
-            # TODO: Change to select_poses()
-            poses = file.get_poses(
-                names=None if consider_names is None else consider_names[k],
-                category=None
-                if consider_categories is None
-                else consider_categories[k],
+            poses = file.select_entity_poses(
+                None if predicate is None else predicate[k]
             )
             all_poses = file.entity_poses
             for i in range(len(poses)):
@@ -586,60 +523,18 @@ def evaluate_follow_iid(
     plt.close()
 
 
-def evaluate_all(
-    paths, names=None, save_folder=None, consider_names=None, consider_categories=None
-):
-    evaluate_speed(
-        paths, names, save_folder + "speed.png", consider_names, consider_categories
-    )
-    evaluate_turn(
-        paths, names, save_folder + "turn.png", consider_names, consider_categories
-    )
-    evaluate_orientation(
-        paths,
-        names,
-        save_folder + "orientation.png",
-        consider_names,
-        consider_categories,
-    )
+def evaluate_all(paths, names=None, save_folder=None, predicate=None):
+    evaluate_speed(paths, names, save_folder + "speed.png", predicate)
+    evaluate_turn(paths, names, save_folder + "turn.png", predicate)
+    evaluate_orientation(paths, names, save_folder + "orientation.png", predicate)
     evaluate_relativeOrientation(
-        paths,
-        names,
-        save_folder + "relativeOrientation.png",
-        consider_names,
-        consider_categories,
-    )
-    evaluate_distanceToWall(
-        paths,
-        names,
-        save_folder + "distanceToWall.png",
-        consider_names,
-        consider_categories,
-    )
-    evaluate_tankpositions(
-        paths,
-        names,
-        save_folder + "tankpositions.png",
-        consider_names,
-        consider_categories,
-    )
-    evaluate_trajectories(
-        paths,
-        names,
-        save_folder + "trajectories.png",
-        consider_names,
-        consider_categories,
-    )
-    evaluate_positionVec(
-        paths, names, save_folder + "posVec.png", consider_names, consider_categories
-    )
-    evaluate_follow_iid(
-        paths,
-        names,
-        save_folder + "follow_iid.png",
-        consider_names,
-        consider_categories,
+        paths, names, save_folder + "relativeOrientation.png", predicate
     )
+    evaluate_distanceToWall(paths, names, save_folder + "distanceToWall.png", predicate)
+    evaluate_tankpositions(paths, names, save_folder + "tankpositions.png", predicate)
+    evaluate_trajectories(paths, names, save_folder + "trajectories.png", predicate)
+    evaluate_positionVec(paths, names, save_folder + "posVec.png", predicate)
+    evaluate_follow_iid(paths, names, save_folder + "follow_iid.png", predicate)
 
 
 def calculate_follow(a, b):
-- 
GitLab