diff --git a/config.ini b/config.ini new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/robofish/io/track.py b/src/robofish/io/track.py index dd27c8f814b3ad22692249e9fc5090d32c31fe3a..f1cd04b334b35524a7cadced506fe1317e0aea1c 100644 --- a/src/robofish/io/track.py +++ b/src/robofish/io/track.py @@ -2,13 +2,24 @@ # import configparser # config.read("../../config.ini") +# Neccessary functions: +# - save everything +# - convenient loading +# - interpolation + + import robofish.io.util as util default_trackformat_version = [1, 0] class Track: - def __init__(self, world_size, version=default_trackformat_version): + time = None + entity_list = [] + + def __init__( + self, world_size, monotonic_time=None, version=default_trackformat_version + ): assert len(world_size) == 2 assert len(version) == 2 self.track = { @@ -17,25 +28,79 @@ class Track: "g_entities": {}, } - def _create_entity(self, entity_name, poses=None, outlines=None): - e_data = {time: {}} - if poses: + def create_entities( + self, + type="Fish", + n=None, + names=None, + poses_bulk=None, + monotonic_steps=None, + monotonic_points=None, + calendar_points=None, + ): + # Nothing specified + if n is None and names is None and poses_bulk is None: + logger.error("It's neccessary to specify either n, names or trajectories") + return [] + # Names specified + if poses_bulk is not None: + if n and len(poses_bulk) != n: + logger.error("The length of poses bulk does not match n") + n = len(poses_bulk) + if names is not None: + if n and len(names) != n: + logger.error( + "The length of entity names does not match n or the trajectory shape" + ) + n = len(names) + else: + names = ["%s_%d" % (type, i) for i in range(n)] + + for i, name in enumerate(names): + poses = None + if poses_bulk is not None: + poses = poses_bulk[i] + self._create_single_entity( + name, + poses=poses, + monotonic_steps=monotonic_steps, + monotonic_points=monotonic_points, + calendar_points=calendar_points, + ) + return names + + def _create_single_entity( + self, + entity_name, + poses=None, + outlines=None, + monotonic_steps=None, + monotonic_points=None, + calendar_points=None, + ): + e_data = {"g_time": {}} + if poses is not None: e_data.update({"d_poses": poses}) - if outlines: + if outlines is not None: e_data.update({"d_outlines": outlines}) + if monotonic_steps is not None: + e_data["g_time"].update({"a_monotonic steps": monotonic_steps}) + if monotonic_points is not None: + e_data["g_time"].update({"d_monotonic step": monotonic_points}) + if calendar_points is not None: + e_data["g_time"].update({"d_calendar points": calendar_points}) - if entity_name not in self.track["entities"]: - self.track["entities"].update({entity_name: e_data}) + if entity_name not in self.track["g_entities"]: + self.track["g_entities"].update({entity_name: e_data}) - def add_to_entity(self, entity_name, key, pre, data): + def _add_to_entity(self, entity_name, key, pre, data): self._create_entity(entity_name) key_w_pre = pre + "_" + key - self.track["entities"][entity_name][key] = data + self.track["g_entities"][entity_name][key] = data - def add_time_to_entity(self, entity_name, key, data): - # add_to_entity(entity_name,"time",data) + def _add_time_to_entity(self, entity_name, key, data): self._create_entity(entity_name) - self.track["entities"][entity_name]["time"][key] = data + self.track["g_entities"][entity_name]["g_time"][key] = data def validate(self): util.validate_track(self.track) diff --git a/src/robofish/io/util.py b/src/robofish/io/util.py index 49a5a6818d813c6d1c18ccb1ee2c4a55c8c7614e..373e27ac751580b38a179c292f930fdbc7a61c4b 100644 --- a/src/robofish/io/util.py +++ b/src/robofish/io/util.py @@ -205,7 +205,7 @@ def validate_track(track, throw_exception=True): # time time = e_data["g_time"] - if "a_monotonoc steps" in time: + if "a_monotonic steps" in time: pass elif "d_monotonic points" in time: monotonic_points = time["d_monotonic points"] diff --git a/tests/robofish/io/test_track.py b/tests/robofish/io/test_track.py index e575818c07efe1f6bf8df6896ab09235cbe2607f..9bc7e3d9764ed92bb5e27c822ebc58a3cc47ee79 100644 --- a/tests/robofish/io/test_track.py +++ b/tests/robofish/io/test_track.py @@ -1,6 +1,22 @@ from robofish.io.track import Track +from robofish.io import util + +import numpy as np def test_init(): - track = Track([100, 100]) + timesteps = 10 + fishes = 2 + pose_dim = 4 + world_size = [100, 100] + poses_bulk = np.zeros(shape=(timesteps, fishes, pose_dim)) + + track = Track(world_size) + names = track.create_entities( + type="fish", poses_bulk=poses_bulk, monotonic_steps=40 + ) + + print(track.track) + + util.write_hdf5_from_track("test.hdf5", track.track) # track.add_entity_poses("fish1", [[[1, 1, 0, 0], [1, 1, 0, 0]]])