diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index 1dfb4d1d512c1486380d028c853ece7f66e67f43..0d8396de57681daa208bffab879faf12f0f17fad 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -176,8 +176,19 @@ class File(h5py.File): # x Create file, fail if exists # a Read/write if exists, create otherwise logging.info(f"Opening File {path}") - initialize = not Path(path).exists() - super().__init__(path, mode, libver=("earliest", "v110")) + + assert mode in ["r", "r+", "w", "x", "a"], f"Unknown mode {mode}." + + # If the file does not exist or if it should be truncated with mode=w, initialize it. + if Path(path).exists() and mode != "w": + initialize = False + else: + initialize = True + + try: + super().__init__(path, mode, libver=("earliest", "v110")) + except OSError as e: + raise OSError(f"Could not open file {path} with mode {mode}.\n{e}") if initialize: assert world_size_cm is not None and format_version is not None @@ -264,13 +275,30 @@ class File(h5py.File): sampling = self["samplings"].create_group(name) - if frequency_hz is not None: - sampling.attrs["frequency_hz"] = (np.float32)(frequency_hz) if monotonic_time_points_us is not None: + + monotonic_time_points_us = np.array( + monotonic_time_points_us, dtype=np.int64 + ) sampling.create_dataset( - "monotonic_time_points_us", - data=np.array(monotonic_time_points_us, dtype=np.int64), + "monotonic_time_points_us", data=monotonic_time_points_us ) + if frequency_hz is None: + diff = np.diff(monotonic_time_points_us) + if np.all(diff == diff[0]) and diff[0] > 0: + frequency_hz = 1e6 / diff[0] + warnings.warn( + f"The frequency_hz of {frequency_hz:.2f}hz was calculated automatically by robofish.io. The safer variant is to pass it using frequency_hz.\nThis is important when using fish_models with the files." + ) + + else: + warnings.warn( + "The frequency_hz could not be calculated automatically. When using fish_models, the file will access frequency_hz." + ) + + if frequency_hz is not None: + sampling.attrs["frequency_hz"] = (np.float32)(frequency_hz) + if calendar_time_points is not None: def format_calendar_time_point(p): @@ -314,8 +342,9 @@ class File(h5py.File): @property def default_sampling(self): - if not "samplings" in self: - print("Wassss?", self) + assert ( + "samplings" in self + ), "The file does not have a group 'sampling' which is required." if "default" in self["samplings"].attrs: return self["samplings"].attrs["default"] return None