Skip to content
Snippets Groups Projects
Commit dbc34191 authored by Andi Gerken's avatar Andi Gerken
Browse files

Added max files argument in evaluation.

parent aaa78007
No related branches found
No related tags found
No related merge requests found
Pipeline #49293 passed
...@@ -533,8 +533,6 @@ def evaluate_follow_iid( ...@@ -533,8 +533,6 @@ def evaluate_follow_iid(
grids = [] grids = []
# Exclude possible nans
# Find the 0.5%/ 99.5% quantiles as min and max for the axis # Find the 0.5%/ 99.5% quantiles as min and max for the axis
follow_range = np.max( follow_range = np.max(
[ [
...@@ -678,23 +676,27 @@ def evaluate_tracks( ...@@ -678,23 +676,27 @@ def evaluate_tracks(
def evaluate_individual_speed( def evaluate_individual_speed(
speeds_turns_from_paths=None, file_settings=None, **kwargs speeds_turns_from_paths=None, file_settings=None, max_files=None, **kwargs
): ):
"""Evaluate the average speeds per individual""" """Evaluate the average speeds per individual"""
return evaluate_individuals( return evaluate_individuals(
speeds_turns_from_paths=speeds_turns_from_paths, speeds_turns_from_paths=speeds_turns_from_paths,
file_settings=file_settings, file_settings=file_settings,
mode="speed", mode="speed",
max_files=max_files,
**kwargs, **kwargs,
) )
def evaluate_individual_iid(poses_from_paths=None, file_settings=None, **kwargs): def evaluate_individual_iid(
poses_from_paths=None, file_settings=None, max_files=None, **kwargs
):
"""Evaluate the average iid per file""" """Evaluate the average iid per file"""
return evaluate_individuals( return evaluate_individuals(
poses_from_paths=poses_from_paths, poses_from_paths=poses_from_paths,
file_settings=file_settings, file_settings=file_settings,
mode="iid", mode="iid",
max_files=max_files,
**kwargs, **kwargs,
) )
...@@ -709,6 +711,7 @@ def evaluate_individuals( ...@@ -709,6 +711,7 @@ def evaluate_individuals(
predicate=None, predicate=None,
lw_distances=False, lw_distances=False,
threshold=7, threshold=7,
max_files=None,
): ):
"""Evaluate the average speeds per individual. """Evaluate the average speeds per individual.
...@@ -723,14 +726,14 @@ def evaluate_individuals( ...@@ -723,14 +726,14 @@ def evaluate_individuals(
if speeds_turns_from_paths is None and mode == "speed": if speeds_turns_from_paths is None and mode == "speed":
speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths( speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths(
paths, "speeds_turns" paths, "speeds_turns", max_files=max_files
) )
if poses_from_paths is None and mode == "iid": if poses_from_paths is None and mode == "iid":
poses_from_paths, file_settings = utils.get_all_data_from_paths( poses_from_paths, file_settings = utils.get_all_data_from_paths(
paths, "poses_4d" paths, "poses_4d", max_files=max_files
) )
files_from_paths = utils.get_all_files_from_paths(paths) files_from_paths = utils.get_all_files_from_paths(paths, max_files)
fig = plt.figure(figsize=(10, 4)) fig = plt.figure(figsize=(10, 4))
# small_iid_files = [] # small_iid_files = []
...@@ -813,6 +816,7 @@ def evaluate_all( ...@@ -813,6 +816,7 @@ def evaluate_all(
save_folder: Path = None, save_folder: Path = None,
fdict: dict = None, fdict: dict = None,
predicate=None, predicate=None,
max_files=None,
): ):
"""Generate all evaluation graphs and save them to a folder. """Generate all evaluation graphs and save them to a folder.
...@@ -838,15 +842,18 @@ def evaluate_all( ...@@ -838,15 +842,18 @@ def evaluate_all(
fdict = robofish.evaluate.app.function_dict() fdict = robofish.evaluate.app.function_dict()
fdict.pop("all") fdict.pop("all")
poses_from_paths, file_settings = utils.get_all_poses_from_paths(paths) poses_from_paths, file_settings = utils.get_all_poses_from_paths(
paths, max_files=max_files
)
speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths( speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths(
paths, "speeds_turns" paths, "speeds_turns", max_files=max_files
) )
input_dict = { input_dict = {
"poses_from_paths": poses_from_paths, "poses_from_paths": poses_from_paths,
"speeds_turns_from_paths": speeds_turns_from_paths, "speeds_turns_from_paths": speeds_turns_from_paths,
"file_settings": file_settings, "file_settings": file_settings,
"max_files": max_files,
} }
t = tqdm(fdict.items(), desc="Evaluation", leave=True) t = tqdm(fdict.items(), desc="Evaluation", leave=True)
......
...@@ -40,7 +40,7 @@ def limit_angle_range(angle: Union[float, Iterable], _range=(-np.pi, np.pi)): ...@@ -40,7 +40,7 @@ def limit_angle_range(angle: Union[float, Iterable], _range=(-np.pi, np.pi)):
return angle return angle
def get_all_files_from_paths(paths: Iterable[Union[str, Path]]): def get_all_files_from_paths(paths: Iterable[Union[str, Path]], max_files=None):
# Find all files with correct ending # Find all files with correct ending
files = [] files = []
for path in [Path(p) for p in paths]: for path in [Path(p) for p in paths]:
...@@ -48,13 +48,15 @@ def get_all_files_from_paths(paths: Iterable[Union[str, Path]]): ...@@ -48,13 +48,15 @@ def get_all_files_from_paths(paths: Iterable[Union[str, Path]]):
files_path = [] files_path = []
for ext in ("hdf", "hdf5", "h5", "he5"): for ext in ("hdf", "hdf5", "h5", "he5"):
files_path += list(path.rglob(f"*.{ext}")) files_path += list(path.rglob(f"*.{ext}"))
files.append(files_path) files.append(files_path[:max_files])
else: else:
files.append([path]) files.append([path])
return files return files
def get_all_poses_from_paths(paths: Iterable[Union[str, Path]], predicate=None): def get_all_poses_from_paths(
paths: Iterable[Union[str, Path]], predicate=None, max_files=None
):
"""Read all poses from given paths. """Read all poses from given paths.
The function shall be used by the evaluation functions. The function shall be used by the evaluation functions.
...@@ -67,16 +69,19 @@ def get_all_poses_from_paths(paths: Iterable[Union[str, Path]], predicate=None): ...@@ -67,16 +69,19 @@ def get_all_poses_from_paths(paths: Iterable[Union[str, Path]], predicate=None):
[paths][files][entities, timesteps, 4], [paths][files][entities, timesteps, 4],
the common frequency of the files the common frequency of the files
""" """
return get_all_data_from_paths(paths, "poses_4d", predicate) return get_all_data_from_paths(paths, "poses_4d", predicate, max_files=max_files)
def get_all_data_from_paths( def get_all_data_from_paths(
paths: Iterable[Union[str, Path]], request_type="poses_4d", predicate=None paths: Iterable[Union[str, Path]],
request_type="poses_4d",
predicate=None,
max_files=None,
): ):
expected_settings = None expected_settings = None
all_data = [] all_data = []
files_per_path = get_all_files_from_paths(paths) files_per_path = get_all_files_from_paths(paths, max_files)
pbar = tqdm( pbar = tqdm(
total=sum([len(files_in_path) for files_in_path in files_per_path]), total=sum([len(files_in_path) for files_in_path in files_per_path]),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment