Skip to content
Snippets Groups Projects

Added calculation of individual ids

Merged Andi Gerken requested to merge develop into master
Files
8
@@ -18,7 +18,7 @@ import matplotlib.gridspec as gridspec
import seaborn as sns
import numpy as np
import pandas as pd
from typing import Iterable, Union, Callable
from typing import Iterable, Union, Callable, List
from scipy import stats
from tqdm import tqdm
import inspect
@@ -449,7 +449,7 @@ def evaluate_tank_position(
x=reduced_xy_positions[:, 0],
y=reduced_xy_positions[:, 1],
n_levels=20,
shade=True,
fill=True,
ax=ax[i],
)
@@ -494,7 +494,7 @@ def evaluate_quiver(
if poses_from_paths is None:
poses_from_paths, file_settings = utils.get_all_poses_from_paths(
paths, predicate
paths, predicate, predicate
)
if speeds_turns_from_paths is None:
speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths(
@@ -518,7 +518,7 @@ def evaluate_quiver(
speeds_turns_from_paths = np.array(speeds_turns_from_paths)
except Exception as e:
warnings.warn(
f"The conversion to numpy array failed:\nspeeds_turns_from_path was {speeds_turns_from_paths}. Exception was {e}"
f"The conversion to numpy array failed:\nspeeds_turns_from_path.shape was {speeds_turns_from_paths.shape}. Exception was {e}"
)
return
@@ -877,7 +877,7 @@ def evaluate_tracks(
files_per_path = utils.get_all_files_from_paths(paths)
max_files_per_path = max([len(files) for files in files_per_path])
rows, cols = len(files_per_path), min(4, max_files_per_path)
rows, cols = len(files_per_path), min(6, max_files_per_path)
multirow = False
if rows == 1 and cols == 4:
@@ -908,31 +908,31 @@ def evaluate_tracks(
initial_poses_info_available is not None
and initial_poses_info_available < k
):
if len(all_initial_info) > i:
initial_poses_info = all_initial_info[i]
initial_poses_info = all_initial_info[i]
preload_data = Path(initial_poses_info["preload_data"])
preload_data = Path(initial_poses_info["preload_data"])
assert (
preload_data.parts == paths[k].parts[-len(preload_data.parts) :]
), f"The second given path should correspond to the preload_data of the second path.\n{preload_data.parts}\n{paths[k][-len(preload_data.parts) :]}"
assert (
preload_data.parts == paths[k].parts[-len(preload_data.parts) :]
), f"The second given path should correspond to the preload_data of the second path.\nExpected preload: {preload_data.parts}\nGiven Path: {paths[k].parts[-len(preload_data.parts) :]}"
# Open the correct file instead.
# Open the correct file instead.
new_file_path = paths[k] / initial_poses_info["file_name"]
if verbose:
print(f"Using file {file_path}")
with robofish.io.File(new_file_path, "r") as new_file:
new_file_path = paths[k] / initial_poses_info["file_name"]
if verbose:
print(f"Using file {file_path}")
with robofish.io.File(new_file_path, "r") as new_file:
new_file.plot(
selected_ax,
lw_distances=lw_distances,
skip_timesteps=initial_poses_info["start"],
max_timesteps=initial_poses_info["n_timesteps"],
)
selected_ax.set_title(
f"{selected_ax.get_title()} (reference track: id {initial_poses_info['track_id']})",
)
new_file.plot(
selected_ax,
lw_distances=lw_distances,
skip_timesteps=initial_poses_info["start"],
max_timesteps=initial_poses_info["n_timesteps"],
)
selected_ax.set_title(
f"{selected_ax.get_title()} (reference track: id {initial_poses_info['track_id']})",
)
else:
with robofish.io.File(file_path, "r") as file:
@@ -1147,6 +1147,7 @@ def evaluate_all(
fdict: dict = None,
predicate: Callable[[robofish.io.Entity], bool] = None,
max_files: int = None,
evaluations: List[str] = None,
) -> Iterable[Path]:
"""Generate all evaluation graphs and save them to a folder.
@@ -1157,6 +1158,7 @@ def evaluate_all(
fdict(dict): The dictionary with the functions which should be called. If none is given, all functions are used.
predicate(Callable[[robofish.io.Entity], bool]): A lambda function, selecting entities.
max_files(int): The maximum number of files to plot.
evaluations(List[str]): A list of all evaluations to plot. Values should correspond to the keys in fdict.
Returns:
Iterable[Path]: An array of all paths to the created images.
"""
@@ -1164,7 +1166,7 @@ def evaluate_all(
save_folder is not None and save_folder != ""
), "Please provide a save_folder using --save_path"
save_folder.mkdir(exist_ok=True)
save_folder.mkdir(exist_ok=True, parents=True)
save_paths = []
if fdict is None:
@@ -1188,22 +1190,23 @@ def evaluate_all(
t = tqdm(fdict.items(), desc="Evaluation", leave=True)
for f_name, f_callable in t:
t.set_description(f_name)
t.refresh() # to show the update immediately
save_path = save_folder / (f_name + ".png")
requested_inputs = {
k: input_dict[k]
for k in inspect.signature(f_callable).parameters.keys()
if k in input_dict
}
fig = f_callable(
paths=paths, labels=labels, predicate=predicate, **requested_inputs
)
if fig is not None:
fig.savefig(save_path)
plt.close(fig)
save_paths.append(save_path)
if evaluations is None or f_name in evaluations:
t.set_description(f_name)
t.refresh() # to show the update immediately
save_path = save_folder / (f_name + ".png")
requested_inputs = {
k: input_dict[k]
for k in inspect.signature(f_callable).parameters.keys()
if k in input_dict
}
fig = f_callable(
paths=paths, labels=labels, predicate=predicate, **requested_inputs
)
if fig is not None:
fig.savefig(save_path)
plt.close(fig)
save_paths.append(save_path)
return save_paths
@@ -9,14 +9,53 @@ The files are saved in the `.hdf5` format and following the [Track Format Specif
import sys
import logging
import yaml
from robofish.io.file import *
from robofish.io.entity import *
from robofish.io.validation import *
from robofish.io.io import *
from robofish.io.utils import *
import robofish.io.app
if not ((3, 7) <= sys.version_info < (4, 0)):
logging.warning("Unsupported Python version")
def user_config(overwrite=True) -> dict:
"""Returns the user config.
Returns:
dict: Dict with all user settings.
"""
p = Path.home() / ".robofish/io" / "config.yaml"
print("Reading user config from", p)
if not p.exists() or overwrite:
logging.warning(
"User config not found (or overwrite). Creating file and searching for fish_models."
)
p.parent.mkdir(parents=True, exist_ok=True)
p.touch()
import imp
fm = imp.find_module("fish_models")[1]
p.write_text(yaml.dump({"fish_models_path": str(Path(fm).parents[1])}))
# Load config
with open(p, "r") as f:
config = yaml.safe_load(f)
if config is None:
config = {}
assert (
"fish_models_path" in config
), "fish_models_path: Please add the path to the fish_models package to the config file. Update the file automatically with the bash command `roborish-io-overwrite_user_configs`."
assert Path(
config["fish_models_path"]
).exists(), f"fish_models_path: The path to the fish_models package does not exist. {config['fish_models_path']}. Update the file automatically with the bash command `roborish-io-overwrite_user_configs`."
return config
+ 89
39
@@ -17,15 +17,22 @@ import argparse
import logging
import warnings
from tqdm.auto import tqdm
from typing import Dict
import itertools
import numpy as np
import multiprocessing
import copy
from pathlib import Path
def print_file(args=None):
def print_file(args: argparse.Namespace = None) -> bool:
"""This function can be used to print hdf5 files from the command line
Args:
args (argparse.Namespace, optional): The arguments for the render function. This is mainly used for testing.
Returns:
A human readable print of a given hdf5 file.
bool: A boolean if the file was invalid [True if invalid, False if not]
"""
parser = argparse.ArgumentParser(
description="This function can be used to print hdf5 files from the command line"
@@ -57,7 +64,12 @@ def print_file(args=None):
return not valid
def update_calculated_data(args=None):
def update_calculated_data(args: argparse.Namespace = None) -> None:
"""This function updates the calculated data of any number of files or folders.
Args:
args (argparse.Namespace, optional): The arguments for the render function. This is mainly used for testing.
"""
parser = argparse.ArgumentParser(
description="This function updates all calculated data from files."
)
@@ -92,45 +104,18 @@ def update_calculated_data(args=None):
print(e)
# This should not be neccessary since the data will always have the calculated data by default.
# def clear_calculated_data(args=None):
# parser = argparse.ArgumentParser(
# description="This function clears calculated data from robofish.io files."
# )
# parser.add_argument(
# "path",
# type=str,
# nargs="+",
# help="The path to one or multiple files and/or folders.",
# )
# if args is None:
# args = parser.parse_args()
# files_per_path = utils.get_all_files_from_paths(args.path)
# files = [
# f for f_in_path in files_per_path for f in f_in_path
# ] # Concatenate all files to one list
# assert len(files) > 0, f"No files found in path {args.path}."
# for fp in files:
# print(f"File {fp}")
# with robofish.io.File(
# fp, "r+", validate_poses_hash=False, store_calculated_data=False
# ) as f:
# f.clear_calculated_data()
def validate(args=None):
def validate(args: argparse.Namespace = None) -> int:
"""This function can be used to validate hdf5 files.
The function can be directly accessed from the commandline and can be given
any number of files or folders. The function returns the validity of the files
in a human readable format or as a raw output.
Args:
args (argparse.Namespace, optional): The arguments for the render function. This is mainly used for testing.
Returns:
A human readable table of each file and its validity
int: An error code. 0 if all files are valid, 1 if at least one file is invalid.
"""
parser = argparse.ArgumentParser(
description="The function can be directly accessed from the commandline and can be given any number of files or folders. The function returns the validity of the files in a human readable format or as a raw output."
@@ -187,13 +172,37 @@ def validate(args=None):
return error_code
def render(args=None):
def render_file(kwargs: Dict) -> None:
"""This function renders a single file.
Args:
kwargs (Dict, optional): A dictionary containing the arguments for the render function.
"""
with robofish.io.File(path=kwargs["path"]) as f:
f.render(**kwargs)
def overwrite_user_configs() -> None:
"""This function overwrites the user configs with the default config."""
robofish.io.user_config(overwrite=True)
def render(args: argparse.Namespace = None) -> None:
"""This function can be used to render hdf5 files.
The function can be directly accessed from the commandline and can be given
any number of files.
Args:
args (argparse.Namespace, optional): The arguments for the render function. This is mainly used for testing.
"""
parser = argparse.ArgumentParser(
description="This function shows the file as animation."
)
parser.add_argument(
"path",
type=str,
nargs="+",
help="The path to one file.",
)
@@ -205,6 +214,13 @@ def render(args=None):
help="Path to save the video to (mp4). If a path is given, the animation won't be played.",
)
parser.add_argument(
"--reference_track",
action="store_true",
help="If true, the reference track will be rendered in parallel.",
default=False,
)
default_options = {
"linewidth": 2,
"speedup": 1,
@@ -216,6 +232,7 @@ def render(args=None):
"cut_frames_start": 0,
"cut_frames_end": 0,
"show_text": False,
"show_ids": False,
"render_goals": False,
"render_targets": False,
"figsize": 10,
@@ -240,10 +257,43 @@ def render(args=None):
if args is None:
args = parser.parse_args()
print(args)
if args.reference_track:
assert (
len(args.path) == 1
), "Only one file can be rendered with the reference track."
f = robofish.io.File(path=args.path)
f.render(**vars(args))
# Load fish_models path from user_config.yaml
user_config = robofish.io.user_config()
with robofish.io.File(args.path[0]) as f:
data_folder = f["initial_poses_info"].attrs["preload_data"]
reference_track = f["initial_poses_info"].attrs["file_name"]
reference_track_path = (
Path(user_config["fish_models_path"])
/ "storage/raw_data"
/ data_folder
/ reference_track
)
args.path.append(str(reference_track_path))
if len(args.path) > 1:
print("Found multiple paths, starting multiprocessing.")
args_array = []
for v in args.path:
kwargs = copy.copy(vars(args))
kwargs["path"] = v
args_array.append(kwargs)
# Render multiple animations with multiprocessing
with multiprocessing.Pool() as pool:
pool.map(render_file, args_array)
else:
kwargs = vars(args)
kwargs["path"] = args.path[0]
render_file(kwargs)
def update_individual_ids(args=None):
Loading