Skip to content
Snippets Groups Projects
Commit 451fb38e authored by Andi Gerken's avatar Andi Gerken
Browse files

Merge branch 'dev_moritz_property' into 'master'

Add properties

Closes #2

See merge request !8
parents f2956f7b dc02a2cd
No related branches found
No related tags found
1 merge request!8Add properties
Pipeline #35613 passed with warnings
%% Cell type:code id: tags:
``` python
#! /usr/bin/env python3
import robofish.io
import numpy as np
from pathlib import Path
import os
# Helper function to enable relative paths from this file
def full_path(path):
return (Path(os.path.abspath("__file__")).parent / path).resolve()
if __name__ == "__main__":
# Create a new io file object with a 100x100cm world
sf = robofish.io.File(world_size=[100, 100])
# create a simple obstacle, fixed in place, fixed outline
obstacle_pose = [[50, 50, 0, 0]]
obstacle_outline = [[[-10, -10], [-10, 0], [0, 0], [0, -10]]]
obstacle_name = sf.create_entity(
"obstacle", poses=obstacle_pose, outlines=obstacle_outline
)
# create a robofish with 100 timesteps and 40ms between the timesteps. If we would not give a name, the name would be generated to be robot_1.
robofish_timesteps = 4
robofish_poses = np.zeros((robofish_timesteps, 4))
sf.create_entity("robot", robofish_poses, name="robot", monotonic_step=40)
# create multiple fishes with timestamps. Since we don't specify names, but only the type "fish" the fishes will be named ["fish_1", "fish_2", "fish_3"]
agents = 3
timesteps = 5
timestamps = np.linspace(0, timesteps + 1, timesteps)
agent_poses = np.random.random((agents, timesteps, 4))
fish_names = sf.create_multiple_entities(
"fish", agent_poses, monotonic_points=timestamps
)
# This would throw an exception if the file was invalid
sf.validate()
# Save file validates aswell
example_file = full_path("example.hdf5")
sf.save(example_file)
# Closing and opening files (just for demonstration)
sf.close()
sf = robofish.io.File(path=example_file)
print("\nEntity Names")
print(sf.get_entity_names())
print(sf.entity_names)
# Get an array with all poses. As the length of poses varies per agent, it
# is filled up with nans. The result is not interpolated and the time scales
# per agent are different. It is planned to create a warning in the case of
# different time scales and have another function, which generates an
# interpolated array.
print("\nAll poses")
print(sf.get_poses_array())
print(sf.select_entity_poses())
print("\nFish poses")
print(sf.get_poses_array(fish_names))
print(sf.select_entity_poses(lambda e: e.category == "fish"))
print("\nFile structure")
print(sf)
```
%% Output
Entity Names
['fish_1', 'fish_2', 'fish_3', 'obstacle_1', 'robot']
All poses
[[[8.86584103e-01 2.35670820e-01 5.41754842e-01 4.49850202e-01]
[8.15511882e-01 4.78223324e-01 6.29803419e-01 1.12592392e-01]
[1.53732300e-01 7.24954247e-01 9.38574493e-01 4.65665817e-01]
[9.10354614e-01 4.47880208e-01 3.81429136e-01 9.67544317e-01]
[6.07822955e-01 5.20158827e-01 8.17965686e-01 8.42760384e-01]]
[[2.29353935e-01 8.80753636e-01 7.94585168e-01 2.22074524e-01]
[6.13970399e-01 1.33511815e-02 2.89155185e-01 2.65219092e-01]
[6.62197351e-01 6.47982001e-01 9.46004018e-02 6.59599364e-01]
[4.86104101e-01 4.23153102e-01 1.39821902e-01 3.11809748e-01]
[8.03322852e-01 9.52799857e-01 3.89638603e-01 6.43237352e-01]]
[[9.70978260e-01 6.75936878e-01 6.23196602e-01 8.42264950e-01]
[4.07079160e-01 8.46290290e-01 5.64092159e-01 3.56871307e-01]
[4.84096229e-01 8.60232174e-01 1.39015794e-01 7.82253265e-01]
[1.24170482e-01 2.21511930e-01 8.88282284e-02 4.53450561e-01]
[1.28404438e-01 2.87771430e-02 4.57022637e-01 9.80571806e-01]]
[[5.00000000e+01 5.00000000e+01 0.00000000e+00 0.00000000e+00]
[ nan nan nan nan]
[ nan nan nan nan]
[ nan nan nan nan]
[ nan nan nan nan]]
[[0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]
[0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]
[0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]
[0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]
[ nan nan nan nan]]]
Fish poses
[[[0.8865841 0.23567082 0.54175484 0.4498502 ]
[0.81551188 0.47822332 0.62980342 0.11259239]
[0.1537323 0.72495425 0.93857449 0.46566582]
[0.91035461 0.44788021 0.38142914 0.96754432]
[0.60782295 0.52015883 0.81796569 0.84276038]]
[[0.22935393 0.88075364 0.79458517 0.22207452]
[0.6139704 0.01335118 0.28915519 0.26521909]
[0.66219735 0.647982 0.0946004 0.65959936]
[0.4861041 0.4231531 0.1398219 0.31180975]
[0.80332285 0.95279986 0.3896386 0.64323735]]
[[0.97097826 0.67593688 0.6231966 0.84226495]
[0.40707916 0.84629029 0.56409216 0.35687131]
[0.48409623 0.86023217 0.13901579 0.78225327]
[0.12417048 0.22151193 0.08882823 0.45345056]
[0.12840444 0.02877714 0.45702264 0.98057181]]]
File structure
version: [1 0]
world size: [100. 100.]
| entities
|---| fish_1
|---|--- type: fish
|---|--- poses: Shape (5, 4)
|---|---| time
|---|---|--- monotonic points: Shape (5,)
|---| fish_2
|---|--- type: fish
|---|--- poses: Shape (5, 4)
|---|---| time
|---|---|--- monotonic points: Shape (5,)
|---| fish_3
|---|--- type: fish
|---|--- poses: Shape (5, 4)
|---|---| time
|---|---|--- monotonic points: Shape (5,)
|---| obstacle_1
|---|--- type: obstacle
|---|--- outlines: Shape (1, 4, 2)
|---|--- poses: Shape (1, 4)
|---|---| time
|---| robot
|---|--- type: robot
|---|--- poses: Shape (4, 4)
|---|---| time
|---|---|--- monotonic step: 40
......
......@@ -46,14 +46,14 @@ if __name__ == "__main__":
sf = robofish.io.File(path=example_file)
print("\nEntity Names")
print(sf.get_entity_names())
print(sf.entity_names)
# Get an array with all poses. As the length of poses varies per agent, it is filled up with nans.
print("\nAll poses")
print(sf.get_poses())
print(sf.entity_poses)
print("\nFish poses")
print(sf.get_poses(category="fish"))
print(sf.select_entity_poses(lambda e: e.category == "fish"))
print("\nFile structure")
print(sf)
......@@ -54,8 +54,17 @@ class Entity(h5py.Group):
ori_vec[:, 1] = np.sin(ori_rad[:, 0])
return ori_vec
def getName(self):
return self.name.split("/")[-1]
@property
def group_name(self):
return super().name
@property
def name(self):
return self.group_name.split("/")[-1]
@property
def category(self):
return self.attrs["category"]
def create_outlines(self, outlines: Iterable, sampling=None):
outlines = self.create_dataset("outlines", data=outlines, dtype=np.float32)
......@@ -98,12 +107,21 @@ class Entity(h5py.Group):
positions.attrs["sampling"] = sampling
orientations.attrs["sampling"] = sampling
def get_poses(self):
poses = np.concatenate([self["positions"], self["orientations"]], axis=1)
return poses
@property
def positions(self):
return self["positions"]
@property
def orientations(self):
return self["orientations"]
@property
def poses(self):
return np.concatenate([self.positions, self.orientations], axis=1)
def get_poses_rad(self):
poses = self.get_poses()
@property
def poses_rad(self):
poses = self.poses
# calculate the angles from the orientation vectors, write them to the third row and delete the fourth row
poses[:, 2] = np.arctan2(poses[:, 3], poses[:, 2])
poses = poses[:, :3]
......
......@@ -201,7 +201,13 @@ class File(h5py.File):
self["samplings"].attrs["default"] = self.default_sampling
return name
def get_frequency(self):
@property
def world_size(self):
return self.attrs["world_size_cm"]
@property
def frequency(self):
# NOTE: Only works if default sampling availabe and specified with frequency_hz.
default_sampling = self["samplings"].attrs["default"]
return self["samplings"][default_sampling].attrs["frequency_hz"]
......@@ -290,13 +296,8 @@ class File(h5py.File):
)
return returned_names
def get_entities(self):
return {
e_name: robofish.io.Entity.from_h5py_group(e_group)
for e_name, e_group in self["entities"].items()
}
def get_entity_names(self) -> Iterable[str]:
@property
def entity_names(self) -> Iterable[str]:
""" Getter for the names of all entities
Returns:
......@@ -304,8 +305,16 @@ class File(h5py.File):
"""
return sorted(self["entities"].keys())
def get_poses(self, names: Iterable = None, category: str = None) -> Iterable:
""" Get an array of the poses of entities
@property
def entities(self):
return [robofish.io.Entity.from_h5py_group(self["entities"][name]) for name in self.entity_names]
@property
def entity_poses(self):
return self.select_entity_poses(None)
def select_entity_poses(self, predicate = None) -> Iterable:
""" Select an array of the poses of entities
If no name or category is specified, all entities will be selected.
......@@ -316,57 +325,29 @@ class File(h5py.File):
An three dimensional array of all poses with the shape (entity, time, 4)
"""
if names is not None and category is not None:
logging.error("Specify either names or a category, not both.")
raise Exception
# collect the names of all entities with the correct category
if category is not None:
names = [
e_name
for e_name, e_data in self["entities"].items()
if e_data.attrs["category"] == category
]
entities = self.get_entities()
entities = self.entities
if predicate is not None:
entities = [e for e in entities if predicate(e)]
# If no names or category are given, select all
if names is None:
names = sorted(entities.keys())
# Entity objects given as names
if all([type(name) == robofish.io.Entity for name in names]):
names = [entity.getName() for entity in names]
if not all([type(name) == str for name in names]):
raise Exception(
"Given names were not strings. Instead names were %s" % names
)
max_timesteps = (
0
if len(names) == 0
else max([entities[e_name]["positions"].shape[0] for e_name in names])
)
max_timesteps = max([0] + [e.positions.shape[0] for e in entities])
# Initialize poses output array
poses_output = np.empty((len(names), max_timesteps, 4))
poses_output = np.empty((len(entities), max_timesteps, 4))
poses_output[:] = np.nan
# Fill poses output array
i = 0
custom_sampling = None
for name in names:
entity = entities[name]
for entity in entities:
if "sampling" in entity["positions"].attrs:
if custom_sampling is None:
custom_sampling = entity["positions"].attrs["sampling"]
elif custom_sampling != entity["positions"].attrs["sampling"]:
raise Exception(
"Multiple samplings found, which can not be given back by the get_poses function collectively."
"Multiple samplings found, preventing return of a single array."
)
poses = entity.get_poses()
poses = entity.poses
poses_output[i][: poses.shape[0]] = poses
i += 1
return poses_output
......
......@@ -150,7 +150,9 @@ def validate(iofile: File, strict_validate: bool = True) -> (bool, str):
# validate entities
assert_validate("entities" in iofile, "entities not found")
for e_name, entity in iofile.get_entities().items():
for entity in iofile.entities:
e_name = entity.name
assert_validate(
type(entity) == Entity,
"Entity group was not a robofish.io.Entity object",
......
......@@ -12,7 +12,7 @@ def test_entity_object():
sf = robofish.io.File(world_size_cm=[100, 100], frequency_hz=25)
f = sf.create_entity("fish", positions=[[10, 10]])
assert type(f) == robofish.io.Entity
assert f.getName() == "fish_1"
assert f.name == "fish_1"
assert f.attrs["category"] == "fish"
print(dir(f))
print(f["positions"])
......@@ -26,7 +26,7 @@ def test_entity_object():
f2 = sf.create_entity("fish", poses=poses_rad)
assert type(f2["positions"]) == h5py.Dataset
assert type(f2["orientations"]) == h5py.Dataset
poses_rad_retrieved = f2.get_poses_rad()
poses_rad_retrieved = f2.poses_rad
# Check if retrieved rad poses is close to the original poses.
# Internally always ori_x and ori_y are used. When retrieved, the range is from -pi to pi, so for some of our original data 2 pi has to be substracted.
......
......@@ -69,7 +69,7 @@ def test_multiple_entities():
sf = robofish.io.File(world_size_cm=[100, 100], monotonic_time_points_us=m_points)
returned_entities = sf.create_multiple_entities("fish", poses)
returned_names = [entity.getName() for entity in returned_entities]
returned_names = [entity.name for entity in returned_entities]
expected_names = ["fish_1", "fish_2", "fish_3"]
print(returned_names)
......@@ -79,24 +79,24 @@ def test_multiple_entities():
sf.validate()
# The returned poses should be equal to the inserted poses
returned_poses = sf.get_poses()
returned_poses = sf.entity_poses
print(returned_poses)
assert (returned_poses == poses).all()
# Just get the array for some names
returned_poses = sf.get_poses(["fish_1", "fish_2"])
returned_poses = sf.select_entity_poses(lambda e: e.name in ["fish_1", "fish_2"])
assert (returned_poses == poses[:2]).all()
# Falsely specify names and category
with pytest.raises(Exception):
sf.get_poses(names=["fish_1"], category="fish")
# Filter on both category and name
returned_poses = sf.select_entity_poses(lambda e: e.category == "fish" and e.name == "fish_1")
assert (returned_poses == poses[:1]).all()
# Insert some random obstacles
returned_names = sf.create_multiple_entities(
"obstacle", poses=np.random.random((agents, timesteps, 4))
)
# Obstacles should not be returned when only fish are selected
returned_poses = sf.get_poses(category="fish")
returned_poses = sf.select_entity_poses(lambda e: e.category == "fish")
assert (returned_poses == poses).all()
# for each of the entities
......@@ -124,12 +124,12 @@ def test_multiple_entities():
print(returned_names)
print(sf)
# pass an poses array in separate parts (positions, orientations) and retreive it with get_poses.
# pass an poses array in separate parts (positions, orientations) and retrieve it with poses.
poses_arr = np.random.random((100, 4))
position_orientation_fish = sf.create_entity(
"fish", positions=poses_arr[:, :2], orientations=poses_arr[:, 2:]
)
assert np.isclose(poses_arr, position_orientation_fish.get_poses()).all()
assert np.isclose(poses_arr, position_orientation_fish.poses).all()
sf.validate()
return sf
......@@ -142,7 +142,7 @@ def test_load_validate():
def test_get_entity_names():
sf = robofish.io.File(path=valid_file_path)
names = sf.get_entity_names()
names = sf.entity_names
assert len(names) == 9
assert names[0] == "fish_1"
assert names[1] == "fish_2"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment