From d7c2ffa19cadd5639f36dc3e5c9cdbef79864098 Mon Sep 17 00:00:00 2001 From: Moritz Maxeiner <mm@ucw.sh> Date: Thu, 18 Feb 2021 13:02:21 +0100 Subject: [PATCH] Use robofish.io.File as base class for TrackFile --- python/setup.py | 2 +- python/src/robofish/trackviewer/common.py | 33 +++++++++-------------- 2 files changed, 14 insertions(+), 21 deletions(-) diff --git a/python/setup.py b/python/setup.py index 44b01b6..9eadde6 100644 --- a/python/setup.py +++ b/python/setup.py @@ -62,7 +62,7 @@ entry_points = { "console_scripts": ["robofish-trackviewer-render=robofish.trackviewer.render_video:main"] } -install_requires = ["numpy", "h5py", "robofish-core", "PySide2>=5.15"] +install_requires = ["numpy", "h5py", "robofish-core", "robofish-io", "PySide2>=5.15"] def source_version(): diff --git a/python/src/robofish/trackviewer/common.py b/python/src/robofish/trackviewer/common.py index 2f0c186..068a4fd 100644 --- a/python/src/robofish/trackviewer/common.py +++ b/python/src/robofish/trackviewer/common.py @@ -9,6 +9,7 @@ from numpy import zeros, float32, linspace, radians, clip import h5py from robofish.core import entity_boundary_distance, entity_entities_minimum_sector_distances +import robofish.io from robofish.trackviewer.cpp import ( FrameRenderer, @@ -225,21 +226,16 @@ class ViewAction(Action): setattr(namespace, self.dest, items) -class TrackFile: - def __init__(self, file_name, *, merged_fps=30): - self.f = h5py.File(file_name, "r") +class TrackFile(robofish.io.File): + def __init__(self, file_name): + super().__init__(file_name, "r") - self.merged_fps = merged_fps - - assert self.f.attrs["format_version"][0] == 1 - self.world_size = self.f.attrs["world_size_cm"] - - self.entity_names = list(self.f["entities"].keys()) + assert self.attrs["format_version"][0] == 1 self._parse_samplings_and_entities() def _parse_samplings_and_entities(self): - samplings = self.f["samplings"] + samplings = self["samplings"] assert len(samplings.values()) == 1, "Only a single sampling supported" @@ -249,17 +245,14 @@ class TrackFile: assert "frequency_hz" in default_sampling.attrs, "Frequency required" self.time_step = 1000.0 / default_sampling.attrs["frequency_hz"] - self.poses = {} - for entity_name, entity_group in self.f["entities"].items(): - positions = entity_group["positions"] - orientations = entity_group["orientations"] - assert positions.shape == orientations.shape, "Positions and orientations required" - self.poses[entity_name] = np.concatenate([positions, orientations], axis=1) + self.__entity_poses = {} + for entity in self.entities: + self.__entity_poses[entity.name] = entity.poses - assert len(np.unique([p.shape for p in self.poses.values()], axis=0)) == 1, "All poses must have the same shape" - assert len(self.poses) >= 1, "At least one pose required" - self.num_frames = list(self.poses.values())[0].shape[0] + assert len(np.unique([p.shape for p in self.__entity_poses.values()], axis=0)) == 1, "All poses must have the same shape" + assert len(self.__entity_poses) >= 1, "At least one pose required" + self.num_frames = list(self.__entity_poses.values())[0].shape[0] def frame(self, entity_name, frame_index): - return self.poses[entity_name][frame_index] + return self.__entity_poses[entity_name][frame_index] -- GitLab