Skip to content
Snippets Groups Projects
Commit d7c2ffa1 authored by calrama's avatar calrama
Browse files

Use robofish.io.File as base class for TrackFile

parent 614f54e0
No related branches found
No related tags found
No related merge requests found
Pipeline #35768 passed
...@@ -62,7 +62,7 @@ entry_points = { ...@@ -62,7 +62,7 @@ entry_points = {
"console_scripts": ["robofish-trackviewer-render=robofish.trackviewer.render_video:main"] "console_scripts": ["robofish-trackviewer-render=robofish.trackviewer.render_video:main"]
} }
install_requires = ["numpy", "h5py", "robofish-core", "PySide2>=5.15"] install_requires = ["numpy", "h5py", "robofish-core", "robofish-io", "PySide2>=5.15"]
def source_version(): def source_version():
......
...@@ -9,6 +9,7 @@ from numpy import zeros, float32, linspace, radians, clip ...@@ -9,6 +9,7 @@ from numpy import zeros, float32, linspace, radians, clip
import h5py import h5py
from robofish.core import entity_boundary_distance, entity_entities_minimum_sector_distances from robofish.core import entity_boundary_distance, entity_entities_minimum_sector_distances
import robofish.io
from robofish.trackviewer.cpp import ( from robofish.trackviewer.cpp import (
FrameRenderer, FrameRenderer,
...@@ -225,21 +226,16 @@ class ViewAction(Action): ...@@ -225,21 +226,16 @@ class ViewAction(Action):
setattr(namespace, self.dest, items) setattr(namespace, self.dest, items)
class TrackFile: class TrackFile(robofish.io.File):
def __init__(self, file_name, *, merged_fps=30): def __init__(self, file_name):
self.f = h5py.File(file_name, "r") super().__init__(file_name, "r")
self.merged_fps = merged_fps assert self.attrs["format_version"][0] == 1
assert self.f.attrs["format_version"][0] == 1
self.world_size = self.f.attrs["world_size_cm"]
self.entity_names = list(self.f["entities"].keys())
self._parse_samplings_and_entities() self._parse_samplings_and_entities()
def _parse_samplings_and_entities(self): def _parse_samplings_and_entities(self):
samplings = self.f["samplings"] samplings = self["samplings"]
assert len(samplings.values()) == 1, "Only a single sampling supported" assert len(samplings.values()) == 1, "Only a single sampling supported"
...@@ -249,17 +245,14 @@ class TrackFile: ...@@ -249,17 +245,14 @@ class TrackFile:
assert "frequency_hz" in default_sampling.attrs, "Frequency required" assert "frequency_hz" in default_sampling.attrs, "Frequency required"
self.time_step = 1000.0 / default_sampling.attrs["frequency_hz"] self.time_step = 1000.0 / default_sampling.attrs["frequency_hz"]
self.poses = {} self.__entity_poses = {}
for entity_name, entity_group in self.f["entities"].items(): for entity in self.entities:
positions = entity_group["positions"] self.__entity_poses[entity.name] = entity.poses
orientations = entity_group["orientations"]
assert positions.shape == orientations.shape, "Positions and orientations required"
self.poses[entity_name] = np.concatenate([positions, orientations], axis=1)
assert len(np.unique([p.shape for p in self.poses.values()], axis=0)) == 1, "All poses must have the same shape" assert len(np.unique([p.shape for p in self.__entity_poses.values()], axis=0)) == 1, "All poses must have the same shape"
assert len(self.poses) >= 1, "At least one pose required" assert len(self.__entity_poses) >= 1, "At least one pose required"
self.num_frames = list(self.poses.values())[0].shape[0] self.num_frames = list(self.__entity_poses.values())[0].shape[0]
def frame(self, entity_name, frame_index): def frame(self, entity_name, frame_index):
return self.poses[entity_name][frame_index] return self.__entity_poses[entity_name][frame_index]
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment