Skip to content
Snippets Groups Projects
Commit 5c3846b3 authored by Andi Gerken's avatar Andi Gerken
Browse files

Added prefix to the track dict, Added beginning of track.py (under development)

parent 35e02d03
No related branches found
No related tags found
1 merge request!1Added prefix to the track dict, Added beginning of track.py (under development)
Pipeline #33514 passed
......@@ -9,7 +9,7 @@ setup(
version="0.1",
author="",
author_email="",
install_requires=["h5py", "re", "numpy"],
install_requires=["h5py", "numpy"],
classifiers=[
"Development Status :: 3 - Alpha",
"Intended Audience :: Science/Research",
......
# TODO: version should come from config
# import configparser
# config.read("../../config.ini")
import robofish.io.util as util
default_trackformat_version = [1, 0]
class Track:
def __init__(self, world_size, version=default_trackformat_version):
assert len(world_size) == 2
assert len(version) == 2
self.track = {
"a_version": version,
"a_world_size": world_size,
"g_entities": {},
}
def _create_entity(self, entity_name, poses=None, outlines=None):
e_data = {time: {}}
if poses:
e_data.update({"d_poses": poses})
if outlines:
e_data.update({"d_outlines": outlines})
if entity_name not in self.track["entities"]:
self.track["entities"].update({entity_name: e_data})
def add_to_entity(self, entity_name, key, pre, data):
self._create_entity(entity_name)
key_w_pre = pre + "_" + key
self.track["entities"][entity_name][key] = data
def add_time_to_entity(self, entity_name, key, data):
# add_to_entity(entity_name,"time",data)
self._create_entity(entity_name)
self.track["entities"][entity_name]["time"][key] = data
def validate(self):
util.validate_track(self.track)
# test_dict = {
# "a_version": [1, 0],
# "a_world_size": [100, 100],
# "g_entities": {
# "g_fish_1": {
# "a_type": "fish",
# "d_poses": [[0, 0, 1, 0], [0.1, 0, 1, 0], [0.2, 0, 1, 0]],
# "e_outlines": [[[1, 1]]],
# "g_time": {
# "a_monotonic points": [0, 1, 2],
# "d_calendar points": [
# "2020-11-18T13:21:34.117+01:00",
# "2020-11-18T13:21:34.117+01:00",
# "2020-11-18T13:21:34.117+01:00",
# ],
# },
# },
# "g_obstacle1": {
# "a_type": "obstacle",
# "d_poses": [[0, 0, 1, 0]],
# "d_outlines": [[[1, 1]]],
# "g_time": {},
# },
# },
# }
......@@ -5,6 +5,10 @@
# Robofish track format (1.0 Draft 7). The standard is available at
# https://git.imp.fu-berlin.de/bioroboticslab/robofish/track_format
#
# The term track is used to describe a dictionary, describing the track in a dict.
# To distinguish between attributes, dictionaries and groups, a prefix is used
# (a_ for attribute, d_ for dictionary, and g_ for groups).
#
# Dec 2020 Andreas Gerken, Berlin, Germany
# Released under GNU 3.0 License
# email andi.gerken@gmail.com
......@@ -28,18 +32,12 @@ def write_hdf5_from_track(path, track, additional_dataset_names=[]):
Args:
path (str): The path, to the output file which should be written.
track (dict): A dict of the track which should be written.
additional_dataset_names (str array): If you use additional datasets,
which are not specified in the standard, pass their names using this
argument. Example: ['new_dataset']
Returns:
object: The hdf5 object of the written file.
None
Throws:
AssertionError: When the track is invalid.
"""
dataset_names = ["poses", "outlines", "monotonic points", "calendar points"]
dataset_names.append(additional_dataset_names)
# Validate track
track = validate_track(track, throw_exception=True)
assert track
......@@ -48,19 +46,22 @@ def write_hdf5_from_track(path, track, additional_dataset_names=[]):
def write_recursive(group, d):
# Attributes
for key, item in d.items():
if key in dataset_names:
for full_key, item in d.items():
pre, key = full_key.split("_", maxsplit=1)
if pre == "d":
group.create_dataset(key, data=item)
elif isinstance(item, dict):
elif pre == "g":
sub_group = group.create_group(key)
write_recursive(sub_group, d[key])
else:
write_recursive(sub_group, d[full_key])
elif pre == "a":
group.attrs[key] = item
else:
logging.warning("Prefix not known for %s" % full_key)
return
write_recursive(f, track)
return f
return
def read_track_from_hdf5(path, throw_exception=True):
......@@ -82,11 +83,11 @@ def read_track_from_hdf5(path, throw_exception=True):
def read_recursive(group):
# Attributes
d = {key: group.attrs[key] for key in group.attrs.keys()}
d = {"a_" + key: group.attrs[key] for key in group.attrs.keys()}
# Datasets
d.update(
{
key: np.array(group[key][:])
"d_" + key: np.array(group[key][:])
for key in group.keys()
if isinstance(group[key], h5py.Dataset)
}
......@@ -94,7 +95,7 @@ def read_track_from_hdf5(path, throw_exception=True):
# Groups (recursive)
d.update(
{
key: read_recursive(group[key])
"g_" + key: read_recursive(group[key])
for key in group.keys()
if isinstance(group[key], h5py.Group)
}
......@@ -133,11 +134,34 @@ def validate_track(track, throw_exception=True):
"""
array_dtypes = {
"root": {"version": np.int32, "world_size": np.float32},
"entity": {"poses": np.float32, "outlines": np.float32},
"entity_time": {"monotonic points": np.int64, "calendar points": "S"},
"root": {"a_version": np.int32, "a_world_size": np.float32},
"entity": {"d_poses": np.float32, "d_outlines": np.float32},
"entity_time": {"d_monotonic points": np.int64, "d_calendar points": "S"},
}
outlines = None
#### Checking for prefixes ####
def recursive_validate_prefix(dict_with_prefixes):
for full_key, value in dict_with_prefixes.items():
# split the key into pre and key
try:
pre, key = full_key.split("_", maxsplit=1)
except:
return False
# If value is a dict, check if pre is g and check the value recursively
if isinstance(value, dict):
if pre != "g" or not recursive_validate_prefix(value):
return False
# Check if the prefix is a or d
elif pre not in ["a", "d"]:
return False
return True
# call prefix validation
recursive_validate_prefix(track)
#### Reformatting to numpy arrays ####
def update_to_nparray(in_dict, dtypes):
in_dict.update(
......@@ -153,49 +177,62 @@ def validate_track(track, throw_exception=True):
update_to_nparray(track, array_dtypes["root"])
#### Validation ####
assert track.keys() >= {"version", "world_size"}
for e_name, e_data in track["entities"].items():
assert track.keys() >= {"a_version", "a_world_size"}
for e_name, e_data in track["g_entities"].items():
# Format all entity arrays
update_to_nparray(track["entities"][e_name], array_dtypes["entity"])
update_to_nparray(track["g_entities"][e_name], array_dtypes["entity"])
update_to_nparray(
track["entities"][e_name]["time"], array_dtypes["entity_time"]
track["g_entities"][e_name]["g_time"], array_dtypes["entity_time"]
)
# poses
poses = e_data["poses"]
poses = e_data["d_poses"]
assert poses.ndim == 2
assert poses.shape[1] == 4
# outlines
if "outlines" in e_data:
outlines = e_data["outlines"]
if "d_outlines" in e_data:
outlines = e_data["d_outlines"]
# 3 dimensional array
assert outlines.ndim == 3
# Either fixed outline or same length with poses
assert outlines.shape[0] == 1 or outlines.shape[0] == poses.shape[0]
# Outline from two dimensional points
assert outlines.shape[2] == 2
# time
time = e_data["time"]
if "monotonoc steps" in time:
time = e_data["g_time"]
if "a_monotonoc steps" in time:
pass
elif "monotonic points" in time:
monotonic_points = time["monotonic points"]
elif "d_monotonic points" in time:
monotonic_points = time["d_monotonic points"]
# 1 dimensional array
assert monotonic_points.ndim == 1
# Either fixed in place or same length with poses
assert (
poses.shape[0] == 1 or monotonic_points.shape[0] == poses.shape[0]
)
# Either there is no outline, or fixed outline or same length with outline
assert (
outlines.shape[0] == 1
or monotonic_points.shape[0] == outlines.shape[0]
outlines is None
or outlines.shape[0] == 1
or monotonic_points.shape[0] == outline.shape[0]
)
else:
# Fixed in Place and fixed outline
assert poses.shape[0] == 1
assert outlines.shape[0] == 1
assert outlines is None or outlines.shape[0] == 1
# calendar points
if "calendar points" in time:
calendar_points = time["calendar points"]
if "d_calendar points" in time:
calendar_points = time["d_calendar points"]
assert calendar_points.ndim == 1
assert calendar_points.shape[0] == monotonic_points.shape[0]
......
from robofish.io.track import Track
def test_init():
track = Track([100, 100])
# track.add_entity_poses("fish1", [[[1, 1, 0, 0], [1, 1, 0, 0]]])
......@@ -55,35 +55,33 @@ def test_write_hdf5_from_track():
"""
test_dict = {
"version": [1, 0],
"world_size": [100, 100],
"entities": {
"fish_1": {
"type": "fish",
"poses": [[0, 0, 1, 0], [0.1, 0, 1, 0], [0.2, 0, 1, 0]],
"outlines": [[[1, 1]]],
"time": {
"monotonic points": [0, 1, 2],
"calendar points": [
"a_version": [1, 0],
"a_world_size": [100, 100],
"g_entities": {
"g_fish_1": {
"a_type": "fish",
"d_poses": [[0, 0, 1, 0], [0.1, 0, 1, 0], [0.2, 0, 1, 0]],
"e_outlines": [[[1, 1]]],
"g_time": {
"d_monotonic points": [0, 1, 2],
"d_calendar points": [
"2020-11-18T13:21:34.117+01:00",
"2020-11-18T13:21:34.117+01:00",
"2020-11-18T13:21:34.117+01:00",
],
},
},
"obstacle1": {
"type": "obstacle",
"poses": [[0, 0, 1, 0]],
"outlines": [[[1, 1]]],
"time": {},
"g_obstacle1": {
"a_type": "obstacle",
"d_poses": [[0, 0, 1, 0]],
"d_outlines": [[[1, 1]]],
"g_time": {},
},
},
}
testfile_path = full_path("/../../resources/created_by_test.hdf5")
h5file = util.write_hdf5_from_track(testfile_path, test_dict)
assert type(h5file) == h5py._hl.files.File
util.write_hdf5_from_track(testfile_path, test_dict)
def test_read_track_from_hdf5():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment