Skip to content
Snippets Groups Projects
Commit 1a0d2870 authored by Andi Gerken's avatar Andi Gerken
Browse files

Added evaluate2 which will use IoDatasets instead of loading everything by hand.

parent bc5ea988
No related branches found
No related tags found
1 merge request!54robofish-io-evaluate2 and lazy imports
......@@ -14,6 +14,7 @@ entry_points = {
"robofish-io-update-calculated-data=robofish.io.app:update_calculated_data",
# TODO: This should be called robofish-evaluate which is not possible because of the package name (guess) ask moritz
"robofish-io-evaluate=robofish.evaluate.app:evaluate",
"robofish-io-evaluate2=robofish.evaluate.app:app",
"robofish-io-update-individual-ids = robofish.io.app:update_individual_ids",
"robofish-io-update-world-shape = robofish.io.app:update_world_shape",
"robofish-io-overwrite_user_configs = robofish.io.app:overwrite_user_configs",
......
......@@ -4,6 +4,7 @@ import logging
import robofish.io
from robofish.evaluate.evaluate import *
from robofish.evaluate.evaluate2 import evaluate_wd
import robofish.evaluate.app
if not ((3, 7) <= sys.version_info < (4, 0)):
......
......@@ -10,7 +10,9 @@ Functions available to be used in the commandline to evaluate robofish.io files.
# Last doku update Feb 2021
import robofish.evaluate
import typer
from typing import List, Union
import argparse
from pathlib import Path
import matplotlib.pyplot as plt
......@@ -168,3 +170,23 @@ def evaluate(args: dict = None) -> None:
plt.close(fig)
else:
print(f"Evaluation function not found {args.analysis_type}")
app = typer.Typer()
@app.command()
def evaluate2(eval_type: str, paths: List[str], save_path: str = None) -> None:
eval_types = [("wd", robofish.evaluate.evaluate2.evaluate_wd)]
func = [f for n, f in eval_types if n == eval_type]
assert len(func) != 0, f"Unknown eval_type {eval_type}"
assert len(func) == 1, f"Ambiguous eval_type {eval_type}"
func = func[0]
func(paths, save_path=save_path)
if __name__ == "__main__":
app()
import warnings
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
try:
import fish_models
except ImportError:
warnings.warn("fish_models not found. evaluate_2 won't work.")
def evaluate_wd(paths=None, dsets=None, save_path=None):
assert (paths is None) != (dsets is None), "Either paths or dsets must be given."
if paths is not None:
dsets = [fish_models.IoDataset(p) for p in paths]
fig, ax = plt.subplots(1, (len(dsets)), figsize=(15, 5), squeeze=False)
for di, dset in enumerate(dsets):
world_size = dset.world_size
assert world_size[0] == world_size[1], "World is not square"
world_size = world_size[0]
wd = world_size / 2 - np.max(dset["poses"][..., :2], axis=(-1))
for f_wd in wd:
ax[0, di].plot(f_wd.T)
ax[0, di].set_title(f"{Path(paths[di]).name}")
ax[0, di].set_xlabel("Time")
ax[0, di].set_ylabel("Wall distance")
if save_path is None:
plt.show()
else:
plt.savefig("wd.png")
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment