diff --git a/src/robofish/evaluate/app.py b/src/robofish/evaluate/app.py index 00e5c5488caa9ee77bdc3694fccf93df79fbe0f3..c037b583f82cac3fa5fa146b434b13d8fd3c410f 100644 --- a/src/robofish/evaluate/app.py +++ b/src/robofish/evaluate/app.py @@ -10,6 +10,7 @@ Functions available to be used in the commandline to evaluate robofish.io files. # Last doku update Feb 2021 import robofish.evaluate + import argparse from pathlib import Path import matplotlib.pyplot as plt diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py index 9b169b1c7bc06a4b821a34e44e53593f7d0537cc..02d4af1db6c00549033ee2f28ffb495184ee015d 100644 --- a/src/robofish/evaluate/evaluate.py +++ b/src/robofish/evaluate/evaluate.py @@ -24,77 +24,6 @@ import inspect import random -def get_all_files_from_paths(paths: Iterable[Union[str, Path]]): - # Find all files with correct ending - files = [] - for path in [Path(p) for p in paths]: - if path.is_dir(): - files_path = [] - for ext in ("hdf", "hdf5", "h5", "he5"): - files_path += list(path.rglob(f"*.{ext}")) - files.append(files_path) - else: - files.append([path]) - return files - - -def get_all_poses_from_paths(paths: Iterable[Union[str, Path]], predicate=None): - """Read all poses from given paths. - - The function shall be used by the evaluation functions. - - Args: - paths: An array of strings, with files or folders. - The files are checked to have the same frequency. - Returns: - An array, containing poses with the shape - [paths][files][entities, timesteps, 4], - the common frequency of the files - """ - return get_all_data_from_paths(paths, "poses_4d", predicate) - - -def get_all_data_from_paths( - paths: Iterable[Union[str, Path]], request_type="poses_4d", predicate=None -): - expected_settings = None - all_data = [] - - files_per_path = get_all_files_from_paths(paths) - - # for each given path - for files_in_path in files_per_path: - data_from_files = [] - - # Open all files, gather the poses and check if all of them have the same world size and frequency - for i_path, file_path in enumerate(files_in_path): - with robofish.io.File(file_path, "r") as file: - file_settings = { - "world_size_cm_x": file.attrs["world_size_cm"][0], - "world_size_cm_y": file.attrs["world_size_cm"][1], - "frequency_hz": file.frequency, - } - if expected_settings is None: - expected_settings = file_settings - - assert file_settings == expected_settings - - properties = { - "poses_4d": robofish.io.Entity.poses, - "speeds_turns": robofish.io.Entity.speed_turn, - } - - pred = None if predicate is None else predicate[i_path] - data = file.select_entity_property( - pred, entity_property=properties[request_type] - ) - data_from_files.append(data) - - all_data.append(data_from_files) - - return all_data, expected_settings - - def evaluate_speed( paths: Iterable[Union[str, Path]], labels: Iterable[str] = None, @@ -113,7 +42,7 @@ def evaluate_speed( """ if speeds_turns_from_paths is None: - speeds_turns_from_paths, file_settings = get_all_data_from_paths( + speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths( paths, "speeds_turns" ) @@ -172,7 +101,7 @@ def evaluate_turn( """ if speeds_turns_from_paths is None: - speeds_turns_from_paths, file_settings = get_all_data_from_paths( + speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths( paths, "speeds_turns" ) @@ -233,7 +162,7 @@ def evaluate_orientation( (example: lambda e: e.category == "fish") """ if poses_from_paths is None: - poses_from_paths, file_settings = get_all_poses_from_paths(paths) + poses_from_paths, file_settings = utils.get_all_poses_from_paths(paths) world_bounds = [ -file_settings["world_size_cm_x"] / 2, @@ -317,7 +246,7 @@ def evaluate_relative_orientation( """ if poses_from_paths is None: - poses_from_paths, file_settings = get_all_poses_from_paths(paths) + poses_from_paths, file_settings = utils.get_all_poses_from_paths(paths) orientations = [] # Iterate all paths @@ -367,7 +296,7 @@ def evaluate_distance_to_wall( """ if poses_from_paths is None: - poses_from_paths, file_settings = get_all_poses_from_paths(paths) + poses_from_paths, file_settings = utils.get_all_poses_from_paths(paths) world_bounds = [ -file_settings["world_size_cm_x"] / 2, @@ -453,7 +382,7 @@ def evaluate_tank_position( poses_step = 20 if poses_from_paths is None: - poses_from_paths, file_settings = get_all_poses_from_paths(paths) + poses_from_paths, file_settings = utils.get_all_poses_from_paths(paths) xy_positions = [] @@ -512,7 +441,7 @@ def evaluate_social_vector( """ if poses_from_paths is None: - poses_from_paths, file_settings = get_all_poses_from_paths(paths) + poses_from_paths, file_settings = utils.get_all_poses_from_paths(paths) socialVec = [] @@ -585,7 +514,7 @@ def evaluate_follow_iid( (example: lambda e: e.category == "fish") """ if poses_from_paths is None: - poses_from_paths, file_settings = get_all_poses_from_paths(paths) + poses_from_paths, file_settings = utils.get_all_poses_from_paths(paths) follow, iid = [], [] @@ -708,7 +637,7 @@ def evaluate_tracks( random.seed(seed) - files_per_path = get_all_files_from_paths(paths) + files_per_path = utils.get_all_files_from_paths(paths) max_files_per_path = max([len(files) for files in files_per_path]) rows, cols = len(files_per_path), min(4, max_files_per_path) @@ -780,13 +709,15 @@ def evaluate_individuals( """ if speeds_turns_from_paths is None and mode == "speed": - speeds_turns_from_paths, file_settings = get_all_data_from_paths( + speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths( paths, "speeds_turns" ) if poses_from_paths is None and mode == "iid": - poses_from_paths, file_settings = get_all_data_from_paths(paths, "poses_4d") + poses_from_paths, file_settings = utils.get_all_data_from_paths( + paths, "poses_4d" + ) - files_from_paths = get_all_files_from_paths(paths) + files_from_paths = utils.get_all_files_from_paths(paths) fig = plt.figure(figsize=(10, 4)) # small_iid_files = [] @@ -798,7 +729,8 @@ def evaluate_individuals( for f, file_path in enumerate(files_in_paths): if mode == "speed": metric = ( - speeds_turns_from_paths[k][f][:, 0] * file_settings["frequency_hz"] + speeds_turns_from_paths[k][f][..., 0] + * file_settings["frequency_hz"] ) elif mode == "iid": poses = poses_from_paths[k][f] @@ -894,8 +826,8 @@ def evaluate_all( fdict.pop("all") print("Loading all poses and actions.") - poses_from_paths, file_settings = get_all_poses_from_paths(paths) - speeds_turns_from_paths, file_settings = get_all_data_from_paths( + poses_from_paths, file_settings = utils.get_all_poses_from_paths(paths) + speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths( paths, "speeds_turns" ) diff --git a/src/robofish/io/app.py b/src/robofish/io/app.py index 0518dedd81d5f9c4842583d1e687b14d7e353d4a..e555f92f88447277179110183ae2cb4a7ae16ced 100644 --- a/src/robofish/io/app.py +++ b/src/robofish/io/app.py @@ -11,6 +11,8 @@ # ----------------------------------------------------------- import robofish.io +from robofish.io import utils + import argparse import logging @@ -77,25 +79,30 @@ def validate(args=None): logging.getLogger().setLevel(logging.ERROR) - sf_dict = robofish.io.read_multiple_files(args.path) + files_per_path = utils.get_all_files_from_paths(args.path) + files = [ + f for f_in_path in files_per_path for f in f_in_path + ] # Concatenate all files to one list - if len(sf_dict) == 0: + if len(files) == 0: logging.getLogger().setLevel(logging.INFO) logging.info("No files found in %s" % args.path) return + validity_dict = {} + for fp in files: + with robofish.io.File(fp) as f: + validity_dict[str(fp)] = f.validate(strict_validate=False) + if args.output_format == "raw": - sf_dict = { - (str)(f): sf.validate(strict_validate=False)[0] for f, sf in sf_dict.items() - } - return sf_dict + return validity_dict - max_filename_width = max([len((str)(f)) for f in sf_dict.keys()]) + max_filename_width = max([len(str(f)) for f in files]) error_code = 0 - for file, sf in sf_dict.items(): - filled_file = (str)(file).ljust(max_filename_width + 3) - validity, validity_message = sf.validate(strict_validate=False) - sf.close() + for fp, (validity, validity_message) in validity_dict.items(): + + filled_file = (str)(fp).ljust(max_filename_width + 3) + if not validity: error_code = 1 print(f"{filled_file}:{validity}\t{validity_message}") diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index 0637e8961c5125d92d90979afecb5243e9e150bf..f51815476905a0e87169c54016a7ee554f4c00bd 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -65,6 +65,7 @@ class File(h5py.File): mode: str = "r", *, # PEP 3102 world_size_cm: List[int] = None, + validate: bool = False, strict_validate: bool = False, format_version: List[int] = default_format_version, format_url: str = default_format_url, @@ -93,6 +94,8 @@ class File(h5py.File): world_size_cm : [int, int] , optional side lengths [x, y] of the world in cm. rectangular world shape is assumed. + validate: bool, default=False + Should the track be validated? This is normally switched off for performance reasons. strict_validate : bool, default=False if the file should be strictly validated against the track format specification, when loaded from a path. @@ -210,7 +213,8 @@ class File(h5py.File): calendar_time_points=calendar_time_points, default=True, ) - self.validate(strict_validate) + if validate: + self.validate(strict_validate) def __enter__(self): return self @@ -994,7 +998,7 @@ class File(h5py.File): update, frames=n_frames, init_func=init, - blit=False, + blit=True, interval=self.frequency, repeat=False, ) diff --git a/src/robofish/io/utils.py b/src/robofish/io/utils.py index 50be83f8bf4a290aaf23dea88aaf58c65f5816a0..bac783619ddc0be738f092105d2f6be36f15b26f 100644 --- a/src/robofish/io/utils.py +++ b/src/robofish/io/utils.py @@ -38,3 +38,74 @@ def limit_angle_range(angle: Union[float, Iterable], _range=(-np.pi, np.pi)): else: angle = limit_simple(angle) return angle + + +def get_all_files_from_paths(paths: Iterable[Union[str, Path]]): + # Find all files with correct ending + files = [] + for path in [Path(p) for p in paths]: + if path.is_dir(): + files_path = [] + for ext in ("hdf", "hdf5", "h5", "he5"): + files_path += list(path.rglob(f"*.{ext}")) + files.append(files_path) + else: + files.append([path]) + return files + + +def get_all_poses_from_paths(paths: Iterable[Union[str, Path]], predicate=None): + """Read all poses from given paths. + + The function shall be used by the evaluation functions. + + Args: + paths: An array of strings, with files or folders. + The files are checked to have the same frequency. + Returns: + An array, containing poses with the shape + [paths][files][entities, timesteps, 4], + the common frequency of the files + """ + return get_all_data_from_paths(paths, "poses_4d", predicate) + + +def get_all_data_from_paths( + paths: Iterable[Union[str, Path]], request_type="poses_4d", predicate=None +): + expected_settings = None + all_data = [] + + files_per_path = get_all_files_from_paths(paths) + + # for each given path + for files_in_path in files_per_path: + data_from_files = [] + + # Open all files, gather the poses and check if all of them have the same world size and frequency + for i_path, file_path in enumerate(files_in_path): + with robofish.io.File(file_path, "r") as file: + file_settings = { + "world_size_cm_x": file.attrs["world_size_cm"][0], + "world_size_cm_y": file.attrs["world_size_cm"][1], + "frequency_hz": file.frequency, + } + if expected_settings is None: + expected_settings = file_settings + + assert file_settings == expected_settings + + properties = { + "poses_4d": robofish.io.Entity.poses, + "speeds_turns": robofish.io.Entity.actions_speeds_turns, + } + + pred = None if predicate is None else predicate[i_path] + data = file.select_entity_property( + pred, entity_property=properties[request_type] + ) + data_from_files.append(data) + + all_data.append(data_from_files) + + return all_data, expected_settings diff --git a/tests/robofish/io/test_app_io.py b/tests/robofish/io/test_app_io.py index 7d5ffd8e82ab7065f4122c2399080030957fa7d7..e8809a63bceffa3fbce68ba5f19019a92f8ac9ae 100644 --- a/tests/robofish/io/test_app_io.py +++ b/tests/robofish/io/test_app_io.py @@ -19,11 +19,11 @@ def test_app_validate(): self.path = path self.output_format = output_format - raw_output = app.validate(DummyArgs(resources_path, "raw")) + raw_output = app.validate(DummyArgs([resources_path], "raw")) # The three files valid.hdf5, almost_valid.hdf5, and invalid.hdf5 should be found. assert len(raw_output) == 4 - app.validate(DummyArgs(resources_path, "human")) + app.validate(DummyArgs([resources_path], "human")) def test_app_print(): diff --git a/tests/robofish/io/test_io.py b/tests/robofish/io/test_io.py index ad0f68f64a9a35785a065c6b71d8730fb3bfe7c8..f13d28b525f55ed15a6de6b4876605ee8b96d648 100644 --- a/tests/robofish/io/test_io.py +++ b/tests/robofish/io/test_io.py @@ -1,12 +1,4 @@ import robofish.io -from robofish.io import utils -import pytest -from pathlib import Path -import numpy as np - - -resources_path = utils.full_path(__file__, "../../resources/") -h5py_file = utils.full_path(__file__, "../../resources/valid_1.hdf5") def test_now_iso8061(): @@ -14,43 +6,3 @@ def test_now_iso8061(): time = robofish.io.now_iso8061() assert type(time) == str assert len(time) == 32 - - -def test_read_multiple_single(): - path = h5py_file - - # Variants path as posix path or as string - for sf in [ - robofish.io.read_multiple_files(h5py_file), - robofish.io.read_multiple_files(str(h5py_file)), - ]: - assert len(sf) == 1 - for p, f in sf.items(): - assert p == path - assert type(f) == robofish.io.File - - -def test_read_multiple_folder(): - # Variants path as posix path or as string - for sf in [ - robofish.io.read_multiple_files(resources_path), - robofish.io.read_multiple_files(str(resources_path)), - ]: - # Should find the 4 available hdf5 files - assert len(sf) == 4 - for p, f in sf.items(): - print(p) - assert type(f) == robofish.io.File - - -# TODO read from folder of valid files -@pytest.mark.parametrize("param_path", [h5py_file, str(h5py_file)]) -def test_read_poses_rad_from_multiple_folder(param_path): - poses = robofish.io.read_property_from_multiple_files( - [param_path, param_path], robofish.io.entity.Entity.poses_rad - ) - # Should find the 3 presaved hdf5 files - assert len(poses) == 2 - for p in poses: - print(p) - assert type(p) == np.ndarray diff --git a/tests/robofish/evaluate/test_evaluate.py b/tests/robofish/io/test_utils.py similarity index 67% rename from tests/robofish/evaluate/test_evaluate.py rename to tests/robofish/io/test_utils.py index 1059d1ee71881ac0507673517c12e2ba0986e0bc..ed607df8253d5290b35d6f6ed3b28b2c484a0733 100644 --- a/tests/robofish/evaluate/test_evaluate.py +++ b/tests/robofish/io/test_utils.py @@ -1,12 +1,13 @@ -import pytest -import robofish.evaluate +from robofish.io import utils + from robofish.io import utils import numpy as np def test_get_all_poses_from_paths(): valid_file_path = utils.full_path(__file__, "../../resources/valid_1.hdf5") - poses, file_settings = robofish.evaluate.get_all_poses_from_paths([valid_file_path]) + poses, file_settings = utils.get_all_poses_from_paths([valid_file_path]) # (1 input array, 1 file, 2 fishes, 100 timesteps, 4 poses) assert np.array(poses).shape == (1, 1, 2, 100, 4) + type(file_settings) == dict