Skip to content
Snippets Groups Projects
Commit b3d7246d authored by Andi Gerken's avatar Andi Gerken
Browse files

- Added first evaluation script

- Fixed bug when loading multiple files
parent d270bc59
Branches develop
No related tags found
No related merge requests found
......@@ -7,6 +7,8 @@ entry_points = {
"console_scripts": [
"robofish-io-validate=robofish.io.app:validate",
"robofish-io-print=robofish.io.app:print",
# TODO: This should be called robofish-evaluate which is not possible because of the package name (guess) ask moritzs
"robofish-io-evaluate=robofish.evaluate.app:evaluate",
]
}
setup(
......
import sys
import logging
import robofish.io
from robofish.evaluate.evaluate import *
import robofish.evaluate.app
# TODO: REMOVE
# logging.getLogger().setLevel(logging.INFO)
assert (3, 7) <= sys.version_info < (4, 0), "Unsupported Python version"
# -*- coding: utf-8 -*-
# -----------------------------------------------------------
# Functions available to be used in the commandline to evaluate robofish.io files
#
# Dec 2020 Andreas Gerken, Berlin, Germany
# Released under GNU 3.0 License
# email andi.gerken@gmail.com
# -----------------------------------------------------------
import robofish.evaluate
import argparse
def evaluate(args=None):
"""This function can be used to print hdf5 files from the command line
Returns:
A human readable print of a given hdf5 file.
"""
parser = argparse.ArgumentParser(description="TODO")
parser.add_argument("analysis_type", type=str, choices=["speed"])
parser.add_argument(
"paths",
type=str,
nargs="+",
help="The paths to io/hdf5 files. Multiple paths can be given which will be shown in different colors",
)
if args is None:
args = parser.parse_args()
if args.analysis_type == "speed":
robofish.evaluate.evaluate.evaluate_speed(args.paths)
import robofish.evaluate
import robofish.io
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
def evaluate_speed(paths):
files_per_path = [robofish.io.read_multiple_files(p) for p in paths]
speeds = []
for files in files_per_path:
path_speeds = []
for p, file in files.items():
poses = file.get_poses_array()
for e_poses in poses:
e_speeds = np.linalg.norm(np.diff(e_poses[:, :2], axis=0), axis=1)
path_speeds.extend(e_speeds)
speeds.append(path_speeds)
plt.hist(speeds, bins=20, label=paths, density=True, range=[0, 1])
# sns.displot(speeds, label=paths, multiple="layer", kind="kde")
plt.title("Agent speeds")
plt.xlabel("Speed (cm/timestep)")
plt.ylabel("Frequency (logscale)")
plt.ticklabel_format(useOffset=False)
# plt.xscale("log", nonpositive="clip")
plt.legend()
plt.tight_layout()
plt.show()
# SPDX-License-Identifier: LGPL-3.0-or-later
import sys
import logging
# TODO: REMOVE
logging.getLogger().setLevel(logging.INFO)
from robofish.io.file import File
from robofish.io.validation import *
from robofish.io.io import *
......
......@@ -223,8 +223,11 @@ class File(h5py.File):
names = self["entities"].keys()
n = len(names)
timesteps = max(
[self["entities"][e_name]["poses"].shape[0] for e_name in names]
timesteps = (
0
if n == 0
else max([self["entities"][e_name]["poses"].shape[0] for e_name in names])
)
# Initialize poses output array
......
......@@ -4,6 +4,17 @@ from typing import Union, Iterable
from pathlib import Path
import logging
import numpy as np
# Optional pandas series support
list_types = (list, np.ndarray)
try:
import pandas
list_types += (pandas.core.series.Series,)
except ImportError:
pass
def now_iso8061():
return datetime.datetime.now(datetime.timezone.utc).isoformat(
......@@ -29,16 +40,16 @@ def read_multiple_files(
logging.info(f"Reading files from path {paths}")
try:
iter(paths)
except TypeError:
if not isinstance(paths, list_types):
paths = [paths]
paths = [Path(p) for p in paths]
sf_dict = {}
for path in paths:
if path.is_dir():
logging.info("found dir %s" % path)
# Find all hdf5 files in folder
files = []
for ext in ("hdf", "hdf5", "h5", "he5"):
......@@ -51,7 +62,8 @@ def read_multiple_files(
sf_dict.update(
{file: robofish.io.File(path=file, strict_validate=strict_validate)}
)
elif path is not None:
elif path is not None and path.exists():
logging.info("found file %s" % path)
sf_dict[path] = robofish.io.File(path=path, strict_validate=strict_validate)
return sf_dict
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment