diff --git a/setup.py b/setup.py index 48573d8af70ab56fc4c1098994c2b9b3b4e7014f..a9ebe9c5d415412a41abe7e72af70e682d3e9601 100644 --- a/setup.py +++ b/setup.py @@ -10,7 +10,8 @@ entry_points = { "robofish-io-validate=robofish.io.app:validate", "robofish-io-print=robofish.io.app:print_file", "robofish-io-render=robofish.io.app:render", - "robofish-io-clear-calculated-data=robofish.io.app:clear_calculated_data", + # "robofish-io-clear-calculated-data=robofish.io.app:clear_calculated_data", + "robofish-io-update-calculated-data=robofish.io.app:update_calculated_data", # TODO: This should be called robofish-evaluate which is not possible because of the package name (guess) ask moritz "robofish-io-evaluate=robofish.evaluate.app:evaluate", ] diff --git a/src/robofish/io/__init__.py b/src/robofish/io/__init__.py index b14b319e8e1899321e1140140d0846dcee92662e..5b054c2fb71018dd23ceb315d83bf83f51cc1cac 100644 --- a/src/robofish/io/__init__.py +++ b/src/robofish/io/__init__.py @@ -20,9 +20,3 @@ import robofish.io.app if not ((3, 7) <= sys.version_info < (4, 0)): logging.warning("Unsupported Python version") - -warn_when_unable_to_store = True - - -def disable_warning_when_unable_to_store(): - robofish.io.warn_when_unable_to_store = False diff --git a/src/robofish/io/app.py b/src/robofish/io/app.py index 15b6ebe15cb09eac6bf38b901b86184bf0da5077..c6a93852d63d07d4bc7b598075e280759dbd4992 100644 --- a/src/robofish/io/app.py +++ b/src/robofish/io/app.py @@ -15,6 +15,7 @@ from robofish.io import utils import argparse import logging +import warnings def print_file(args=None): @@ -53,9 +54,9 @@ def print_file(args=None): return not valid -def clear_calculated_data(args=None): +def update_calculated_data(args=None): parser = argparse.ArgumentParser( - description="This function clears calculated data from robofish.io files." + description="This function updates all calculated data from files." ) parser.add_argument( @@ -76,8 +77,42 @@ def clear_calculated_data(args=None): for fp in files: print(f"File {fp}") - with robofish.io.File(fp, "r+") as f: - f.clear_calculated_data() + try: + with robofish.io.File(fp, "r+", validate_poses_hash=False) as f: + f.update_calculated_data(verbose=True) + except Exception as e: + warnings.warn(f"The file {fp} could not be updated.") + print(e) + + +# This should not be neccessary since the data will always have the calculated data by default. +# def clear_calculated_data(args=None): +# parser = argparse.ArgumentParser( +# description="This function clears calculated data from robofish.io files." +# ) + +# parser.add_argument( +# "path", +# type=str, +# nargs="+", +# help="The path to one or multiple files and/or folders.", +# ) +# if args is None: +# args = parser.parse_args() + +# files_per_path = utils.get_all_files_from_paths(args.path) +# files = [ +# f for f_in_path in files_per_path for f in f_in_path +# ] # Concatenate all files to one list + +# assert len(files) > 0, f"No files found in path {args.path}." + +# for fp in files: +# print(f"File {fp}") +# with robofish.io.File( +# fp, "r+", validate_poses_hash=False, store_calculated_data=False +# ) as f: +# f.clear_calculated_data() def validate(args=None): @@ -162,7 +197,7 @@ def render(args=None): default_options = { "linewidth": 2, - "speedup": 4, + "speedup": 1, "trail": 100, "entity_scale": 0.2, "fixed_view": False, diff --git a/src/robofish/io/entity.py b/src/robofish/io/entity.py index fb0b1a209ec070aab86b9e44b02d1829bb0be417..3375f377909410f5b6044e2bfdf444d078a09fc7 100644 --- a/src/robofish/io/entity.py +++ b/src/robofish/io/entity.py @@ -60,6 +60,8 @@ class Entity(h5py.Group): if outlines is not None: entity.create_outlines(outlines, sampling) + entity.update_calculated_data() + return entity @classmethod @@ -141,36 +143,6 @@ class Entity(h5py.Group): return np.tile([0, 1], (self.positions.shape[0], 1)) return self["orientations"] - @property - def orientations_rad(self): - # If actions_speeds_turns does not exist yet or the poses hash is not correct - if ( - "calculated_orientations_rad" not in self - or "poses_hash" not in self.attrs - or self.poses_hash != self.attrs["poses_hash"] - ): - - ori_rad = utils.limit_angle_range( - np.arctan2(self.orientations[:, 1], self.orientations[:, 0]), - _range=(0, 2 * np.pi), - )[:, np.newaxis] - - try: - - self.attrs["poses_hash"] = self.poses_hash - self["calculated_orientations_rad"] = ori_rad.astype(np.float64) - except RuntimeError as e: - if robofish.io.warn_when_unable_to_store: - print( - "Trying to store calculated orientations_rad in file to reuse it but the file was opened as read-only.\n" - "If you open the file with mode 'r+' (or 'w' if it is still open from creation time) the information can be stored and reused.\n" - "To disable this message execute robofish.io.disable_warning_when_unable_to_store().\n" - ) - else: - ori_rad = self["calculated_orientations_rad"] - - return ori_rad - @property @deprecation.deprecated( deprecated_in="0.2", @@ -220,11 +192,19 @@ class Entity(h5py.Group): @property def poses_hash(self): + # The hash of h5py datasets changes each time the file is reopened. + # Also the hash of casting the array to bytes and calculating the hash changes. + def npsumhash(a): + return hash(np.nansum(a)) + if "orientations" in self: - h = (hash(self["positions"]) + hash(self["orientations"])) // 2 + h = (npsumhash(self["positions"]) + npsumhash(self["orientations"])) // 2 + elif "positions" in self: + print("We found positions") + h = npsumhash(self["positions"]) else: - h = hash(self["positions"]) - return int(h) + h = 0 + return h @property def poses(self): @@ -267,8 +247,59 @@ class Entity(h5py.Group): turn = utils.limit_angle_range(diff[:, 2], _range=(-np.pi, np.pi)) return np.stack([speed, turn], axis=-1) - @property - def actions_speeds_turns(self): + def update_calculated_data(self, verbose=False, force_update=False): + if ( + "poses_hash" not in self.attrs + or self.attrs["poses_hash"] != self.poses_hash + or "calculated_orientations_rad" not in self + or "calculated_actions_speeds_turns" not in self + or "unfinished_calculations" in self.attrs + or force_update + ): + try: + self.attrs["poses_hash"] = self.poses_hash + self.attrs["unfinished_calculations"] = True + if "orientations" in self: + ori_rad = self.calculate_orientations_rad() + if "calculated_orientations_rad" in self: + del self["calculated_orientations_rad"] + self["calculated_orientations_rad"] = ori_rad.astype(np.float64) + + speeds_turns = self.calculate_actions_speeds_turns() + if "calculated_actions_speeds_turns" in self: + del self["calculated_actions_speeds_turns"] + self["calculated_actions_speeds_turns"] = speeds_turns.astype( + np.float64 + ) + del self.attrs["unfinished_calculations"] + + if verbose: + print( + f"Updated calculated data for entity {self.name} with poses_hash {self.poses_hash}" + ) + elif verbose: + print( + "Since there were no orientations in the data, nothing was calculated." + ) + except RuntimeError as e: + print("Trying to update calculated data in a read-only file") + raise e + else: + if verbose: + print( + f"Nothing to be updated in entity {self.name}. Poses_hash was {self.attrs['poses_hash']}" + ) + + assert self.attrs["poses_hash"] == self.poses_hash + + def calculate_orientations_rad(self): + ori_rad = utils.limit_angle_range( + np.arctan2(self.orientations[:, 1], self.orientations[:, 0]), + _range=(0, 2 * np.pi), + )[:, np.newaxis] + return ori_rad + + def calculate_actions_speeds_turns(self): """Calculate the speed, turn and from the recorded positions and orientations. The turn is calculated by the change of orientation between frames. @@ -278,36 +309,32 @@ class Entity(h5py.Group): Returns: An array with shape (number_of_positions -1, 2 (speed in cm/frame, turn in rad/frame). """ + ori = self.orientations + ori_rad = self.orientations_rad + pos = self.positions + turn = utils.limit_angle_range(np.diff(ori_rad, axis=0)[:, 0]) + pos_diff = np.diff(pos, axis=0) + speed = np.array( + [np.dot(pos_diff[i], ori[i + 1]) for i in range(pos_diff.shape[0])] + ) + return np.stack([speed, turn], axis=-1) - # If actions_speeds_turns does not exist yet or the poses hash is not correct - if ( - "calculated_actions_speeds_turns" not in self - or "poses_hash" not in self.attrs - or self.poses_hash != self.attrs["poses_hash"] - ): - ori = self.orientations - ori_rad = self.orientations_rad - pos = self.positions - turn = utils.limit_angle_range(np.diff(ori_rad, axis=0)[:, 0]) - pos_diff = np.diff(pos, axis=0) - speed = np.array( - [np.dot(pos_diff[i], ori[i + 1]) for i in range(pos_diff.shape[0])] - ) - actions_speeds_turns = np.stack([speed, turn], axis=-1) - - try: - self.attrs["poses_hash"] = self.poses_hash - self["calculated_actions_speeds_turns"] = actions_speeds_turns.astype( - np.float64 - ) - except RuntimeError as e: - if robofish.io.warn_when_unable_to_store: - print( - "Trying to store calculated actions_speeds_turns in file to reuse it but the file was opened as read-only.\n" - "If you open the file with mode 'r+' (or 'w' if it is still open from creation time) the information can be stored and reused.\n" - "To disable this message execute robofish.io.disable_warning_when_unable_to_store().\n" - ) + @property + def actions_speeds_turns(self): + if "calculated_actions_speeds_turns" in self: + assert ( + self.attrs["poses_hash"] == self.poses_hash + ), f"The calculated poses_hash was not identical to the stored poses_hash. Please update the calculated data after changing positions or orientations with entity.update_calculated_data(). stored hash: {self.attrs['poses_hash']}, calculated hash: {self.poses_hash}." + return self["calculated_actions_speeds_turns"] else: - actions_speeds_turns = self["calculated_actions_speeds_turns"] + return self.calculate_actions_speeds_turns() - return actions_speeds_turns + @property + def orientations_rad(self): + if "calculated_orientations_rad" in self: + assert ( + self.attrs["poses_hash"] == self.poses_hash + ), f"The calculated poses_hash was not identical to the stored poses_hash. Please update the calculated data after changing positions or orientations with entity.update_calculated_data(). stored hash: {self.attrs['poses_hash']}, calculated hash: {self.poses_hash}." + return self["calculated_orientations_rad"] + else: + return self.calculate_orientations_rad() diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index 40097b5a992886fbae5f103d7c4e3f231a7b3fe2..d57c7bf89ceb99508f1a4c09f14824c5b244b0c3 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -75,6 +75,7 @@ class File(h5py.File): monotonic_time_points_us: Iterable = None, calendar_time_points: Iterable = None, open_copy: bool = False, + validate_poses_hash: bool = True, ): """Create a new RoboFish Track Format object. @@ -214,6 +215,38 @@ class File(h5py.File): calendar_time_points=calendar_time_points, default=True, ) + else: + # A quick validation to find h5py files which are not robofish.io files + if any([a not in self.attrs for a in ["world_size_cm", "format_version"]]): + msg = f"The opened file {self.path} does not include world_size_cm or format_version. It seems that the file is not a robofish.io.File." + if strict_validate: + raise KeyError(msg) + else: + warnings.warn(msg) + return + + # Validate that the stored poses hash still fits. + if validate_poses_hash: + for entity in self.entities: + if "poses_hash" in entity.attrs: + if entity.attrs["poses_hash"] != entity.poses_hash: + warnings.warn( + f"The stored hash is not identical with the newly calculated hash. In entity {entity.name} in {self.path}. f.entity_actions_turns_speeds and f.entity_orientation_rad will return wrong results.\n" + f"stored: {entity.attrs['poses_hash']}, calculated: {entity.poses_hash}" + ) + assert ( + "unfinished_calculations" not in entity.attrs + ), f"The calculated data of file {self.path} is uncomplete and was probably aborted during calculation. please recalculate with `robofish-io-update-calculated-data {self.path}`." + + else: + warnings.warn( + f"The file did not include pre-calculated data so the actions_speeds_turns " + f"and orientations_rad will have to be be recalculated everytime.\n" + f"Please use `robofish-io-update-calculated-data {self.path}` in the " + f"commandline or\nopen and close the file with robofish.io.File(f, 'r+') " + f"in python.\nIf the data should be recalculated every time open the file " + "with the bool option validate_poses_hash=False." + ) if validate: self.validate(strict_validate) @@ -225,8 +258,14 @@ class File(h5py.File): if (type, value, traceback) == (None, None, None): if self.mode != "r": # No need to validate read only files (performance). self.validate() + super().__exit__(type, value, traceback) + def close(self): + if self.mode != "r": + self.update_calculated_data() + super().close() + def save_as( self, path: Union[str, Path], @@ -243,6 +282,7 @@ class File(h5py.File): The file itself, so something like f = robofish.io.File().save_as("file.hdf5") works """ + self.update_calculated_data() self.validate(strict_validate=strict_validate) # Ensure all buffered data has been written to disk @@ -471,6 +511,10 @@ class File(h5py.File): ) return entity_names + def update_calculated_data(self, verbose=False): + for e in self.entities: + e.update_calculated_data(verbose) + def clear_calculated_data(self, verbose=True): """Delete all calculated data from the files.""" txt = "" @@ -1129,7 +1173,7 @@ class File(h5py.File): frames=n_frames, init_func=init, blit=platform.system() != "Darwin", - interval=self.frequency, + interval=1000 / self.frequency, repeat=False, ) diff --git a/tests/resources/nan_test.hdf5 b/tests/resources/nan_test.hdf5 index 3ff017ef600aec8928a7e55c38a68c2821876e22..92c3050e3b5204721b0eb0a0543adf0c4d5f65c0 100644 Binary files a/tests/resources/nan_test.hdf5 and b/tests/resources/nan_test.hdf5 differ diff --git a/tests/resources/valid_1.hdf5 b/tests/resources/valid_1.hdf5 index de8a1dabe0f28ecb0b1600aa01918328858e0f8f..664e40a39b908364ac70cf99dc11715a6ecd2abc 100644 Binary files a/tests/resources/valid_1.hdf5 and b/tests/resources/valid_1.hdf5 differ diff --git a/tests/resources/valid_2.hdf5 b/tests/resources/valid_2.hdf5 index 10caedaafa4d0d13c4be720f2c8d48dec184a092..eb8c69e1a01f1704c29c94b7c85590c5443275dd 100644 Binary files a/tests/resources/valid_2.hdf5 and b/tests/resources/valid_2.hdf5 differ diff --git a/tests/robofish/io/test_app_io.py b/tests/robofish/io/test_app_io.py index ea4933f28ae92e86b0bbc0f9e75b0890ae5fba3b..6c3a16f62e6f34708624236ec4d1525be4b39254 100644 --- a/tests/robofish/io/test_app_io.py +++ b/tests/robofish/io/test_app_io.py @@ -2,7 +2,6 @@ import robofish.io.app as app from robofish.io import utils import pytest import logging -from pathlib import Path logging.getLogger().setLevel(logging.INFO) @@ -19,11 +18,16 @@ def test_app_validate(): self.path = path self.output_format = output_format - raw_output = app.validate(DummyArgs([resources_path], "raw")) + # invalid.hdf5 should pass a warning + with pytest.warns(UserWarning): + raw_output = app.validate(DummyArgs([resources_path], "raw")) # The three files valid.hdf5, almost_valid.hdf5, and invalid.hdf5 should be found. assert len(raw_output) == 4 - app.validate(DummyArgs([resources_path], "human")) + + # invalid.hdf5 should pass a warning + with pytest.warns(UserWarning): + app.validate(DummyArgs([resources_path], "human")) def test_app_print():