diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index 3b2f1fc7a3fefb825fb46fd79e94f3fc7dd44bb4..05db9fa6dbd5ea6689dcf58d73aebdedef6e8d72 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -131,7 +131,8 @@ class File(h5py.File): def __exit__(self, type, value, traceback): # Check if the context was left under normal circumstances if (type, value, traceback) == (None, None, None): - self.validate() + if self.mode != "r": # No need to validate read only files (performance). + self.validate() self.close() def save_as(self, path: Union[str, Path], strict_validate: bool = True): diff --git a/src/robofish/io/io.py b/src/robofish/io/io.py index 5b626a4ec069ba7e548cd80fa28eb6f557fccc48..ab46b5ba67dfb4451f62a31a82805a40861bb598 100644 --- a/src/robofish/io/io.py +++ b/src/robofish/io/io.py @@ -10,7 +10,7 @@ import random def now_iso8061() -> str: - """ The current time as iso8061 string. + """The current time as iso8061 string. Returns: str: The current time as iso8061 string. @@ -24,10 +24,9 @@ def read_multiple_files( paths: Union[Path, str, Iterable[Path], Iterable[str]], strict_validate: bool = False, max_files: int = None, - shuffle: bool = False, ) -> dict: - """ Load hdf5 files from a given path. + """Load hdf5 files from a given path. The function can be given the path to a single single hdf5 file, to a folder, containing hdf5 files, or an array of multiple files or folders. @@ -35,7 +34,7 @@ def read_multiple_files( Args: path: The path to a hdf5 file or folder. strict_validate: Choice between error and warning in case of invalidity - max: Maximum number of files to be read + max_files: Maximum number of files to be read Returns: dict: A dictionary where the keys are filenames and the opened robofish.io.File objects """ @@ -49,8 +48,6 @@ def read_multiple_files( paths = [Path(p) for p in paths] sf_dict = {} - if shuffle: - random.shuffle(paths) for path in paths: if path.is_dir(): logging.info("found dir %s" % path) @@ -72,3 +69,65 @@ def read_multiple_files( logging.info("found file %s" % path) sf_dict[path] = robofish.io.File(path=path, strict_validate=strict_validate) return sf_dict + + +def read_poses_from_multiple_files( + paths: Union[Path, str, Iterable[Path], Iterable[str]], + strict_validate: bool = False, + max_files: int = None, + shuffle: bool = False, + ori_rad: bool = False, +): + """Load hdf5 files from a given path and return the entity poses. + + The function can be given the path to a single single hdf5 file, to a folder, + containing hdf5 files, or an array of multiple files or folders. + + Args: + path: The path to a hdf5 file or folder. + strict_validate: Choice between error and warning in case of invalidity + max_files: Maximum number of files to be read + shuffle: Shuffle the order of files + ori_rad: Return the orientations as radiants instead of unit vectors + Returns: + An array of all entity poses arrays + """ + + logging.info(f"Reading files from path {paths}") + + list_types = (list, np.ndarray, pandas.core.series.Series) + if not isinstance(paths, list_types): + paths = [paths] + + paths = [Path(p) for p in paths] + + poses_array = [] + for path in paths: + if path.is_dir(): + + logging.info("found dir %s" % path) + # Find all hdf5 files in folder + files = [] + for ext in ("hdf", "hdf5", "h5", "he5"): + files += list(path.rglob(f"*.{ext}")) + files = random.shuffle(files) if shuffle else sorted(files) + + logging.info("Reading files") + + for file in files: + if max_files is not None and len(poses_array) >= max_files: + break + + if not file.is_dir(): + with robofish.io.File( + path=file, strict_validate=strict_validate + ) as f: + p = f.entity_poses_rad if ori_rad else f.entity_poses + poses_array.append(p) + + elif path is not None and path.exists(): + logging.info("found file %s" % path) + with robofish.io.File(path=file, strict_validate=strict_validate) as f: + p = f.entity_poses_rad if ori_rad else f.entity_poses + poses_array.append(p) + return poses_array \ No newline at end of file diff --git a/src/robofish/io/validation.py b/src/robofish/io/validation.py index 8ea10175c7fed8d0e4a813856290116d86d83b54..f8ae264c83f8314deab290db71bd2318da939f71 100644 --- a/src/robofish/io/validation.py +++ b/src/robofish/io/validation.py @@ -235,7 +235,8 @@ def validate(iofile: File, strict_validate: bool = True) -> (bool, str): e_name, ) - validate_orientations_length(orientations, e_name) + if strict_validate: + validate_orientations_length(orientations, e_name) # outlines if "outlines" in entity: @@ -345,7 +346,7 @@ def validate_orientations_length(orientations, e_name): "The orientation vectors were not unit vectors. Their length was in the range [%.4f, %.4f] when it should be 1" % (min(ori_lengths), max(ori_lengths)), e_name, - strict_validate=True, + strict_validate=False, ) diff --git a/tests/robofish/io/test_io.py b/tests/robofish/io/test_io.py index f0609d2038a50d17a522479f2efbe57f3c13d301..13bd8761111f17e93945b3f7fa02dfe00ab712e2 100644 --- a/tests/robofish/io/test_io.py +++ b/tests/robofish/io/test_io.py @@ -2,6 +2,7 @@ import robofish.io from robofish.io import utils import pytest from pathlib import Path +import numpy as np def test_now_iso8061(): @@ -38,3 +39,16 @@ def test_read_multiple_folder(): for p, f in sf.items(): print(p) assert type(f) == robofish.io.File + + +path = utils.full_path(__file__, "../../resources/valid.hdf5") + +# TODO read from folder of valid files +@pytest.mark.parametrize("_path", [path, str(path)]) +def test_read_poses_from_multiple_folder(_path): + poses = robofish.io.read_poses_from_multiple_files([_path, _path]) + # Should find the 3 presaved hdf5 files + assert len(poses) == 2 + for p in poses: + print(p) + assert type(p) == np.ndarray