Skip to content
Snippets Groups Projects
Commit 285a015e authored by Andi Gerken's avatar Andi Gerken
Browse files

Merge branch 'master' into develop

parents d2aa9c98 c95c84ed
No related branches found
No related tags found
1 merge request!37Added calculation of individual ids
from pathlib import Path
import argparse
import numpy as np
import robofish.io
from tqdm import tqdm
from robofish.socoro.preprocessing.stats import load_trial
def convert(input_path: Path, output_path: Path):
"""
Convert from csv format to RoboFish track format.
robofish.socoro is used for the load_trial function.
This is needed because the csv files do not have headers.
Unused columns: "frame_number", "datetime_local".
These columns could be used if needed, e.g. convert datetime_local to
calendar_time_points or include frame_number if files will be split into segments.
TODO: filter warning about unbounded radians
"""
trial_df = load_trial(input_path)
robot_poses = np.stack(
[
np.clip(trial_df.robot_position_x_cm, 0, 100) - 50,
np.clip(trial_df.robot_position_y_cm, 0, 100) - 50,
-trial_df.robot_orientation_rad,
]
).T
guppy_poses = np.stack(
[
np.clip(trial_df.fish_position_x_cm, 0, 100) - 50,
np.clip(trial_df.fish_position_y_cm, 0, 100) - 50,
-trial_df.fish_orientation_rad,
]
).T
f = robofish.io.File(world_size_cm=[100, 100], frequency_hz=25.0)
f.create_entity(category="robot", poses=robot_poses, name="robot")
f.create_entity(category="organism", poses=guppy_poses, name="guppy")
for robot_dataset in [
trial_df.robot_mode.astype(bytes),
trial_df.avoidance_score,
trial_df.follow_value,
trial_df.carefulness_variable,
]:
f["entities"]["robot"].create_dataset(robot_dataset.name, data=robot_dataset)
f["samplings"]["25 hz"].create_dataset(
"monotonic_time_points_us",
data=(trial_df.monotonic_time_ms - trial_df.monotonic_time_ms[0]) * 1000,
)
f.save_as(output_path, no_warning=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Convert csv file from the socoro experiments to RoboFish track format."
)
parser.add_argument(
"input", type=str, help="Single folder of files to be converted"
)
parser.add_argument("output", type=str, help="Output path")
args = parser.parse_args()
output_path = Path(args.output)
output_path.mkdir(exist_ok=True)
for path in tqdm(list(Path(args.input).glob("*.csv"))):
convert(input_path=path, output_path=output_path / f"{path.stem}.hdf5")
...@@ -494,7 +494,7 @@ def evaluate_quiver( ...@@ -494,7 +494,7 @@ def evaluate_quiver(
if poses_from_paths is None: if poses_from_paths is None:
poses_from_paths, file_settings = utils.get_all_poses_from_paths( poses_from_paths, file_settings = utils.get_all_poses_from_paths(
paths, predicate paths, predicate, predicate
) )
if speeds_turns_from_paths is None: if speeds_turns_from_paths is None:
speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths( speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths(
......
...@@ -934,15 +934,6 @@ class File(h5py.File): ...@@ -934,15 +934,6 @@ class File(h5py.File):
# Plotting outside of the figure to have the label # Plotting outside of the figure to have the label
ax.plot([550, 600], [550, 600], lw=5, c=this_c, label=fish_id) ax.plot([550, 600], [550, 600], lw=5, c=this_c, label=fish_id)
# ax.scatter(
# [poses[:, skip_timesteps, 0]],
# [poses[:, skip_timesteps, 1]],
# marker="h",
# c="black",
# s=ms,
# label="Start",
# zorder=5,
# )
ax.scatter( ax.scatter(
[poses[:, -1, 0]], [poses[:, -1, 0]],
[poses[:, -1, 1]], [poses[:, -1, 1]],
...@@ -1108,9 +1099,12 @@ class File(h5py.File): ...@@ -1108,9 +1099,12 @@ class File(h5py.File):
) )
xv, yv = np.meshgrid(x, y) xv, yv = np.meshgrid(x, y)
grid_points = plt.scatter(xv, yv, c="gray", s=1.5) points = [
plt.scatter([], [], marker="x", color="k"),
plt.plot([], [], linestyle="dotted", alpha=0.5, color="k", zorder=0)[0],
plt.scatter(xv, yv, c="gray", s=1.5),
]
# border = plt.plot(border_vertices[0], border_vertices[1], "k")
border = patches.Polygon(border_vertices.T, facecolor="w", zorder=-1) border = patches.Polygon(border_vertices.T, facecolor="w", zorder=-1)
def title(file_frame: int) -> str: def title(file_frame: int) -> str:
...@@ -1201,7 +1195,7 @@ class File(h5py.File): ...@@ -1201,7 +1195,7 @@ class File(h5py.File):
[options["view_size"], min_view + options["margin"]] [options["view_size"], min_view + options["margin"]]
) )
if not np.isnan(min_view).any() and not new_view_size is np.nan: if not np.isnan(min_view).any() and new_view_size is not np.nan:
self.middle_of_swarm = options[ self.middle_of_swarm = options[
"slow_view" "slow_view"
] * self.middle_of_swarm + (1 - options["slow_view"]) * np.mean( ] * self.middle_of_swarm + (1 - options["slow_view"]) * np.mean(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment