From d7c2ffa19cadd5639f36dc3e5c9cdbef79864098 Mon Sep 17 00:00:00 2001
From: Moritz Maxeiner <mm@ucw.sh>
Date: Thu, 18 Feb 2021 13:02:21 +0100
Subject: [PATCH] Use robofish.io.File as base class for TrackFile

---
 python/setup.py                           |  2 +-
 python/src/robofish/trackviewer/common.py | 33 +++++++++--------------
 2 files changed, 14 insertions(+), 21 deletions(-)

diff --git a/python/setup.py b/python/setup.py
index 44b01b6..9eadde6 100644
--- a/python/setup.py
+++ b/python/setup.py
@@ -62,7 +62,7 @@ entry_points = {
     "console_scripts": ["robofish-trackviewer-render=robofish.trackviewer.render_video:main"]
 }
 
-install_requires = ["numpy", "h5py", "robofish-core", "PySide2>=5.15"]
+install_requires = ["numpy", "h5py", "robofish-core", "robofish-io", "PySide2>=5.15"]
 
 
 def source_version():
diff --git a/python/src/robofish/trackviewer/common.py b/python/src/robofish/trackviewer/common.py
index 2f0c186..068a4fd 100644
--- a/python/src/robofish/trackviewer/common.py
+++ b/python/src/robofish/trackviewer/common.py
@@ -9,6 +9,7 @@ from numpy import zeros, float32, linspace, radians, clip
 import h5py
 
 from robofish.core import entity_boundary_distance, entity_entities_minimum_sector_distances
+import robofish.io
 
 from robofish.trackviewer.cpp import (
     FrameRenderer,
@@ -225,21 +226,16 @@ class ViewAction(Action):
         setattr(namespace, self.dest, items)
 
 
-class TrackFile:
-    def __init__(self, file_name, *, merged_fps=30):
-        self.f = h5py.File(file_name, "r")
+class TrackFile(robofish.io.File):
+    def __init__(self, file_name):
+        super().__init__(file_name, "r")
 
-        self.merged_fps = merged_fps
-
-        assert self.f.attrs["format_version"][0] == 1
-        self.world_size = self.f.attrs["world_size_cm"]
-
-        self.entity_names = list(self.f["entities"].keys())
+        assert self.attrs["format_version"][0] == 1
 
         self._parse_samplings_and_entities()
 
     def _parse_samplings_and_entities(self):
-        samplings = self.f["samplings"]
+        samplings = self["samplings"]
 
         assert len(samplings.values()) == 1, "Only a single sampling supported"
 
@@ -249,17 +245,14 @@ class TrackFile:
         assert "frequency_hz" in default_sampling.attrs, "Frequency required"
         self.time_step = 1000.0 / default_sampling.attrs["frequency_hz"]
 
-        self.poses = {}
-        for entity_name, entity_group in self.f["entities"].items():
-            positions = entity_group["positions"]
-            orientations = entity_group["orientations"]
-            assert positions.shape == orientations.shape, "Positions and orientations required"
-            self.poses[entity_name] = np.concatenate([positions, orientations], axis=1)
+        self.__entity_poses = {}
+        for entity in self.entities:
+            self.__entity_poses[entity.name] = entity.poses
 
-        assert len(np.unique([p.shape for p in self.poses.values()], axis=0)) == 1, "All poses must have the same shape"
-        assert len(self.poses) >= 1, "At least one pose required"
-        self.num_frames = list(self.poses.values())[0].shape[0]
+        assert len(np.unique([p.shape for p in self.__entity_poses.values()], axis=0)) == 1, "All poses must have the same shape"
+        assert len(self.__entity_poses) >= 1, "At least one pose required"
+        self.num_frames = list(self.__entity_poses.values())[0].shape[0]
 
 
     def frame(self, entity_name, frame_index):
-        return self.poses[entity_name][frame_index]
+        return self.__entity_poses[entity_name][frame_index]
-- 
GitLab