From bc5ea988c98fa8888230ecaea61455d7ccd075d9 Mon Sep 17 00:00:00 2001 From: Andi <andi.gerken@gmail.com> Date: Thu, 27 Mar 2025 11:33:19 +0100 Subject: [PATCH] Added evaluation to show all tracks in one plot. --- src/robofish/evaluate/app.py | 1 + src/robofish/evaluate/evaluate.py | 32 +++++++++++++++++++++++++++++++ 2 files changed, 33 insertions(+) diff --git a/src/robofish/evaluate/app.py b/src/robofish/evaluate/app.py index 78697fd..ba38642 100644 --- a/src/robofish/evaluate/app.py +++ b/src/robofish/evaluate/app.py @@ -37,6 +37,7 @@ def function_dict() -> dict: "follow_iid": base.evaluate_follow_iid, "individual_speed": base.evaluate_individual_speed, "individual_iid": base.evaluate_individual_iid, + "single_plot_tracks": base.evaluate_single_plot_tracks, # "quiver": base.evaluate_quiver, # Quiver has issues with multiple paths and raises exceptions. The function is not used for now. "all": base.evaluate_all, } diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py index 793eb31..79c85f0 100644 --- a/src/robofish/evaluate/evaluate.py +++ b/src/robofish/evaluate/evaluate.py @@ -627,6 +627,38 @@ def evaluate_quiver( return fig +def evaluate_single_plot_tracks( + paths: Iterable[Union[str, Path]], + labels: Iterable[str] = None, + predicate: Callable[[robofish.io.Entity], bool] = None, + poses_from_paths: Iterable[Iterable[np.ndarray]] = None, + max_files: int = None, +): + """Evaluate the tracks of the entities as a plot.""" + if poses_from_paths is None: + poses_from_paths, file_settings = utils.get_all_poses_from_paths( + paths, predicate, max_files=max_files + ) + + fig, ax = plt.subplots( + 1, len(poses_from_paths), figsize=(8 * len(poses_from_paths), 8) + ) + if len(poses_from_paths) == 1: + ax = [ax] + + cmap = plt.get_cmap("Blues") + for i, poses_per_path in enumerate(poses_from_paths): + for fi, poses in enumerate(poses_per_path): + for e in poses: + ax[i].plot( + e[:, 0], e[:, 1], c=cmap(fi / len(poses_per_path)), alpha=0.5 + ) + ax[i].set_title(labels[i]) + ax[i].set_xlabel("x [cm]") + ax[i].set_ylabel("y [cm]") + return fig + + def evaluate_social_vector( paths: Iterable[Union[str, Path]], labels: Iterable[str] = None, -- GitLab