diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py index 54659cbda35722e51df2ab627bc41d74b0057cfb..bfb900cd175212eb45b67839d29513b43e333d92 100644 --- a/src/robofish/evaluate/evaluate.py +++ b/src/robofish/evaluate/evaluate.py @@ -533,8 +533,6 @@ def evaluate_follow_iid( grids = [] - # Exclude possible nans - # Find the 0.5%/ 99.5% quantiles as min and max for the axis follow_range = np.max( [ @@ -678,23 +676,27 @@ def evaluate_tracks( def evaluate_individual_speed( - speeds_turns_from_paths=None, file_settings=None, **kwargs + speeds_turns_from_paths=None, file_settings=None, max_files=None, **kwargs ): """Evaluate the average speeds per individual""" return evaluate_individuals( speeds_turns_from_paths=speeds_turns_from_paths, file_settings=file_settings, mode="speed", + max_files=max_files, **kwargs, ) -def evaluate_individual_iid(poses_from_paths=None, file_settings=None, **kwargs): +def evaluate_individual_iid( + poses_from_paths=None, file_settings=None, max_files=None, **kwargs +): """Evaluate the average iid per file""" return evaluate_individuals( poses_from_paths=poses_from_paths, file_settings=file_settings, mode="iid", + max_files=max_files, **kwargs, ) @@ -709,6 +711,7 @@ def evaluate_individuals( predicate=None, lw_distances=False, threshold=7, + max_files=None, ): """Evaluate the average speeds per individual. @@ -723,14 +726,14 @@ def evaluate_individuals( if speeds_turns_from_paths is None and mode == "speed": speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths( - paths, "speeds_turns" + paths, "speeds_turns", max_files=max_files ) if poses_from_paths is None and mode == "iid": poses_from_paths, file_settings = utils.get_all_data_from_paths( - paths, "poses_4d" + paths, "poses_4d", max_files=max_files ) - files_from_paths = utils.get_all_files_from_paths(paths) + files_from_paths = utils.get_all_files_from_paths(paths, max_files) fig = plt.figure(figsize=(10, 4)) # small_iid_files = [] @@ -813,6 +816,7 @@ def evaluate_all( save_folder: Path = None, fdict: dict = None, predicate=None, + max_files=None, ): """Generate all evaluation graphs and save them to a folder. @@ -838,15 +842,18 @@ def evaluate_all( fdict = robofish.evaluate.app.function_dict() fdict.pop("all") - poses_from_paths, file_settings = utils.get_all_poses_from_paths(paths) + poses_from_paths, file_settings = utils.get_all_poses_from_paths( + paths, max_files=max_files + ) speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths( - paths, "speeds_turns" + paths, "speeds_turns", max_files=max_files ) input_dict = { "poses_from_paths": poses_from_paths, "speeds_turns_from_paths": speeds_turns_from_paths, "file_settings": file_settings, + "max_files": max_files, } t = tqdm(fdict.items(), desc="Evaluation", leave=True) diff --git a/src/robofish/io/utils.py b/src/robofish/io/utils.py index e1349cdd8dda4bf4b3a08629e0178ddaa41b15a0..c17f8efebcef10e602e463991e72a4438d285f3c 100644 --- a/src/robofish/io/utils.py +++ b/src/robofish/io/utils.py @@ -40,7 +40,7 @@ def limit_angle_range(angle: Union[float, Iterable], _range=(-np.pi, np.pi)): return angle -def get_all_files_from_paths(paths: Iterable[Union[str, Path]]): +def get_all_files_from_paths(paths: Iterable[Union[str, Path]], max_files=None): # Find all files with correct ending files = [] for path in [Path(p) for p in paths]: @@ -48,13 +48,15 @@ def get_all_files_from_paths(paths: Iterable[Union[str, Path]]): files_path = [] for ext in ("hdf", "hdf5", "h5", "he5"): files_path += list(path.rglob(f"*.{ext}")) - files.append(files_path) + files.append(files_path[:max_files]) else: files.append([path]) return files -def get_all_poses_from_paths(paths: Iterable[Union[str, Path]], predicate=None): +def get_all_poses_from_paths( + paths: Iterable[Union[str, Path]], predicate=None, max_files=None +): """Read all poses from given paths. The function shall be used by the evaluation functions. @@ -67,16 +69,19 @@ def get_all_poses_from_paths(paths: Iterable[Union[str, Path]], predicate=None): [paths][files][entities, timesteps, 4], the common frequency of the files """ - return get_all_data_from_paths(paths, "poses_4d", predicate) + return get_all_data_from_paths(paths, "poses_4d", predicate, max_files=max_files) def get_all_data_from_paths( - paths: Iterable[Union[str, Path]], request_type="poses_4d", predicate=None + paths: Iterable[Union[str, Path]], + request_type="poses_4d", + predicate=None, + max_files=None, ): expected_settings = None all_data = [] - files_per_path = get_all_files_from_paths(paths) + files_per_path = get_all_files_from_paths(paths, max_files) pbar = tqdm( total=sum([len(files_in_path) for files_in_path in files_per_path]),