Skip to content
Snippets Groups Projects
Commit 8cf4c862 authored by Andi Gerken's avatar Andi Gerken
Browse files

Switched off validation when opening a file.

This increases performance drastically.
parent eb95717d
No related branches found
No related tags found
No related merge requests found
......@@ -10,6 +10,7 @@ Functions available to be used in the commandline to evaluate robofish.io files.
# Last doku update Feb 2021
import robofish.evaluate
import argparse
from pathlib import Path
import matplotlib.pyplot as plt
......
......@@ -24,77 +24,6 @@ import inspect
import random
def get_all_files_from_paths(paths: Iterable[Union[str, Path]]):
# Find all files with correct ending
files = []
for path in [Path(p) for p in paths]:
if path.is_dir():
files_path = []
for ext in ("hdf", "hdf5", "h5", "he5"):
files_path += list(path.rglob(f"*.{ext}"))
files.append(files_path)
else:
files.append([path])
return files
def get_all_poses_from_paths(paths: Iterable[Union[str, Path]], predicate=None):
"""Read all poses from given paths.
The function shall be used by the evaluation functions.
Args:
paths: An array of strings, with files or folders.
The files are checked to have the same frequency.
Returns:
An array, containing poses with the shape
[paths][files][entities, timesteps, 4],
the common frequency of the files
"""
return get_all_data_from_paths(paths, "poses_4d", predicate)
def get_all_data_from_paths(
paths: Iterable[Union[str, Path]], request_type="poses_4d", predicate=None
):
expected_settings = None
all_data = []
files_per_path = get_all_files_from_paths(paths)
# for each given path
for files_in_path in files_per_path:
data_from_files = []
# Open all files, gather the poses and check if all of them have the same world size and frequency
for i_path, file_path in enumerate(files_in_path):
with robofish.io.File(file_path, "r") as file:
file_settings = {
"world_size_cm_x": file.attrs["world_size_cm"][0],
"world_size_cm_y": file.attrs["world_size_cm"][1],
"frequency_hz": file.frequency,
}
if expected_settings is None:
expected_settings = file_settings
assert file_settings == expected_settings
properties = {
"poses_4d": robofish.io.Entity.poses,
"speeds_turns": robofish.io.Entity.speed_turn,
}
pred = None if predicate is None else predicate[i_path]
data = file.select_entity_property(
pred, entity_property=properties[request_type]
)
data_from_files.append(data)
all_data.append(data_from_files)
return all_data, expected_settings
def evaluate_speed(
paths: Iterable[Union[str, Path]],
labels: Iterable[str] = None,
......@@ -113,7 +42,7 @@ def evaluate_speed(
"""
if speeds_turns_from_paths is None:
speeds_turns_from_paths, file_settings = get_all_data_from_paths(
speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths(
paths, "speeds_turns"
)
......@@ -172,7 +101,7 @@ def evaluate_turn(
"""
if speeds_turns_from_paths is None:
speeds_turns_from_paths, file_settings = get_all_data_from_paths(
speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths(
paths, "speeds_turns"
)
......@@ -233,7 +162,7 @@ def evaluate_orientation(
(example: lambda e: e.category == "fish")
"""
if poses_from_paths is None:
poses_from_paths, file_settings = get_all_poses_from_paths(paths)
poses_from_paths, file_settings = utils.get_all_poses_from_paths(paths)
world_bounds = [
-file_settings["world_size_cm_x"] / 2,
......@@ -317,7 +246,7 @@ def evaluate_relative_orientation(
"""
if poses_from_paths is None:
poses_from_paths, file_settings = get_all_poses_from_paths(paths)
poses_from_paths, file_settings = utils.get_all_poses_from_paths(paths)
orientations = []
# Iterate all paths
......@@ -367,7 +296,7 @@ def evaluate_distance_to_wall(
"""
if poses_from_paths is None:
poses_from_paths, file_settings = get_all_poses_from_paths(paths)
poses_from_paths, file_settings = utils.get_all_poses_from_paths(paths)
world_bounds = [
-file_settings["world_size_cm_x"] / 2,
......@@ -453,7 +382,7 @@ def evaluate_tank_position(
poses_step = 20
if poses_from_paths is None:
poses_from_paths, file_settings = get_all_poses_from_paths(paths)
poses_from_paths, file_settings = utils.get_all_poses_from_paths(paths)
xy_positions = []
......@@ -512,7 +441,7 @@ def evaluate_social_vector(
"""
if poses_from_paths is None:
poses_from_paths, file_settings = get_all_poses_from_paths(paths)
poses_from_paths, file_settings = utils.get_all_poses_from_paths(paths)
socialVec = []
......@@ -585,7 +514,7 @@ def evaluate_follow_iid(
(example: lambda e: e.category == "fish")
"""
if poses_from_paths is None:
poses_from_paths, file_settings = get_all_poses_from_paths(paths)
poses_from_paths, file_settings = utils.get_all_poses_from_paths(paths)
follow, iid = [], []
......@@ -708,7 +637,7 @@ def evaluate_tracks(
random.seed(seed)
files_per_path = get_all_files_from_paths(paths)
files_per_path = utils.get_all_files_from_paths(paths)
max_files_per_path = max([len(files) for files in files_per_path])
rows, cols = len(files_per_path), min(4, max_files_per_path)
......@@ -780,13 +709,15 @@ def evaluate_individuals(
"""
if speeds_turns_from_paths is None and mode == "speed":
speeds_turns_from_paths, file_settings = get_all_data_from_paths(
speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths(
paths, "speeds_turns"
)
if poses_from_paths is None and mode == "iid":
poses_from_paths, file_settings = get_all_data_from_paths(paths, "poses_4d")
poses_from_paths, file_settings = utils.get_all_data_from_paths(
paths, "poses_4d"
)
files_from_paths = get_all_files_from_paths(paths)
files_from_paths = utils.get_all_files_from_paths(paths)
fig = plt.figure(figsize=(10, 4))
# small_iid_files = []
......@@ -798,7 +729,8 @@ def evaluate_individuals(
for f, file_path in enumerate(files_in_paths):
if mode == "speed":
metric = (
speeds_turns_from_paths[k][f][:, 0] * file_settings["frequency_hz"]
speeds_turns_from_paths[k][f][..., 0]
* file_settings["frequency_hz"]
)
elif mode == "iid":
poses = poses_from_paths[k][f]
......@@ -894,8 +826,8 @@ def evaluate_all(
fdict.pop("all")
print("Loading all poses and actions.")
poses_from_paths, file_settings = get_all_poses_from_paths(paths)
speeds_turns_from_paths, file_settings = get_all_data_from_paths(
poses_from_paths, file_settings = utils.get_all_poses_from_paths(paths)
speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths(
paths, "speeds_turns"
)
......
......@@ -11,6 +11,8 @@
# -----------------------------------------------------------
import robofish.io
from robofish.io import utils
import argparse
import logging
......@@ -77,25 +79,30 @@ def validate(args=None):
logging.getLogger().setLevel(logging.ERROR)
sf_dict = robofish.io.read_multiple_files(args.path)
files_per_path = utils.get_all_files_from_paths(args.path)
files = [
f for f_in_path in files_per_path for f in f_in_path
] # Concatenate all files to one list
if len(sf_dict) == 0:
if len(files) == 0:
logging.getLogger().setLevel(logging.INFO)
logging.info("No files found in %s" % args.path)
return
validity_dict = {}
for fp in files:
with robofish.io.File(fp) as f:
validity_dict[str(fp)] = f.validate(strict_validate=False)
if args.output_format == "raw":
sf_dict = {
(str)(f): sf.validate(strict_validate=False)[0] for f, sf in sf_dict.items()
}
return sf_dict
return validity_dict
max_filename_width = max([len((str)(f)) for f in sf_dict.keys()])
max_filename_width = max([len(str(f)) for f in files])
error_code = 0
for file, sf in sf_dict.items():
filled_file = (str)(file).ljust(max_filename_width + 3)
validity, validity_message = sf.validate(strict_validate=False)
sf.close()
for fp, (validity, validity_message) in validity_dict.items():
filled_file = (str)(fp).ljust(max_filename_width + 3)
if not validity:
error_code = 1
print(f"{filled_file}:{validity}\t{validity_message}")
......
......@@ -65,6 +65,7 @@ class File(h5py.File):
mode: str = "r",
*, # PEP 3102
world_size_cm: List[int] = None,
validate: bool = False,
strict_validate: bool = False,
format_version: List[int] = default_format_version,
format_url: str = default_format_url,
......@@ -93,6 +94,8 @@ class File(h5py.File):
world_size_cm : [int, int] , optional
side lengths [x, y] of the world in cm.
rectangular world shape is assumed.
validate: bool, default=False
Should the track be validated? This is normally switched off for performance reasons.
strict_validate : bool, default=False
if the file should be strictly validated against the track
format specification, when loaded from a path.
......@@ -210,6 +213,7 @@ class File(h5py.File):
calendar_time_points=calendar_time_points,
default=True,
)
if validate:
self.validate(strict_validate)
def __enter__(self):
......@@ -994,7 +998,7 @@ class File(h5py.File):
update,
frames=n_frames,
init_func=init,
blit=False,
blit=True,
interval=self.frequency,
repeat=False,
)
......
......@@ -38,3 +38,74 @@ def limit_angle_range(angle: Union[float, Iterable], _range=(-np.pi, np.pi)):
else:
angle = limit_simple(angle)
return angle
def get_all_files_from_paths(paths: Iterable[Union[str, Path]]):
# Find all files with correct ending
files = []
for path in [Path(p) for p in paths]:
if path.is_dir():
files_path = []
for ext in ("hdf", "hdf5", "h5", "he5"):
files_path += list(path.rglob(f"*.{ext}"))
files.append(files_path)
else:
files.append([path])
return files
def get_all_poses_from_paths(paths: Iterable[Union[str, Path]], predicate=None):
"""Read all poses from given paths.
The function shall be used by the evaluation functions.
Args:
paths: An array of strings, with files or folders.
The files are checked to have the same frequency.
Returns:
An array, containing poses with the shape
[paths][files][entities, timesteps, 4],
the common frequency of the files
"""
return get_all_data_from_paths(paths, "poses_4d", predicate)
def get_all_data_from_paths(
paths: Iterable[Union[str, Path]], request_type="poses_4d", predicate=None
):
expected_settings = None
all_data = []
files_per_path = get_all_files_from_paths(paths)
# for each given path
for files_in_path in files_per_path:
data_from_files = []
# Open all files, gather the poses and check if all of them have the same world size and frequency
for i_path, file_path in enumerate(files_in_path):
with robofish.io.File(file_path, "r") as file:
file_settings = {
"world_size_cm_x": file.attrs["world_size_cm"][0],
"world_size_cm_y": file.attrs["world_size_cm"][1],
"frequency_hz": file.frequency,
}
if expected_settings is None:
expected_settings = file_settings
assert file_settings == expected_settings
properties = {
"poses_4d": robofish.io.Entity.poses,
"speeds_turns": robofish.io.Entity.actions_speeds_turns,
}
pred = None if predicate is None else predicate[i_path]
data = file.select_entity_property(
pred, entity_property=properties[request_type]
)
data_from_files.append(data)
all_data.append(data_from_files)
return all_data, expected_settings
......@@ -19,11 +19,11 @@ def test_app_validate():
self.path = path
self.output_format = output_format
raw_output = app.validate(DummyArgs(resources_path, "raw"))
raw_output = app.validate(DummyArgs([resources_path], "raw"))
# The three files valid.hdf5, almost_valid.hdf5, and invalid.hdf5 should be found.
assert len(raw_output) == 4
app.validate(DummyArgs(resources_path, "human"))
app.validate(DummyArgs([resources_path], "human"))
def test_app_print():
......
import robofish.io
from robofish.io import utils
import pytest
from pathlib import Path
import numpy as np
resources_path = utils.full_path(__file__, "../../resources/")
h5py_file = utils.full_path(__file__, "../../resources/valid_1.hdf5")
def test_now_iso8061():
......@@ -14,43 +6,3 @@ def test_now_iso8061():
time = robofish.io.now_iso8061()
assert type(time) == str
assert len(time) == 32
def test_read_multiple_single():
path = h5py_file
# Variants path as posix path or as string
for sf in [
robofish.io.read_multiple_files(h5py_file),
robofish.io.read_multiple_files(str(h5py_file)),
]:
assert len(sf) == 1
for p, f in sf.items():
assert p == path
assert type(f) == robofish.io.File
def test_read_multiple_folder():
# Variants path as posix path or as string
for sf in [
robofish.io.read_multiple_files(resources_path),
robofish.io.read_multiple_files(str(resources_path)),
]:
# Should find the 4 available hdf5 files
assert len(sf) == 4
for p, f in sf.items():
print(p)
assert type(f) == robofish.io.File
# TODO read from folder of valid files
@pytest.mark.parametrize("param_path", [h5py_file, str(h5py_file)])
def test_read_poses_rad_from_multiple_folder(param_path):
poses = robofish.io.read_property_from_multiple_files(
[param_path, param_path], robofish.io.entity.Entity.poses_rad
)
# Should find the 3 presaved hdf5 files
assert len(poses) == 2
for p in poses:
print(p)
assert type(p) == np.ndarray
import pytest
import robofish.evaluate
from robofish.io import utils
from robofish.io import utils
import numpy as np
def test_get_all_poses_from_paths():
valid_file_path = utils.full_path(__file__, "../../resources/valid_1.hdf5")
poses, file_settings = robofish.evaluate.get_all_poses_from_paths([valid_file_path])
poses, file_settings = utils.get_all_poses_from_paths([valid_file_path])
# (1 input array, 1 file, 2 fishes, 100 timesteps, 4 poses)
assert np.array(poses).shape == (1, 1, 2, 100, 4)
type(file_settings) == dict
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment