Skip to content
Snippets Groups Projects
Commit 23654f46 authored by Andi Gerken's avatar Andi Gerken
Browse files

Fixed bug in evaluation scripts and added more tests.

parent c7b2181a
Branches
Tags 0.2.3
No related merge requests found
Pipeline #41048 passed
......@@ -6,6 +6,7 @@ dist
.vscode
.venv
.coverage
.testmondata
report.xml
htmlcov
html
......@@ -16,5 +17,7 @@ env
*.hdf5
*.mp4
!tests/resources/*.hdf5
feature_requests.md
output_graph.png
......@@ -73,6 +73,8 @@ def evaluate_speed(
"""
files_per_path = [robofish.io.read_multiple_files(p) for p in paths]
speeds = []
left_quantiles = []
right_quantiles = []
frequency = None
for k, files in enumerate(files_per_path):
......@@ -82,19 +84,22 @@ def evaluate_speed(
frequency = file.frequency
for e_speeds_turns in file.entity_actions_speeds_turns:
path_speeds = np.concatenate(
[path_speeds, e_speeds_turns[:, 0] * frequency]
)
path_speeds.extend(e_speeds_turns[:, 0] * frequency)
path_speeds = np.array(path_speeds)
left_quantiles.append(np.quantile(path_speeds, 0.001))
right_quantiles.append(np.quantile(path_speeds, 0.999))
speeds.append(path_speeds)
if labels is None:
labels = paths
speeds = np.array(speeds)
left_quantile = np.min(np.quantile(speeds, 0.001, axis=1))
right_quantile = np.max(np.quantile(speeds, 0.999, axis=1))
plt.hist(list(speeds), bins=20, label=labels, range=[left_quantile, right_quantile])
plt.hist(
list(speeds),
bins=20,
label=labels,
range=[min(left_quantiles), max(right_quantiles)],
)
plt.title("Agent speeds")
plt.xlabel("Speed [cm/s]")
plt.ylabel("Frequency")
......@@ -132,25 +137,28 @@ def evaluate_turn(
for k, files in enumerate(files_per_path):
path_turns = []
left_quantiles, right_quantiles = [], []
for p, file in files.items():
assert frequency is None or frequency == file.frequency
frequency = file.frequency
for e_speeds_turns in file.entity_actions_speeds_turns:
path_turns.extend(np.rad2deg(e_speeds_turns[:, 1]))
path_turns = np.array(path_turns)
left_quantiles.append(np.quantile(path_turns, 0.001))
right_quantiles.append(np.quantile(path_turns, 0.999))
turns.append(path_turns)
if labels is None:
labels = paths
left_quantile = np.min(np.quantile(np.array(turns), 0.005, axis=1))
right_quantile = np.max(np.quantile(np.array(turns), 0.995, axis=1))
plt.hist(
turns,
bins=41,
label=labels,
density=True,
range=[left_quantile, right_quantile],
range=[min(left_quantiles), max(right_quantiles)],
)
plt.title("Agent turns")
plt.xlabel("Change in orientation [Degree / timestep at %dhz]" % frequency)
......@@ -697,6 +705,12 @@ def evaluate_all(
predicate: a lambda function, selecting entities
(example: lambda e: e.category == "fish")
"""
assert (
save_folder is not None and save_folder != ""
), "Please provide a save_folder using --save_path"
save_folder.mkdir(exist_ok=True)
t = tqdm(fdict.items(), desc="Evaluation", leave=True)
for f_name, f_callable in t:
t.set_description(f_name)
......
File moved
File added
......@@ -9,7 +9,8 @@ np.seterr(all="raise")
logging.getLogger().setLevel(logging.INFO)
h5py_file = utils.full_path(__file__, "../../resources/valid.hdf5")
h5py_file_1 = utils.full_path(__file__, "../../resources/valid_1.hdf5")
h5py_file_2 = utils.full_path(__file__, "../../resources/valid_2.hdf5")
def test_app_validate(tmp_path):
......@@ -24,10 +25,10 @@ def test_app_validate(tmp_path):
for mode in app.function_dict().keys():
if mode == "all":
app.evaluate(DummyArgs(mode, [h5py_file], tmp_path))
app.evaluate(DummyArgs(mode, [h5py_file, h5py_file], tmp_path))
app.evaluate(DummyArgs(mode, [h5py_file_1], tmp_path))
app.evaluate(DummyArgs(mode, [h5py_file_2, h5py_file_2], tmp_path))
else:
app.evaluate(DummyArgs(mode, [h5py_file], tmp_path / "image.png"))
app.evaluate(DummyArgs(mode, [h5py_file_1], tmp_path / "image.png"))
app.evaluate(
DummyArgs(mode, [h5py_file, h5py_file], tmp_path / "image.png")
DummyArgs(mode, [h5py_file_2, h5py_file_2], tmp_path / "image.png")
)
......@@ -5,7 +5,7 @@ import numpy as np
def test_get_all_poses_from_paths():
valid_file_path = utils.full_path(__file__, "../../resources/valid.hdf5")
valid_file_path = utils.full_path(__file__, "../../resources/valid_1.hdf5")
poses, frequency = robofish.evaluate.get_all_poses_from_paths([valid_file_path])
# (1 input array, 1 file, 2 fishes, 100 timesteps, 4 poses)
......
......@@ -7,34 +7,32 @@ from pathlib import Path
logging.getLogger().setLevel(logging.INFO)
resources_path = utils.full_path(__file__, "../../resources")
h5py_file = utils.full_path(__file__, "../../resources/valid_1.hdf5")
def test_app_validate():
""" This tests the function of the robofish-io-validate command """
"""This tests the function of the robofish-io-validate command"""
class DummyArgs:
def __init__(self, path, output_format):
self.path = path
self.output_format = output_format
raw_output = app.validate(
DummyArgs(utils.full_path(__file__, "../../resources"), "raw")
)
raw_output = app.validate(DummyArgs(resources_path, "raw"))
# The three files valid.hdf5, almost_valid.hdf5, and invalid.hdf5 should be found.
assert len(raw_output) == 2
app.validate(DummyArgs(utils.full_path(__file__, "../../resources"), "human"))
assert len(raw_output) == 3
app.validate(DummyArgs(resources_path, "human"))
def test_app_print():
""" This tests the function of the robofish-io-validate command """
"""This tests the function of the robofish-io-validate command"""
class DummyArgs:
def __init__(self, path, output_format):
self.path = path
self.output_format = output_format
app.print_file(
DummyArgs(utils.full_path(__file__, "../../resources/valid.hdf5"), "full")
)
app.print_file(
DummyArgs(utils.full_path(__file__, "../../resources/valid.hdf5"), "shape")
)
app.print_file(DummyArgs(h5py_file, "full"))
app.print_file(DummyArgs(h5py_file, "shape"))
......@@ -10,7 +10,7 @@ import logging
LOGGER = logging.getLogger(__name__)
valid_file_path = utils.full_path(__file__, "../../resources/valid.hdf5")
valid_file_path = utils.full_path(__file__, "../../resources/valid_1.hdf5")
def test_constructor():
......
......@@ -5,6 +5,10 @@ from pathlib import Path
import numpy as np
resources_path = utils.full_path(__file__, "../../resources/")
h5py_file = utils.full_path(__file__, "../../resources/valid_1.hdf5")
def test_now_iso8061():
# Example time: 2021-01-05T14:33:40.401000+00:00
time = robofish.io.now_iso8061()
......@@ -13,12 +17,12 @@ def test_now_iso8061():
def test_read_multiple_single():
path = utils.full_path(__file__, "../../resources/valid.hdf5")
path = h5py_file
# Variants path as posix path or as string
for sf in [
robofish.io.read_multiple_files(path),
robofish.io.read_multiple_files(str(path)),
robofish.io.read_multiple_files(h5py_file),
robofish.io.read_multiple_files(str(h5py_file)),
]:
assert len(sf) == 1
for p, f in sf.items():
......@@ -27,27 +31,23 @@ def test_read_multiple_single():
def test_read_multiple_folder():
path = utils.full_path(__file__, "../../resources/")
# Variants path as posix path or as string
for sf in [
robofish.io.read_multiple_files(path),
robofish.io.read_multiple_files(str(path)),
robofish.io.read_multiple_files(resources_path),
robofish.io.read_multiple_files(str(resources_path)),
]:
# Should find the 3 presaved hdf5 files
assert len(sf) == 2
assert len(sf) == 3
for p, f in sf.items():
print(p)
assert type(f) == robofish.io.File
path = utils.full_path(__file__, "../../resources/valid.hdf5")
# TODO read from folder of valid files
@pytest.mark.parametrize("_path", [path, str(path)])
def test_read_poses_rad_from_multiple_folder(_path):
@pytest.mark.parametrize("param_path", [h5py_file, str(h5py_file)])
def test_read_poses_rad_from_multiple_folder(param_path):
poses = robofish.io.read_property_from_multiple_files(
[_path, _path], robofish.io.entity.Entity.poses_rad
[param_path, param_path], robofish.io.entity.Entity.poses_rad
)
# Should find the 3 presaved hdf5 files
assert len(poses) == 2
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment