diff --git a/setup.py b/setup.py index fd67f2ed2906a91b878e31b1a5754202034ce591..48573d8af70ab56fc4c1098994c2b9b3b4e7014f 100644 --- a/setup.py +++ b/setup.py @@ -10,6 +10,7 @@ entry_points = { "robofish-io-validate=robofish.io.app:validate", "robofish-io-print=robofish.io.app:print_file", "robofish-io-render=robofish.io.app:render", + "robofish-io-clear-calculated-data=robofish.io.app:clear_calculated_data", # TODO: This should be called robofish-evaluate which is not possible because of the package name (guess) ask moritz "robofish-io-evaluate=robofish.evaluate.app:evaluate", ] diff --git a/src/robofish/io/__init__.py b/src/robofish/io/__init__.py index 5b054c2fb71018dd23ceb315d83bf83f51cc1cac..b14b319e8e1899321e1140140d0846dcee92662e 100644 --- a/src/robofish/io/__init__.py +++ b/src/robofish/io/__init__.py @@ -20,3 +20,9 @@ import robofish.io.app if not ((3, 7) <= sys.version_info < (4, 0)): logging.warning("Unsupported Python version") + +warn_when_unable_to_store = True + + +def disable_warning_when_unable_to_store(): + robofish.io.warn_when_unable_to_store = False diff --git a/src/robofish/io/app.py b/src/robofish/io/app.py index be1690d62049354a292a0d06627affd2c410d397..15b6ebe15cb09eac6bf38b901b86184bf0da5077 100644 --- a/src/robofish/io/app.py +++ b/src/robofish/io/app.py @@ -53,6 +53,33 @@ def print_file(args=None): return not valid +def clear_calculated_data(args=None): + parser = argparse.ArgumentParser( + description="This function clears calculated data from robofish.io files." + ) + + parser.add_argument( + "path", + type=str, + nargs="+", + help="The path to one or multiple files and/or folders.", + ) + if args is None: + args = parser.parse_args() + + files_per_path = utils.get_all_files_from_paths(args.path) + files = [ + f for f_in_path in files_per_path for f in f_in_path + ] # Concatenate all files to one list + + assert len(files) > 0, f"No files found in path {args.path}." + + for fp in files: + print(f"File {fp}") + with robofish.io.File(fp, "r+") as f: + f.clear_calculated_data() + + def validate(args=None): """This function can be used to validate hdf5 files. diff --git a/src/robofish/io/entity.py b/src/robofish/io/entity.py index 85c9c926ae7acc333759375d3039015b30c66454..fb0b1a209ec070aab86b9e44b02d1829bb0be417 100644 --- a/src/robofish/io/entity.py +++ b/src/robofish/io/entity.py @@ -143,11 +143,33 @@ class Entity(h5py.Group): @property def orientations_rad(self): - ori_rad = utils.limit_angle_range( - np.arctan2(self.orientations[:, 1], self.orientations[:, 0]), - _range=(0, 2 * np.pi), - ) - return ori_rad[:, np.newaxis] + # If actions_speeds_turns does not exist yet or the poses hash is not correct + if ( + "calculated_orientations_rad" not in self + or "poses_hash" not in self.attrs + or self.poses_hash != self.attrs["poses_hash"] + ): + + ori_rad = utils.limit_angle_range( + np.arctan2(self.orientations[:, 1], self.orientations[:, 0]), + _range=(0, 2 * np.pi), + )[:, np.newaxis] + + try: + + self.attrs["poses_hash"] = self.poses_hash + self["calculated_orientations_rad"] = ori_rad.astype(np.float64) + except RuntimeError as e: + if robofish.io.warn_when_unable_to_store: + print( + "Trying to store calculated orientations_rad in file to reuse it but the file was opened as read-only.\n" + "If you open the file with mode 'r+' (or 'w' if it is still open from creation time) the information can be stored and reused.\n" + "To disable this message execute robofish.io.disable_warning_when_unable_to_store().\n" + ) + else: + ori_rad = self["calculated_orientations_rad"] + + return ori_rad @property @deprecation.deprecated( @@ -196,6 +218,14 @@ class Entity(h5py.Group): return poses_with_calculated_orientation + @property + def poses_hash(self): + if "orientations" in self: + h = (hash(self["positions"]) + hash(self["orientations"])) // 2 + else: + h = hash(self["positions"]) + return int(h) + @property def poses(self): return np.concatenate([self.positions, self.orientations], axis=1) @@ -249,13 +279,35 @@ class Entity(h5py.Group): An array with shape (number_of_positions -1, 2 (speed in cm/frame, turn in rad/frame). """ - ori = self.orientations - ori_rad = self.orientations_rad - pos = self.positions - turn = utils.limit_angle_range(np.diff(ori_rad, axis=0)[:, 0]) - pos_diff = np.diff(pos, axis=0) - speed = np.array( - [np.dot(pos_diff[i], ori[i + 1]) for i in range(pos_diff.shape[0])] - ) + # If actions_speeds_turns does not exist yet or the poses hash is not correct + if ( + "calculated_actions_speeds_turns" not in self + or "poses_hash" not in self.attrs + or self.poses_hash != self.attrs["poses_hash"] + ): + ori = self.orientations + ori_rad = self.orientations_rad + pos = self.positions + turn = utils.limit_angle_range(np.diff(ori_rad, axis=0)[:, 0]) + pos_diff = np.diff(pos, axis=0) + speed = np.array( + [np.dot(pos_diff[i], ori[i + 1]) for i in range(pos_diff.shape[0])] + ) + actions_speeds_turns = np.stack([speed, turn], axis=-1) - return np.stack([speed, turn], axis=-1) + try: + self.attrs["poses_hash"] = self.poses_hash + self["calculated_actions_speeds_turns"] = actions_speeds_turns.astype( + np.float64 + ) + except RuntimeError as e: + if robofish.io.warn_when_unable_to_store: + print( + "Trying to store calculated actions_speeds_turns in file to reuse it but the file was opened as read-only.\n" + "If you open the file with mode 'r+' (or 'w' if it is still open from creation time) the information can be stored and reused.\n" + "To disable this message execute robofish.io.disable_warning_when_unable_to_store().\n" + ) + else: + actions_speeds_turns = self["calculated_actions_speeds_turns"] + + return actions_speeds_turns diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index 1bf1ed173f0a25f84963a9f7e90f477d2c7e44e7..cdb2aedbbe52ab3c81f99590b1a992c20cff0b9b 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -470,6 +470,24 @@ class File(h5py.File): ) return entity_names + def clear_calculated_data(self, verbose=True): + """Delete all calculated data from the files.""" + txt = "" + for e in self.entities: + txt += f"Deleting from {e}. Attrs: [" + for a in ["poses_hash"]: + if a in e.attrs: + del e.attrs[a] + txt += f"{a}, " + txt = txt[:-2] + "] Datasets: [" + for g in ["calculated_actions_speeds_turns", "calculated_orientations_rad"]: + if g in e: + del e[g] + txt += f"{g}, " + txt = txt[:-2] + "]\n" + if verbose: + print(txt[:-1]) + @property def entity_names(self) -> Iterable[str]: """Getter for the names of all entities