From 1a0d28708960d28f0e4a4034125f640ee1a002e5 Mon Sep 17 00:00:00 2001 From: Andi <andi.gerken@gmail.com> Date: Thu, 27 Mar 2025 12:44:01 +0100 Subject: [PATCH] Added evaluate2 which will use IoDatasets instead of loading everything by hand. --- setup.py | 1 + src/robofish/evaluate/__init__.py | 1 + src/robofish/evaluate/app.py | 22 ++++++++++++++++++ src/robofish/evaluate/evaluate2.py | 37 ++++++++++++++++++++++++++++++ 4 files changed, 61 insertions(+) create mode 100644 src/robofish/evaluate/evaluate2.py diff --git a/setup.py b/setup.py index 5af4ce7..f7f7753 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,7 @@ entry_points = { "robofish-io-update-calculated-data=robofish.io.app:update_calculated_data", # TODO: This should be called robofish-evaluate which is not possible because of the package name (guess) ask moritz "robofish-io-evaluate=robofish.evaluate.app:evaluate", + "robofish-io-evaluate2=robofish.evaluate.app:app", "robofish-io-update-individual-ids = robofish.io.app:update_individual_ids", "robofish-io-update-world-shape = robofish.io.app:update_world_shape", "robofish-io-overwrite_user_configs = robofish.io.app:overwrite_user_configs", diff --git a/src/robofish/evaluate/__init__.py b/src/robofish/evaluate/__init__.py index d471a76..2390684 100644 --- a/src/robofish/evaluate/__init__.py +++ b/src/robofish/evaluate/__init__.py @@ -4,6 +4,7 @@ import logging import robofish.io from robofish.evaluate.evaluate import * +from robofish.evaluate.evaluate2 import evaluate_wd import robofish.evaluate.app if not ((3, 7) <= sys.version_info < (4, 0)): diff --git a/src/robofish/evaluate/app.py b/src/robofish/evaluate/app.py index ba38642..c665542 100644 --- a/src/robofish/evaluate/app.py +++ b/src/robofish/evaluate/app.py @@ -10,7 +10,9 @@ Functions available to be used in the commandline to evaluate robofish.io files. # Last doku update Feb 2021 import robofish.evaluate +import typer +from typing import List, Union import argparse from pathlib import Path import matplotlib.pyplot as plt @@ -168,3 +170,23 @@ def evaluate(args: dict = None) -> None: plt.close(fig) else: print(f"Evaluation function not found {args.analysis_type}") + + +app = typer.Typer() + + +@app.command() +def evaluate2(eval_type: str, paths: List[str], save_path: str = None) -> None: + + eval_types = [("wd", robofish.evaluate.evaluate2.evaluate_wd)] + + func = [f for n, f in eval_types if n == eval_type] + assert len(func) != 0, f"Unknown eval_type {eval_type}" + assert len(func) == 1, f"Ambiguous eval_type {eval_type}" + func = func[0] + + func(paths, save_path=save_path) + + +if __name__ == "__main__": + app() diff --git a/src/robofish/evaluate/evaluate2.py b/src/robofish/evaluate/evaluate2.py new file mode 100644 index 0000000..aea2a21 --- /dev/null +++ b/src/robofish/evaluate/evaluate2.py @@ -0,0 +1,37 @@ +import warnings +import numpy as np +import matplotlib.pyplot as plt +from pathlib import Path + +try: + import fish_models +except ImportError: + warnings.warn("fish_models not found. evaluate_2 won't work.") + + +def evaluate_wd(paths=None, dsets=None, save_path=None): + assert (paths is None) != (dsets is None), "Either paths or dsets must be given." + + if paths is not None: + dsets = [fish_models.IoDataset(p) for p in paths] + + fig, ax = plt.subplots(1, (len(dsets)), figsize=(15, 5), squeeze=False) + for di, dset in enumerate(dsets): + world_size = dset.world_size + assert world_size[0] == world_size[1], "World is not square" + world_size = world_size[0] + + wd = world_size / 2 - np.max(dset["poses"][..., :2], axis=(-1)) + + for f_wd in wd: + + ax[0, di].plot(f_wd.T) + + ax[0, di].set_title(f"{Path(paths[di]).name}") + ax[0, di].set_xlabel("Time") + ax[0, di].set_ylabel("Wall distance") + + if save_path is None: + plt.show() + else: + plt.savefig("wd.png") -- GitLab