Skip to content
Snippets Groups Projects
Commit 59ab8e20 authored by Andi Gerken's avatar Andi Gerken
Browse files

development on tracks.py

parent 5c3846b3
Branches
Tags
1 merge request!2development on tracks.py
Pipeline #33538 passed
...@@ -2,13 +2,24 @@ ...@@ -2,13 +2,24 @@
# import configparser # import configparser
# config.read("../../config.ini") # config.read("../../config.ini")
# Neccessary functions:
# - save everything
# - convenient loading
# - interpolation
import robofish.io.util as util import robofish.io.util as util
default_trackformat_version = [1, 0] default_trackformat_version = [1, 0]
class Track: class Track:
def __init__(self, world_size, version=default_trackformat_version): time = None
entity_list = []
def __init__(
self, world_size, monotonic_time=None, version=default_trackformat_version
):
assert len(world_size) == 2 assert len(world_size) == 2
assert len(version) == 2 assert len(version) == 2
self.track = { self.track = {
...@@ -17,25 +28,79 @@ class Track: ...@@ -17,25 +28,79 @@ class Track:
"g_entities": {}, "g_entities": {},
} }
def _create_entity(self, entity_name, poses=None, outlines=None): def create_entities(
e_data = {time: {}} self,
if poses: type="Fish",
n=None,
names=None,
poses_bulk=None,
monotonic_steps=None,
monotonic_points=None,
calendar_points=None,
):
# Nothing specified
if n is None and names is None and poses_bulk is None:
logger.error("It's neccessary to specify either n, names or trajectories")
return []
# Names specified
if poses_bulk is not None:
if n and len(poses_bulk) != n:
logger.error("The length of poses bulk does not match n")
n = len(poses_bulk)
if names is not None:
if n and len(names) != n:
logger.error(
"The length of entity names does not match n or the trajectory shape"
)
n = len(names)
else:
names = ["%s_%d" % (type, i) for i in range(n)]
for i, name in enumerate(names):
poses = None
if poses_bulk is not None:
poses = poses_bulk[i]
self._create_single_entity(
name,
poses=poses,
monotonic_steps=monotonic_steps,
monotonic_points=monotonic_points,
calendar_points=calendar_points,
)
return names
def _create_single_entity(
self,
entity_name,
poses=None,
outlines=None,
monotonic_steps=None,
monotonic_points=None,
calendar_points=None,
):
e_data = {"g_time": {}}
if poses is not None:
e_data.update({"d_poses": poses}) e_data.update({"d_poses": poses})
if outlines: if outlines is not None:
e_data.update({"d_outlines": outlines}) e_data.update({"d_outlines": outlines})
if monotonic_steps is not None:
e_data["g_time"].update({"a_monotonic steps": monotonic_steps})
if monotonic_points is not None:
e_data["g_time"].update({"d_monotonic step": monotonic_points})
if calendar_points is not None:
e_data["g_time"].update({"d_calendar points": calendar_points})
if entity_name not in self.track["entities"]: if entity_name not in self.track["g_entities"]:
self.track["entities"].update({entity_name: e_data}) self.track["g_entities"].update({entity_name: e_data})
def add_to_entity(self, entity_name, key, pre, data): def _add_to_entity(self, entity_name, key, pre, data):
self._create_entity(entity_name) self._create_entity(entity_name)
key_w_pre = pre + "_" + key key_w_pre = pre + "_" + key
self.track["entities"][entity_name][key] = data self.track["g_entities"][entity_name][key] = data
def add_time_to_entity(self, entity_name, key, data): def _add_time_to_entity(self, entity_name, key, data):
# add_to_entity(entity_name,"time",data)
self._create_entity(entity_name) self._create_entity(entity_name)
self.track["entities"][entity_name]["time"][key] = data self.track["g_entities"][entity_name]["g_time"][key] = data
def validate(self): def validate(self):
util.validate_track(self.track) util.validate_track(self.track)
......
...@@ -205,7 +205,7 @@ def validate_track(track, throw_exception=True): ...@@ -205,7 +205,7 @@ def validate_track(track, throw_exception=True):
# time # time
time = e_data["g_time"] time = e_data["g_time"]
if "a_monotonoc steps" in time: if "a_monotonic steps" in time:
pass pass
elif "d_monotonic points" in time: elif "d_monotonic points" in time:
monotonic_points = time["d_monotonic points"] monotonic_points = time["d_monotonic points"]
......
from robofish.io.track import Track from robofish.io.track import Track
from robofish.io import util
import numpy as np
def test_init(): def test_init():
track = Track([100, 100]) timesteps = 10
fishes = 2
pose_dim = 4
world_size = [100, 100]
poses_bulk = np.zeros(shape=(timesteps, fishes, pose_dim))
track = Track(world_size)
names = track.create_entities(
type="fish", poses_bulk=poses_bulk, monotonic_steps=40
)
print(track.track)
util.write_hdf5_from_track("test.hdf5", track.track)
# track.add_entity_poses("fish1", [[[1, 1, 0, 0], [1, 1, 0, 0]]]) # track.add_entity_poses("fish1", [[[1, 1, 0, 0], [1, 1, 0, 0]]])
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment