Skip to content
Snippets Groups Projects
Commit e9ea3fe2 authored by marc131183's avatar marc131183
Browse files

merging

Merge branch 'master' into develop_marc
parents 2525ed7e c838f70d
No related branches found
No related tags found
1 merge request!11Develop marc
Showing
with 641 additions and 170 deletions
......@@ -17,3 +17,4 @@ env
*.mp4
feature_requests.md
output_graph.png
......@@ -4,3 +4,6 @@ pdoc --html robofish.io robofish.evaluate --html-dir docs --force
## Code coverage:
pytest --cov=src --cov-report=html
## Flake
flake8 --ignore E203 --max-line-length 88
......@@ -10,11 +10,12 @@ def create_example_file(path):
# Create a new robot entity. Positions and orientations are passed
# separately in this example. Since the orientations have two columns,
# unit vectors are assumed (orientation_x, orientation_y)
circle_rad = np.linspace(0, 2 * np.pi, num=100)
f.create_entity(
category="robot",
name="robot",
positions=np.zeros((100, 2)),
orientations=np.ones((100, 2)) * [0, 1],
positions=np.stack((np.cos(circle_rad), np.sin(circle_rad))).T * 40,
orientations=np.stack((-np.sin(circle_rad), np.cos(circle_rad))).T,
)
# Create a new fish entity.
......
[build-system]
# Minimum requirements for the build system to execute.
requires = ["setuptools", "wheel", "pytest", "pandas", "deprecation"] # PEP 508 specifications.
requires = ["setuptools", "wheel"] # PEP 508 specifications.
[flake8]
max-line-length = 88
extend-ignore = "E203"
......@@ -13,6 +13,21 @@ import robofish.evaluate
import argparse
def function_dict():
base = robofish.evaluate.evaluate
return {
"speed": base.evaluate_speed,
"turn": base.evaluate_turn,
"orientation": base.evaluate_orientation,
"relative_orientation": base.evaluate_relativeOrientation,
"distance_to_wall": base.evaluate_distanceToWall,
"tank_positions": base.evaluate_tankpositions,
"trajectories": base.evaluate_trajectories,
"evaluate_positionVec": base.evaluate_positionVec,
"follow_iid": base.evaluate_follow_iid,
}
def evaluate(args=None):
"""This function can be called from the commandline to evaluate files.
......@@ -24,13 +39,7 @@ def evaluate(args=None):
(robofish-io-evaluate --help for more info)
"""
function_dict = {
"speed": robofish.evaluate.evaluate.evaluate_speed,
"turn": robofish.evaluate.evaluate.evaluate_turn,
"tank_positions": robofish.evaluate.evaluate.evaluate_tankpositions,
"trajectories": robofish.evaluate.evaluate.evaluate_trajectories,
"follow_iid": robofish.evaluate.evaluate.evaluate_follow_iid,
}
fdict = function_dict()
parser = argparse.ArgumentParser(
description="This function can be called from the commandline to evaluate files.\
......@@ -42,7 +51,7 @@ def evaluate(args=None):
parser.add_argument(
"analysis_type",
type=str,
choices=function_dict.keys(),
choices=fdict.keys(),
help="The type of analysis.\
speed - A histogram of speeds\
turn - A histogram of angular velocities\
......@@ -76,7 +85,7 @@ def evaluate(args=None):
if args is None:
args = parser.parse_args()
if args.analysis_type in function_dict:
function_dict[args.analysis_type](args.paths, args.names, args.save_path)
if args.analysis_type in fdict:
fdict[args.analysis_type](args.paths, args.names, args.save_path)
else:
print(f"Evaluation function not found {args.analysis_type}")
This diff is collapsed.
......@@ -137,6 +137,18 @@ class Entity(h5py.Group):
return np.tile([1, 0], (self.positions.shape[0], 1))
return self["orientations"]
@property
def orientations_calculated(self):
diff = np.diff(self.positions, axis=0)
angles = np.arctan2(diff[:, 1], diff[:, 0])
return angles[:, np.newaxis]
@property
def poses_calc_ori_rad(self):
return np.concatenate(
[self.positions[:-1], self.orientations_calculated], axis=1
)
@property
def poses(self):
return np.concatenate([self.positions, self.orientations], axis=1)
......@@ -145,6 +157,26 @@ class Entity(h5py.Group):
def poses_rad(self):
poses = self.poses
# calculate the angles from the orientation vectors, write them to the third row and delete the fourth row
poses[:, 2] = np.arctan2(poses[:, 3], poses[:, 2])
poses = poses[:, :3]
return poses
ori_rad = np.arctan2(poses[:, 3], poses[:, 2])
return np.concatenate([poses[:, :2], ori_rad[:, np.newaxis]], axis=1)
@property
def speed_turn(self):
"""Get the speed, turn and from the positions.
The vectors pointing from each position to the next are computed.
The output of the function describe these vectors.
Returns:
An array with shape (number_of_positions -1, 3).
The first column is the length of the vectors.
The second column is the turning angle, required to get from one vector to the next.
We assume, that the entity is oriented "correctly" in the first pose. So the first turn angle is 0.
"""
diff = np.diff(self.positions, axis=0)
speed = np.linalg.norm(diff, axis=1)
angles = np.arctan2(diff[:, 1], diff[:, 0])
turn = np.zeros_like(angles)
turn[0] = 0
turn[1:] = utils.limit_angle_range(np.diff(angles))
return np.stack([speed, turn], axis=-1)
......@@ -15,6 +15,7 @@
# -----------------------------------------------------------
import robofish.io
from robofish.io.entity import Entity
import h5py
import numpy as np
......@@ -27,6 +28,7 @@ import datetime
import tempfile
import uuid
import deprecation
import types
default_format_version = np.array([1, 0], dtype=np.int32)
......@@ -131,6 +133,7 @@ class File(h5py.File):
def __exit__(self, type, value, traceback):
# Check if the context was left under normal circumstances
if (type, value, traceback) == (None, None, None):
if self.mode != "r": # No need to validate read only files (performance).
self.validate()
self.close()
......@@ -201,6 +204,9 @@ class File(h5py.File):
calendar_time_points = [
format_calendar_time_point(p) for p in calendar_time_points
]
for c in calendar_time_points:
print(type(c))
sampling.create_dataset(
"calendar_time_points",
data=calendar_time_points,
......@@ -223,9 +229,33 @@ class File(h5py.File):
@property
def frequency(self):
# NOTE: Only works if default sampling availabe and specified with frequency_hz.
default_sampling = self["samplings"].attrs["default"]
return self["samplings"][default_sampling].attrs["frequency_hz"]
common_sampling = self.common_sampling()
assert common_sampling is not None, "The sampling differs between entities."
assert (
"frequency_hz" in common_sampling.attrs
), "The common sampling has no frequency_hz"
return common_sampling.attrs["frequency_hz"]
def common_sampling(
self, entities: Iterable["robofish.io.Entity"] = None
) -> h5py.Group:
"""Check if all entities have the same sampling.
Args:
entities: optional array of entities. If None is given, all entities are checked.
Returns:
The h5py group of the common sampling. If there is no common sampling, None will be returned.
"""
custom_sampling = None
for entity in self.entities:
if "sampling" in entity["positions"].attrs:
this_sampling = entity["positions"].attrs["sampling"]
if custom_sampling is None:
custom_sampling = this_sampling
elif custom_sampling != this_sampling:
return None
sampling = self.default_sampling if custom_sampling is None else custom_sampling
return self["samplings"][sampling]
def create_entity(
self,
......@@ -329,52 +359,67 @@ class File(h5py.File):
@property
def entity_poses(self):
return self.select_entity_poses(None)
return self.select_entity_property(None)
@property
def entity_poses_rad(self):
return self.select_entity_poses(None, rad=True)
return self.select_entity_property(None, entity_property=Entity.poses_rad)
def select_entity_poses(self, predicate=None, rad=False) -> Iterable:
""" TODO: Rework
Select an array of the poses of entities
@property
def entity_poses_calc_ori_rad(self):
return self.select_entity_property(
None, entity_property=Entity.poses_calc_ori_rad
)
@property
def entity_speeds_turns(self):
return self.select_entity_property(None, entity_property=Entity.speed_turn)
def select_entity_poses(self, *args, ori_rad=False, **kwargs):
entity_property = Entity.poses_rad if ori_rad else Entity.poses
return self.select_entity_property(
*args, entity_property=entity_property, **kwargs
)
If no name or category is specified, all entities will be selected.
def select_entity_property(
self,
predicate: types.LambdaType = None,
entity_property: property = Entity.poses,
) -> Iterable:
"""Get a property of selected entities.
Entities can be selected, using a lambda function.
The property of the entities can be selected.
Args:
names: optional array of the names of selected entities
category: optional selected category
predicate: a lambda function, selecting entities
(example: lambda e: e.category == "fish")
entity_property: a property of the Entity class (example: Entity.poses_rad)
Returns:
An three dimensional array of all poses with the shape (entity, time, 4)
An three dimensional array of all properties of all entities with the shape (entity, time, property_length).
If an entity has a shorter length of the property, the output will be filled with nans.
"""
entities = self.entities
if predicate is not None:
entities = [e for e in entities if predicate(e)]
max_timesteps = max([0] + [e.positions.shape[0] for e in entities])
assert self.common_sampling(entities) is not None
# Initialize poses output array
pose_len = 3 if rad else 4
poses_output = np.empty((len(entities), max_timesteps, pose_len))
poses_output[:] = np.nan
properties = [entity_property.__get__(entity) for entity in entities]
# Fill poses output array
i = 0
custom_sampling = None
for entity in entities:
max_timesteps = max([0] + [p.shape[0] for p in properties])
if "sampling" in entity["positions"].attrs:
if custom_sampling is None:
custom_sampling = entity["positions"].attrs["sampling"]
elif custom_sampling != entity["positions"].attrs["sampling"]:
raise Exception(
"Multiple samplings found, preventing return of a single array."
property_array = np.empty(
(len(entities), max_timesteps, properties[0].shape[1])
)
poses = entity.poses_rad if rad else entity.poses
poses_output[i][: poses.shape[0]] = poses
i += 1
return poses_output
property_array[:] = np.nan
# Fill output array
for i, entity in enumerate(entities):
property_array[i][: properties[i].shape[0]] = properties[i]
return property_array
@deprecation.deprecated(
deprecated_in="1.1.2",
......@@ -393,6 +438,24 @@ class File(h5py.File):
predicate = None
return self.select_entity_poses(predicate)
def entity_turn_speed(self, predicate=None):
"""Get an array of turns and speeds of the entities."""
entities = self.entities
if predicate is not None:
entities = [e for e in entities if predicate(e)]
assert self.common_sampling(entities) is not None
# Initialize poses output array
max_timesteps = max([0] + [e.positions.shape[0] for e in entities])
turn_speed = np.empty((len(entities), max_timesteps, 3))
turn_speed[:] = np.nan
for i, entity in enumerate(entities):
poses = entity.poses_rad if rad else entity.poses
poses_output[i][: poses.shape[0]] = poses
def validate(self, strict_validate: bool = True) -> (bool, str):
"""Validate the file to the specification.
......
......@@ -24,7 +24,6 @@ def read_multiple_files(
paths: Union[Path, str, Iterable[Path], Iterable[str]],
strict_validate: bool = False,
max_files: int = None,
shuffle: bool = False,
) -> dict:
"""Load hdf5 files from a given path.
......@@ -35,7 +34,7 @@ def read_multiple_files(
Args:
path: The path to a hdf5 file or folder.
strict_validate: Choice between error and warning in case of invalidity
max: Maximum number of files to be read
max_files: Maximum number of files to be read
Returns:
dict: A dictionary where the keys are filenames and the opened robofish.io.File objects
"""
......@@ -49,8 +48,6 @@ def read_multiple_files(
paths = [Path(p) for p in paths]
sf_dict = {}
if shuffle:
random.shuffle(paths)
for path in paths:
if path.is_dir():
logging.info("found dir %s" % path)
......@@ -72,3 +69,71 @@ def read_multiple_files(
logging.info("found file %s" % path)
sf_dict[path] = robofish.io.File(path=path, strict_validate=strict_validate)
return sf_dict
def read_property_from_multiple_files(
paths: Union[Path, str, Iterable[Path], Iterable[str]],
entity_property: property = None,
*,
strict_validate: bool = False,
max_files: int = None,
shuffle: bool = False,
predicate: callable = None,
):
"""Load hdf5 files from a given path and return the property of the entities.
The function can be given the path to a single single hdf5 file, to a folder,
containing hdf5 files, or an array of multiple files or folders.
Args:
path: The path to a hdf5 file or folder.
entity_property: A property of robofish.io.Entity default is Entity.poses_rad
strict_validate: Choice between error and warning in case of invalidity
max_files: Maximum number of files to be read
shuffle: Shuffle the order of files
predicate:
Returns:
An array of all entity properties arrays
"""
assert (
entity_property is not None
), "Please select an entity property e.g. 'Entity.poses_rad'"
logging.info(f"Reading files from path {paths}")
list_types = (list, np.ndarray, pandas.core.series.Series)
if not isinstance(paths, list_types):
paths = [paths]
paths = [Path(p) for p in paths]
poses_array = []
for path in paths:
if path.is_dir():
logging.info("found dir %s" % path)
# Find all hdf5 files in folder
files = []
for ext in ("hdf", "hdf5", "h5", "he5"):
files += list(path.rglob(f"*.{ext}"))
files = random.shuffle(files) if shuffle else sorted(files)
logging.info("Reading files")
for file in files:
if max_files is not None and len(poses_array) >= max_files:
break
if not file.is_dir():
with robofish.io.File(
path=file, strict_validate=strict_validate
) as f:
p = f.select_entity_property(predicate, entity_property)
poses_array.append(p)
elif path is not None and path.exists():
logging.info("found file %s" % path)
with robofish.io.File(path=path, strict_validate=strict_validate) as f:
p = f.select_entity_property(predicate, entity_property)
poses_array.append(p)
return poses_array
\ No newline at end of file
import robofish.io
import numpy as np
from typing import Union, Iterable
from pathlib import Path
import os
......@@ -13,3 +14,25 @@ def np_array(*arrays):
def full_path(current_file, path):
return (Path(current_file).parent / path).resolve()
def limit_angle_range(angle: Union[float, Iterable], _range=(-np.pi, np.pi)):
"""Limit the range of an angle or array of angles between min and max
Any given angle in rad will be moved to the given range.
e.g. with min = -pi and max pi, 4.5*pi will be moved to be 0.5*pi.
Args:
angle: An angle or an array of angles (1D)
_range: optional tuple of range (min, max)
Returns:
limited angle(s) in the same form as the input (float or ndarray)
"""
assert np.isclose(_range[1] - _range[0], 2 * np.pi)
def limit_one(value):
return (value - _range[0]) % (2 * np.pi) + _range[0]
if isinstance(angle, Iterable):
return np.array([limit_one(v) for v in angle])
else:
return limit_one(angle)
\ No newline at end of file
......@@ -6,7 +6,9 @@ import numpy as np
import logging
def assert_validate(statement: bool, message: str, location: str = None) -> None:
def assert_validate(
statement: bool, message: str, location: str = None, strict_validate=True
) -> None:
""" Assert the statement and attach the entity name to the error message.
Args:
......@@ -18,9 +20,12 @@ def assert_validate(statement: bool, message: str, location: str = None) -> None
"""
if not statement:
if location:
raise AssertionError("%s in %s" % (message, location))
else:
message = "%s in %s" % (message, location)
if strict_validate:
raise AssertionError(message)
else:
logging.warning(message)
def assert_validate_type(
......@@ -230,6 +235,7 @@ def validate(iofile: File, strict_validate: bool = True) -> (bool, str):
e_name,
)
if strict_validate:
validate_orientations_length(orientations, e_name)
# outlines
......@@ -329,11 +335,18 @@ def validate_positions_range(world_size, positions, e_name):
def validate_orientations_length(orientations, e_name):
ori_lengths = np.linalg.norm(orientations, axis=1)
# import matplotlib.pyplot as plt
#
# plt.plot(ori_lengths)
# plt.show()
# Check if all orientation lengths are all 1. Different lengths cause warnings.
assert_validate(
np.isclose(ori_lengths, 1).all(),
"The orientation vectors were not unit vectors. Their length was in the range [%.2f, %.2f] when it should be 1"
"The orientation vectors were not unit vectors. Their length was in the range [%.4f, %.4f] when it should be 1"
% (min(ori_lengths), max(ori_lengths)),
e_name,
strict_validate=False,
)
......
No preview for this file type
......@@ -22,7 +22,6 @@ def test_app_validate():
self.names = None
self.save_path = graphics_out
# TODO: Get rid of deprecation
with pytest.warns(DeprecationWarning):
app.evaluate(DummyArgs("speed"))
for mode in app.function_dict().keys():
app.evaluate(DummyArgs(mode))
graphics_out.unlink()
import pytest
import robofish.evaluate
from robofish.io import utils
import numpy as np
def test_get_all_poses_from_paths():
valid_file_path = utils.full_path(__file__, "../../resources/valid.hdf5")
poses, frequency = robofish.evaluate.get_all_poses_from_paths([valid_file_path])
# (1 input array, 1 file, 2 fishes, 100 timesteps, 4 poses)
assert np.array(poses).shape == (1, 1, 2, 100, 4)
assert frequency == 25
......@@ -29,3 +29,37 @@ def test_entity_object():
assert np.isclose(poses_rad[i, 2], poses_rad_retrieved[i, 2]) or np.isclose(
poses_rad[i, 2] - 2 * np.pi, poses_rad_retrieved[i, 2]
)
def test_entity_turn_speed():
f = robofish.io.File(world_size_cm=[100, 100], frequency_hz=25)
circle_rad = np.linspace(0, 2 * np.pi, num=100)
circle_size = 40
positions = np.stack(
[np.cos(circle_rad) * circle_size, np.sin(circle_rad) * circle_size], axis=-1
)
e = f.create_entity("fish", positions=positions)
speed_turn = e.speed_turn
assert speed_turn.shape == (99, 2)
# No turn in the first timestep, since initialization turns it the right way
assert speed_turn[0, 1] == 0
# Turns and speeds shoud afterwards be all the same afterwards, since the fish swims with constant velocity and angular velocity.
assert (np.std(speed_turn[1:], axis=0) < 0.0001).all()
# Use turn_speed to generate positions
gen_positions = np.zeros((positions.shape[0], 3))
gen_positions[0] = e.poses_calc_ori_rad[0]
for i, (speed, turn) in enumerate(speed_turn):
new_angle = gen_positions[i, 2] + turn
gen_positions[i + 1] = [
gen_positions[i, 0] + np.cos(new_angle) * speed,
gen_positions[i, 1] + np.sin(new_angle) * speed,
new_angle,
]
# The resulting positions should almost be equal to the the given positions
print(gen_positions[:, :2] - positions)
assert np.isclose(positions, gen_positions[:, :2], atol=1.0e-5).all()
import robofish.io
from robofish.io import utils
from pathlib import Path
from testbook import testbook
import sys
......@@ -29,6 +29,8 @@ def test_example_basic():
# This test can be executed manually. The CI/CD System has issues with testbook.
def manual_test_example_basic_ipynb():
from testbook import testbook
# Executing the notebook should not lead to an exception
with testbook(str(ipynb_path), execute=True) as tb:
pass
......
......@@ -89,11 +89,11 @@ def test_multiple_entities():
assert (returned_poses == poses).all()
# Just get the array for some names
returned_poses = sf.select_entity_poses(lambda e: e.name in ["fish_1", "fish_2"])
returned_poses = sf.select_entity_property(lambda e: e.name in ["fish_1", "fish_2"])
assert (returned_poses == poses[:2]).all()
# Filter on both category and name
returned_poses = sf.select_entity_poses(
returned_poses = sf.select_entity_property(
lambda e: e.category == "fish" and e.name == "fish_1"
)
assert (returned_poses == poses[:1]).all()
......@@ -103,7 +103,7 @@ def test_multiple_entities():
obs_poses = np.random.random((agents, 1, 3))
returned_names = sf.create_multiple_entities("obstacle", poses=obs_poses)
# Obstacles should not be returned when only fish are selected
returned_poses = sf.select_entity_poses(lambda e: e.category == "fish")
returned_poses = sf.select_entity_property(lambda e: e.category == "fish")
assert (returned_poses == poses).all()
# for each of the entities
......@@ -139,6 +139,15 @@ def test_multiple_entities():
return sf
def test_speeds_turns_angles():
with robofish.io.File(world_size_cm=[100, 100], frequency_hz=25) as f:
poses = np.zeros((10, 100, 3))
f.create_multiple_entities("fish", poses=poses)
# Stationary fish has no speed or turn
assert (f.entity_speeds_turns == 0).all()
def test_broken_sampling(caplog):
sf = robofish.io.File(world_size_cm=[10, 10])
caplog.set_level(logging.ERROR)
......@@ -178,6 +187,9 @@ def test_entity_positions_no_orientation():
assert f.entity_poses.shape == (1, 100, 4)
assert (f.entity_poses[:, :] == np.array([1, 1, 1, 0])).all()
# Calculate the orientation
assert f.entity_poses_calc_ori_rad.shape == (1, 99, 3)
def test_load_validate():
sf = robofish.io.File(path=valid_file_path)
......@@ -187,8 +199,8 @@ def test_load_validate():
def test_get_entity_names():
sf = robofish.io.File(path=valid_file_path)
names = sf.entity_names
assert len(names) == 1
assert names[0] == "fish_1"
assert len(names) == 2
assert names == ["fish_1", "robot"]
def test_File_without_path_or_worldsize():
......
......@@ -2,6 +2,7 @@ import robofish.io
from robofish.io import utils
import pytest
from pathlib import Path
import numpy as np
def test_now_iso8061():
......@@ -38,3 +39,18 @@ def test_read_multiple_folder():
for p, f in sf.items():
print(p)
assert type(f) == robofish.io.File
path = utils.full_path(__file__, "../../resources/valid.hdf5")
# TODO read from folder of valid files
@pytest.mark.parametrize("_path", [path, str(path)])
def test_read_poses_rad_from_multiple_folder(_path):
poses = robofish.io.read_property_from_multiple_files(
[_path, _path], robofish.io.entity.Entity.poses_rad
)
# Should find the 3 presaved hdf5 files
assert len(poses) == 2
for p in poses:
print(p)
assert type(p) == np.ndarray
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment