Skip to content
Snippets Groups Projects
Commit d7c2ffa1 authored by calrama's avatar calrama
Browse files

Use robofish.io.File as base class for TrackFile

parent 614f54e0
Branches
Tags
No related merge requests found
Pipeline #35768 passed
......@@ -62,7 +62,7 @@ entry_points = {
"console_scripts": ["robofish-trackviewer-render=robofish.trackviewer.render_video:main"]
}
install_requires = ["numpy", "h5py", "robofish-core", "PySide2>=5.15"]
install_requires = ["numpy", "h5py", "robofish-core", "robofish-io", "PySide2>=5.15"]
def source_version():
......
......@@ -9,6 +9,7 @@ from numpy import zeros, float32, linspace, radians, clip
import h5py
from robofish.core import entity_boundary_distance, entity_entities_minimum_sector_distances
import robofish.io
from robofish.trackviewer.cpp import (
FrameRenderer,
......@@ -225,21 +226,16 @@ class ViewAction(Action):
setattr(namespace, self.dest, items)
class TrackFile:
def __init__(self, file_name, *, merged_fps=30):
self.f = h5py.File(file_name, "r")
class TrackFile(robofish.io.File):
def __init__(self, file_name):
super().__init__(file_name, "r")
self.merged_fps = merged_fps
assert self.f.attrs["format_version"][0] == 1
self.world_size = self.f.attrs["world_size_cm"]
self.entity_names = list(self.f["entities"].keys())
assert self.attrs["format_version"][0] == 1
self._parse_samplings_and_entities()
def _parse_samplings_and_entities(self):
samplings = self.f["samplings"]
samplings = self["samplings"]
assert len(samplings.values()) == 1, "Only a single sampling supported"
......@@ -249,17 +245,14 @@ class TrackFile:
assert "frequency_hz" in default_sampling.attrs, "Frequency required"
self.time_step = 1000.0 / default_sampling.attrs["frequency_hz"]
self.poses = {}
for entity_name, entity_group in self.f["entities"].items():
positions = entity_group["positions"]
orientations = entity_group["orientations"]
assert positions.shape == orientations.shape, "Positions and orientations required"
self.poses[entity_name] = np.concatenate([positions, orientations], axis=1)
self.__entity_poses = {}
for entity in self.entities:
self.__entity_poses[entity.name] = entity.poses
assert len(np.unique([p.shape for p in self.poses.values()], axis=0)) == 1, "All poses must have the same shape"
assert len(self.poses) >= 1, "At least one pose required"
self.num_frames = list(self.poses.values())[0].shape[0]
assert len(np.unique([p.shape for p in self.__entity_poses.values()], axis=0)) == 1, "All poses must have the same shape"
assert len(self.__entity_poses) >= 1, "At least one pose required"
self.num_frames = list(self.__entity_poses.values())[0].shape[0]
def frame(self, entity_name, frame_index):
return self.poses[entity_name][frame_index]
return self.__entity_poses[entity_name][frame_index]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment