diff --git a/setup.py b/setup.py index 5af4ce7e8b8de7f1b7c148735c717543b8a9d22f..f7f775368784bb215fefe3fd7604d5d72d446cf5 100644 --- a/setup.py +++ b/setup.py @@ -14,6 +14,7 @@ entry_points = { "robofish-io-update-calculated-data=robofish.io.app:update_calculated_data", # TODO: This should be called robofish-evaluate which is not possible because of the package name (guess) ask moritz "robofish-io-evaluate=robofish.evaluate.app:evaluate", + "robofish-io-evaluate2=robofish.evaluate.app:app", "robofish-io-update-individual-ids = robofish.io.app:update_individual_ids", "robofish-io-update-world-shape = robofish.io.app:update_world_shape", "robofish-io-overwrite_user_configs = robofish.io.app:overwrite_user_configs", diff --git a/src/robofish/evaluate/__init__.py b/src/robofish/evaluate/__init__.py index d471a761c25cd04c103f72939d45cc4112649834..2390684cb505e8c72645b931dacd888e01f182c5 100644 --- a/src/robofish/evaluate/__init__.py +++ b/src/robofish/evaluate/__init__.py @@ -4,6 +4,7 @@ import logging import robofish.io from robofish.evaluate.evaluate import * +from robofish.evaluate.evaluate2 import evaluate_wd import robofish.evaluate.app if not ((3, 7) <= sys.version_info < (4, 0)): diff --git a/src/robofish/evaluate/app.py b/src/robofish/evaluate/app.py index ba38642019b4f16f121d608a6cb2748459f5144e..c665542e31c2253e333cd48d1ddb275dee45fa11 100644 --- a/src/robofish/evaluate/app.py +++ b/src/robofish/evaluate/app.py @@ -10,7 +10,9 @@ Functions available to be used in the commandline to evaluate robofish.io files. # Last doku update Feb 2021 import robofish.evaluate +import typer +from typing import List, Union import argparse from pathlib import Path import matplotlib.pyplot as plt @@ -168,3 +170,23 @@ def evaluate(args: dict = None) -> None: plt.close(fig) else: print(f"Evaluation function not found {args.analysis_type}") + + +app = typer.Typer() + + +@app.command() +def evaluate2(eval_type: str, paths: List[str], save_path: str = None) -> None: + + eval_types = [("wd", robofish.evaluate.evaluate2.evaluate_wd)] + + func = [f for n, f in eval_types if n == eval_type] + assert len(func) != 0, f"Unknown eval_type {eval_type}" + assert len(func) == 1, f"Ambiguous eval_type {eval_type}" + func = func[0] + + func(paths, save_path=save_path) + + +if __name__ == "__main__": + app() diff --git a/src/robofish/evaluate/evaluate2.py b/src/robofish/evaluate/evaluate2.py new file mode 100644 index 0000000000000000000000000000000000000000..aea2a2185e402fd3fa9d235a6ad3dc36cebe4a63 --- /dev/null +++ b/src/robofish/evaluate/evaluate2.py @@ -0,0 +1,37 @@ +import warnings +import numpy as np +import matplotlib.pyplot as plt +from pathlib import Path + +try: + import fish_models +except ImportError: + warnings.warn("fish_models not found. evaluate_2 won't work.") + + +def evaluate_wd(paths=None, dsets=None, save_path=None): + assert (paths is None) != (dsets is None), "Either paths or dsets must be given." + + if paths is not None: + dsets = [fish_models.IoDataset(p) for p in paths] + + fig, ax = plt.subplots(1, (len(dsets)), figsize=(15, 5), squeeze=False) + for di, dset in enumerate(dsets): + world_size = dset.world_size + assert world_size[0] == world_size[1], "World is not square" + world_size = world_size[0] + + wd = world_size / 2 - np.max(dset["poses"][..., :2], axis=(-1)) + + for f_wd in wd: + + ax[0, di].plot(f_wd.T) + + ax[0, di].set_title(f"{Path(paths[di]).name}") + ax[0, di].set_xlabel("Time") + ax[0, di].set_ylabel("Wall distance") + + if save_path is None: + plt.show() + else: + plt.savefig("wd.png")