Skip to content
Snippets Groups Projects
Commit b3782164 authored by Andi Gerken's avatar Andi Gerken
Browse files

Added buffer for poses_rad and actions_speeds_turns so that they don't have to...

Added buffer for poses_rad and actions_speeds_turns so that they don't have to be recalculated every time.
parent 8a45018e
No related branches found
No related tags found
No related merge requests found
Pipeline #49484 passed
...@@ -10,6 +10,7 @@ entry_points = { ...@@ -10,6 +10,7 @@ entry_points = {
"robofish-io-validate=robofish.io.app:validate", "robofish-io-validate=robofish.io.app:validate",
"robofish-io-print=robofish.io.app:print_file", "robofish-io-print=robofish.io.app:print_file",
"robofish-io-render=robofish.io.app:render", "robofish-io-render=robofish.io.app:render",
"robofish-io-clear-calculated-data=robofish.io.app:clear_calculated_data",
# TODO: This should be called robofish-evaluate which is not possible because of the package name (guess) ask moritz # TODO: This should be called robofish-evaluate which is not possible because of the package name (guess) ask moritz
"robofish-io-evaluate=robofish.evaluate.app:evaluate", "robofish-io-evaluate=robofish.evaluate.app:evaluate",
] ]
......
...@@ -20,3 +20,9 @@ import robofish.io.app ...@@ -20,3 +20,9 @@ import robofish.io.app
if not ((3, 7) <= sys.version_info < (4, 0)): if not ((3, 7) <= sys.version_info < (4, 0)):
logging.warning("Unsupported Python version") logging.warning("Unsupported Python version")
warn_when_unable_to_store = True
def disable_warning_when_unable_to_store():
robofish.io.warn_when_unable_to_store = False
...@@ -53,6 +53,33 @@ def print_file(args=None): ...@@ -53,6 +53,33 @@ def print_file(args=None):
return not valid return not valid
def clear_calculated_data(args=None):
parser = argparse.ArgumentParser(
description="This function clears calculated data from robofish.io files."
)
parser.add_argument(
"path",
type=str,
nargs="+",
help="The path to one or multiple files and/or folders.",
)
if args is None:
args = parser.parse_args()
files_per_path = utils.get_all_files_from_paths(args.path)
files = [
f for f_in_path in files_per_path for f in f_in_path
] # Concatenate all files to one list
assert len(files) > 0, f"No files found in path {args.path}."
for fp in files:
print(f"File {fp}")
with robofish.io.File(fp, "r+") as f:
f.clear_calculated_data()
def validate(args=None): def validate(args=None):
"""This function can be used to validate hdf5 files. """This function can be used to validate hdf5 files.
......
...@@ -143,11 +143,33 @@ class Entity(h5py.Group): ...@@ -143,11 +143,33 @@ class Entity(h5py.Group):
@property @property
def orientations_rad(self): def orientations_rad(self):
ori_rad = utils.limit_angle_range( # If actions_speeds_turns does not exist yet or the poses hash is not correct
np.arctan2(self.orientations[:, 1], self.orientations[:, 0]), if (
_range=(0, 2 * np.pi), "calculated_orientations_rad" not in self
) or "poses_hash" not in self.attrs
return ori_rad[:, np.newaxis] or self.poses_hash != self.attrs["poses_hash"]
):
ori_rad = utils.limit_angle_range(
np.arctan2(self.orientations[:, 1], self.orientations[:, 0]),
_range=(0, 2 * np.pi),
)[:, np.newaxis]
try:
self.attrs["poses_hash"] = self.poses_hash
self["calculated_orientations_rad"] = ori_rad.astype(np.float64)
except RuntimeError as e:
if robofish.io.warn_when_unable_to_store:
print(
"Trying to store calculated orientations_rad in file to reuse it but the file was opened as read-only.\n"
"If you open the file with mode 'r+' (or 'w' if it is still open from creation time) the information can be stored and reused.\n"
"To disable this message execute robofish.io.disable_warning_when_unable_to_store().\n"
)
else:
ori_rad = self["calculated_orientations_rad"]
return ori_rad
@property @property
@deprecation.deprecated( @deprecation.deprecated(
...@@ -196,6 +218,14 @@ class Entity(h5py.Group): ...@@ -196,6 +218,14 @@ class Entity(h5py.Group):
return poses_with_calculated_orientation return poses_with_calculated_orientation
@property
def poses_hash(self):
if "orientations" in self:
h = (hash(self["positions"]) + hash(self["orientations"])) // 2
else:
h = hash(self["positions"])
return int(h)
@property @property
def poses(self): def poses(self):
return np.concatenate([self.positions, self.orientations], axis=1) return np.concatenate([self.positions, self.orientations], axis=1)
...@@ -249,13 +279,35 @@ class Entity(h5py.Group): ...@@ -249,13 +279,35 @@ class Entity(h5py.Group):
An array with shape (number_of_positions -1, 2 (speed in cm/frame, turn in rad/frame). An array with shape (number_of_positions -1, 2 (speed in cm/frame, turn in rad/frame).
""" """
ori = self.orientations # If actions_speeds_turns does not exist yet or the poses hash is not correct
ori_rad = self.orientations_rad if (
pos = self.positions "calculated_actions_speeds_turns" not in self
turn = utils.limit_angle_range(np.diff(ori_rad, axis=0)[:, 0]) or "poses_hash" not in self.attrs
pos_diff = np.diff(pos, axis=0) or self.poses_hash != self.attrs["poses_hash"]
speed = np.array( ):
[np.dot(pos_diff[i], ori[i + 1]) for i in range(pos_diff.shape[0])] ori = self.orientations
) ori_rad = self.orientations_rad
pos = self.positions
turn = utils.limit_angle_range(np.diff(ori_rad, axis=0)[:, 0])
pos_diff = np.diff(pos, axis=0)
speed = np.array(
[np.dot(pos_diff[i], ori[i + 1]) for i in range(pos_diff.shape[0])]
)
actions_speeds_turns = np.stack([speed, turn], axis=-1)
return np.stack([speed, turn], axis=-1) try:
self.attrs["poses_hash"] = self.poses_hash
self["calculated_actions_speeds_turns"] = actions_speeds_turns.astype(
np.float64
)
except RuntimeError as e:
if robofish.io.warn_when_unable_to_store:
print(
"Trying to store calculated actions_speeds_turns in file to reuse it but the file was opened as read-only.\n"
"If you open the file with mode 'r+' (or 'w' if it is still open from creation time) the information can be stored and reused.\n"
"To disable this message execute robofish.io.disable_warning_when_unable_to_store().\n"
)
else:
actions_speeds_turns = self["calculated_actions_speeds_turns"]
return actions_speeds_turns
...@@ -470,6 +470,24 @@ class File(h5py.File): ...@@ -470,6 +470,24 @@ class File(h5py.File):
) )
return entity_names return entity_names
def clear_calculated_data(self, verbose=True):
"""Delete all calculated data from the files."""
txt = ""
for e in self.entities:
txt += f"Deleting from {e}. Attrs: ["
for a in ["poses_hash"]:
if a in e.attrs:
del e.attrs[a]
txt += f"{a}, "
txt = txt[:-2] + "] Datasets: ["
for g in ["calculated_actions_speeds_turns", "calculated_orientations_rad"]:
if g in e:
del e[g]
txt += f"{g}, "
txt = txt[:-2] + "]\n"
if verbose:
print(txt[:-1])
@property @property
def entity_names(self) -> Iterable[str]: def entity_names(self) -> Iterable[str]:
"""Getter for the names of all entities """Getter for the names of all entities
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment