From bc5ea988c98fa8888230ecaea61455d7ccd075d9 Mon Sep 17 00:00:00 2001
From: Andi <andi.gerken@gmail.com>
Date: Thu, 27 Mar 2025 11:33:19 +0100
Subject: [PATCH] Added evaluation to show all tracks in one plot.

---
 src/robofish/evaluate/app.py      |  1 +
 src/robofish/evaluate/evaluate.py | 32 +++++++++++++++++++++++++++++++
 2 files changed, 33 insertions(+)

diff --git a/src/robofish/evaluate/app.py b/src/robofish/evaluate/app.py
index 78697fd..ba38642 100644
--- a/src/robofish/evaluate/app.py
+++ b/src/robofish/evaluate/app.py
@@ -37,6 +37,7 @@ def function_dict() -> dict:
         "follow_iid": base.evaluate_follow_iid,
         "individual_speed": base.evaluate_individual_speed,
         "individual_iid": base.evaluate_individual_iid,
+        "single_plot_tracks": base.evaluate_single_plot_tracks,
         # "quiver": base.evaluate_quiver, # Quiver has issues with multiple paths and raises exceptions. The function is not used for now.
         "all": base.evaluate_all,
     }
diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py
index 793eb31..79c85f0 100644
--- a/src/robofish/evaluate/evaluate.py
+++ b/src/robofish/evaluate/evaluate.py
@@ -627,6 +627,38 @@ def evaluate_quiver(
     return fig
 
 
+def evaluate_single_plot_tracks(
+    paths: Iterable[Union[str, Path]],
+    labels: Iterable[str] = None,
+    predicate: Callable[[robofish.io.Entity], bool] = None,
+    poses_from_paths: Iterable[Iterable[np.ndarray]] = None,
+    max_files: int = None,
+):
+    """Evaluate the tracks of the entities as a plot."""
+    if poses_from_paths is None:
+        poses_from_paths, file_settings = utils.get_all_poses_from_paths(
+            paths, predicate, max_files=max_files
+        )
+
+    fig, ax = plt.subplots(
+        1, len(poses_from_paths), figsize=(8 * len(poses_from_paths), 8)
+    )
+    if len(poses_from_paths) == 1:
+        ax = [ax]
+
+    cmap = plt.get_cmap("Blues")
+    for i, poses_per_path in enumerate(poses_from_paths):
+        for fi, poses in enumerate(poses_per_path):
+            for e in poses:
+                ax[i].plot(
+                    e[:, 0], e[:, 1], c=cmap(fi / len(poses_per_path)), alpha=0.5
+                )
+        ax[i].set_title(labels[i])
+        ax[i].set_xlabel("x [cm]")
+        ax[i].set_ylabel("y [cm]")
+    return fig
+
+
 def evaluate_social_vector(
     paths: Iterable[Union[str, Path]],
     labels: Iterable[str] = None,
-- 
GitLab