diff --git a/src/robofish/evaluate/app.py b/src/robofish/evaluate/app.py index 6cd935823298984404231f5ff9345d0b50abc5b9..91f26042c104bdffd7835bda08d4b201d013c6e3 100644 --- a/src/robofish/evaluate/app.py +++ b/src/robofish/evaluate/app.py @@ -176,13 +176,14 @@ app = typer.Typer() @app.command() -def evaluate2(eval_type: str, paths: List[str], save_path: str = None) -> None: +def evaluate2(eval_type: str, paths: List[str], save_path: str = None, max_files: int = None) -> None: eval_types = [ ("wd", robofish.evaluate.evaluate2.evaluate_wd), ("wd_hist", robofish.evaluate.evaluate2.evaluate_wd_hist), ("iid", robofish.evaluate.evaluate2.evaluate_iid), ("iid_hist", robofish.evaluate.evaluate2.evaluate_iid_hist), + ("all_tracks", robofish.evaluate.evaluate2.evaluate_all_tracks), ] func = [f for n, f in eval_types if n == eval_type] @@ -190,7 +191,7 @@ def evaluate2(eval_type: str, paths: List[str], save_path: str = None) -> None: assert len(func) == 1, f"Ambiguous eval_type {eval_type}" func = func[0] - func(paths, save_path=save_path) + func(paths, save_path=save_path, max_files= max_files) if __name__ == "__main__": diff --git a/src/robofish/evaluate/evaluate2.py b/src/robofish/evaluate/evaluate2.py index 7b7bea62d94bc435981cd411afc1d6cedc1290d6..13346e6aa4a20d7e00ffdd5f28035cb2f8df5497 100644 --- a/src/robofish/evaluate/evaluate2.py +++ b/src/robofish/evaluate/evaluate2.py @@ -9,18 +9,25 @@ except ImportError: warnings.warn("fish_models not found. evaluate_2 won't work.") -def load_dsets(paths, dsets): +def load_dsets(paths, dsets, max_files): assert (paths is None) != (dsets is None), "Either paths or dsets must be given." if paths is not None: dsets = [] for p in paths: - prepared_p = Path(p) / "prepared_dset.pkl" - if prepared_p.exists(): - print("loading", prepared_p) - dset = fish_models.IoDataset.load(prepared_p) + if max_files is None: + prepared_p = Path(p) / "prepared_dset.pkl" else: - dset = fish_models.IoDataset(p) - dset.save(prepared_p) + prepared_p = Path(p) / f"prepared_dset_{max_files}.pkl" + + dset = None + if prepared_p.exists(): + dset = fish_models.IoDataset.load(prepared_p) + + if dset is None or (max_files is not None and len(dset.files) > max_files): + + assert Path(p).glob("*.hdf5"), f"No hdf5 files found in {p}" + dset = fish_models.IoDataset(p, max_files=max_files) + dset.store(prepared_p) dsets.append(dset) return dsets @@ -41,10 +48,10 @@ def simple_plot(dsets, paths, metric, ylabel): ax[0, di].set_ylabel(ylabel) return fig, ax -def simple_hist_plot(dsets, paths, metric): +def simple_hist_plot(dsets, paths, metric, n_bins=100): fig, ax = plt.subplots(1, 1, figsize=(15, 5)) metric_data = [metric(dset).flatten() for dset in dsets] - ax.hist(metric_data, bins=50, alpha=0.5, histtype="step", label=paths) + ax.hist(metric_data, bins=n_bins, alpha=0.5, histtype="step", label=paths) ax.set_ylabel("Frequency") ax.legend() return ax @@ -60,26 +67,44 @@ def calculate_iid(dset): diff = np.linalg.norm(dset["poses"][:, 0, : , :2] - dset["poses"][:, 1, :, :2], axis=-1) return diff -def evaluate_wd(paths=None, dsets=None, save_path=None): - dsets = load_dsets(paths, dsets) +def evaluate_wd(paths=None, dsets=None, save_path=None, max_files=None): + dsets = load_dsets(paths, dsets, max_files) simple_plot(dsets, paths, calculate_wd, "Wall distance") save_or_show(save_path) -def evaluate_wd_hist(paths=None, dsets=None, save_path=None): - dsets = load_dsets(paths, dsets) +def evaluate_wd_hist(paths=None, dsets=None, save_path=None, max_files=None): + dsets = load_dsets(paths, dsets, max_files) ax = simple_hist_plot(dsets, paths, calculate_wd) ax.set_title(f"Wall distance histogram") ax.set_xlabel("Wall distance") save_or_show(save_path) -def evaluate_iid(paths=None, dsets=None, save_path=None): - dsets = load_dsets(paths, dsets) +def evaluate_iid(paths=None, dsets=None, save_path=None, max_files=None): + dsets = load_dsets(paths, dsets, max_files) simple_plot(dsets, paths, calculate_iid, "IID") save_or_show(save_path) -def evaluate_iid_hist(paths=None, dsets=None, save_path=None): - dsets = load_dsets(paths, dsets) +def evaluate_iid_hist(paths=None, dsets=None, save_path=None, max_files=None): + dsets = load_dsets(paths, dsets, max_files) ax = simple_hist_plot(dsets, paths, calculate_iid) ax.set_title(f"IID histogram") ax.set_xlabel("IID") + save_or_show(save_path) + +def evaluate_all_tracks(paths=None, dsets=None, save_path=None, max_files=None): + dsets = load_dsets(paths, dsets, max_files) + fig, ax = plt.subplots(1, len(paths), figsize=(15, 5), squeeze=False) + + for di, dset in enumerate(dsets): + if dset["poses"].shape[0] <= 20: + cmap = plt.get_cmap("tab20") + else: + cmap = plt.get_cmap("Blues") + for fi, f in enumerate(dset["poses"][:max_files]): + for e in f: + ax[0,di].plot(e[:, 0], e[:, 1], c=cmap(fi / dset["poses"].shape[0]), alpha=0.5) + + ax[0,di].set_title(paths[di]) + ax[0,di].set_xlabel("x [cm]") + ax[0,di].set_ylabel("y [cm]") save_or_show(save_path) \ No newline at end of file