Skip to content
Snippets Groups Projects
Commit 58ed1c08 authored by Andi Gerken's avatar Andi Gerken Committed by Andi Gerken
Browse files

added all_tracks evaluation to evaluate2

parent 3a136628
Branches
Tags
1 merge request!54robofish-io-evaluate2 and lazy imports
...@@ -176,13 +176,14 @@ app = typer.Typer() ...@@ -176,13 +176,14 @@ app = typer.Typer()
@app.command() @app.command()
def evaluate2(eval_type: str, paths: List[str], save_path: str = None) -> None: def evaluate2(eval_type: str, paths: List[str], save_path: str = None, max_files: int = None) -> None:
eval_types = [ eval_types = [
("wd", robofish.evaluate.evaluate2.evaluate_wd), ("wd", robofish.evaluate.evaluate2.evaluate_wd),
("wd_hist", robofish.evaluate.evaluate2.evaluate_wd_hist), ("wd_hist", robofish.evaluate.evaluate2.evaluate_wd_hist),
("iid", robofish.evaluate.evaluate2.evaluate_iid), ("iid", robofish.evaluate.evaluate2.evaluate_iid),
("iid_hist", robofish.evaluate.evaluate2.evaluate_iid_hist), ("iid_hist", robofish.evaluate.evaluate2.evaluate_iid_hist),
("all_tracks", robofish.evaluate.evaluate2.evaluate_all_tracks),
] ]
func = [f for n, f in eval_types if n == eval_type] func = [f for n, f in eval_types if n == eval_type]
...@@ -190,7 +191,7 @@ def evaluate2(eval_type: str, paths: List[str], save_path: str = None) -> None: ...@@ -190,7 +191,7 @@ def evaluate2(eval_type: str, paths: List[str], save_path: str = None) -> None:
assert len(func) == 1, f"Ambiguous eval_type {eval_type}" assert len(func) == 1, f"Ambiguous eval_type {eval_type}"
func = func[0] func = func[0]
func(paths, save_path=save_path) func(paths, save_path=save_path, max_files= max_files)
if __name__ == "__main__": if __name__ == "__main__":
......
...@@ -9,18 +9,25 @@ except ImportError: ...@@ -9,18 +9,25 @@ except ImportError:
warnings.warn("fish_models not found. evaluate_2 won't work.") warnings.warn("fish_models not found. evaluate_2 won't work.")
def load_dsets(paths, dsets): def load_dsets(paths, dsets, max_files):
assert (paths is None) != (dsets is None), "Either paths or dsets must be given." assert (paths is None) != (dsets is None), "Either paths or dsets must be given."
if paths is not None: if paths is not None:
dsets = [] dsets = []
for p in paths: for p in paths:
prepared_p = Path(p) / "prepared_dset.pkl" if max_files is None:
if prepared_p.exists(): prepared_p = Path(p) / "prepared_dset.pkl"
print("loading", prepared_p)
dset = fish_models.IoDataset.load(prepared_p)
else: else:
dset = fish_models.IoDataset(p) prepared_p = Path(p) / f"prepared_dset_{max_files}.pkl"
dset.save(prepared_p)
dset = None
if prepared_p.exists():
dset = fish_models.IoDataset.load(prepared_p)
if dset is None or (max_files is not None and len(dset.files) > max_files):
assert Path(p).glob("*.hdf5"), f"No hdf5 files found in {p}"
dset = fish_models.IoDataset(p, max_files=max_files)
dset.store(prepared_p)
dsets.append(dset) dsets.append(dset)
return dsets return dsets
...@@ -41,10 +48,10 @@ def simple_plot(dsets, paths, metric, ylabel): ...@@ -41,10 +48,10 @@ def simple_plot(dsets, paths, metric, ylabel):
ax[0, di].set_ylabel(ylabel) ax[0, di].set_ylabel(ylabel)
return fig, ax return fig, ax
def simple_hist_plot(dsets, paths, metric): def simple_hist_plot(dsets, paths, metric, n_bins=100):
fig, ax = plt.subplots(1, 1, figsize=(15, 5)) fig, ax = plt.subplots(1, 1, figsize=(15, 5))
metric_data = [metric(dset).flatten() for dset in dsets] metric_data = [metric(dset).flatten() for dset in dsets]
ax.hist(metric_data, bins=50, alpha=0.5, histtype="step", label=paths) ax.hist(metric_data, bins=n_bins, alpha=0.5, histtype="step", label=paths)
ax.set_ylabel("Frequency") ax.set_ylabel("Frequency")
ax.legend() ax.legend()
return ax return ax
...@@ -60,26 +67,44 @@ def calculate_iid(dset): ...@@ -60,26 +67,44 @@ def calculate_iid(dset):
diff = np.linalg.norm(dset["poses"][:, 0, : , :2] - dset["poses"][:, 1, :, :2], axis=-1) diff = np.linalg.norm(dset["poses"][:, 0, : , :2] - dset["poses"][:, 1, :, :2], axis=-1)
return diff return diff
def evaluate_wd(paths=None, dsets=None, save_path=None): def evaluate_wd(paths=None, dsets=None, save_path=None, max_files=None):
dsets = load_dsets(paths, dsets) dsets = load_dsets(paths, dsets, max_files)
simple_plot(dsets, paths, calculate_wd, "Wall distance") simple_plot(dsets, paths, calculate_wd, "Wall distance")
save_or_show(save_path) save_or_show(save_path)
def evaluate_wd_hist(paths=None, dsets=None, save_path=None): def evaluate_wd_hist(paths=None, dsets=None, save_path=None, max_files=None):
dsets = load_dsets(paths, dsets) dsets = load_dsets(paths, dsets, max_files)
ax = simple_hist_plot(dsets, paths, calculate_wd) ax = simple_hist_plot(dsets, paths, calculate_wd)
ax.set_title(f"Wall distance histogram") ax.set_title(f"Wall distance histogram")
ax.set_xlabel("Wall distance") ax.set_xlabel("Wall distance")
save_or_show(save_path) save_or_show(save_path)
def evaluate_iid(paths=None, dsets=None, save_path=None): def evaluate_iid(paths=None, dsets=None, save_path=None, max_files=None):
dsets = load_dsets(paths, dsets) dsets = load_dsets(paths, dsets, max_files)
simple_plot(dsets, paths, calculate_iid, "IID") simple_plot(dsets, paths, calculate_iid, "IID")
save_or_show(save_path) save_or_show(save_path)
def evaluate_iid_hist(paths=None, dsets=None, save_path=None): def evaluate_iid_hist(paths=None, dsets=None, save_path=None, max_files=None):
dsets = load_dsets(paths, dsets) dsets = load_dsets(paths, dsets, max_files)
ax = simple_hist_plot(dsets, paths, calculate_iid) ax = simple_hist_plot(dsets, paths, calculate_iid)
ax.set_title(f"IID histogram") ax.set_title(f"IID histogram")
ax.set_xlabel("IID") ax.set_xlabel("IID")
save_or_show(save_path)
def evaluate_all_tracks(paths=None, dsets=None, save_path=None, max_files=None):
dsets = load_dsets(paths, dsets, max_files)
fig, ax = plt.subplots(1, len(paths), figsize=(15, 5), squeeze=False)
for di, dset in enumerate(dsets):
if dset["poses"].shape[0] <= 20:
cmap = plt.get_cmap("tab20")
else:
cmap = plt.get_cmap("Blues")
for fi, f in enumerate(dset["poses"][:max_files]):
for e in f:
ax[0,di].plot(e[:, 0], e[:, 1], c=cmap(fi / dset["poses"].shape[0]), alpha=0.5)
ax[0,di].set_title(paths[di])
ax[0,di].set_xlabel("x [cm]")
ax[0,di].set_ylabel("y [cm]")
save_or_show(save_path) save_or_show(save_path)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment