diff --git a/src/conversion_scripts/convert_socoro.py b/src/conversion_scripts/convert_socoro.py new file mode 100755 index 0000000000000000000000000000000000000000..ff4056220e57ea542597ec6a70be9e667279560c --- /dev/null +++ b/src/conversion_scripts/convert_socoro.py @@ -0,0 +1,71 @@ +from pathlib import Path +import argparse + +import numpy as np +import robofish.io +from tqdm import tqdm + +from robofish.socoro.preprocessing.stats import load_trial + + +def convert(input_path: Path, output_path: Path): + """ + Convert from csv format to RoboFish track format. + + robofish.socoro is used for the load_trial function. + This is needed because the csv files do not have headers. + Unused columns: "frame_number", "datetime_local". + These columns could be used if needed, e.g. convert datetime_local to + calendar_time_points or include frame_number if files will be split into segments. + + TODO: filter warning about unbounded radians + """ + trial_df = load_trial(input_path) + robot_poses = np.stack( + [ + np.clip(trial_df.robot_position_x_cm, 0, 100) - 50, + np.clip(trial_df.robot_position_y_cm, 0, 100) - 50, + -trial_df.robot_orientation_rad, + ] + ).T + guppy_poses = np.stack( + [ + np.clip(trial_df.fish_position_x_cm, 0, 100) - 50, + np.clip(trial_df.fish_position_y_cm, 0, 100) - 50, + -trial_df.fish_orientation_rad, + ] + ).T + + f = robofish.io.File(world_size_cm=[100, 100], frequency_hz=25.0) + + f.create_entity(category="robot", poses=robot_poses, name="robot") + f.create_entity(category="organism", poses=guppy_poses, name="guppy") + + for robot_dataset in [ + trial_df.robot_mode.astype(bytes), + trial_df.avoidance_score, + trial_df.follow_value, + trial_df.carefulness_variable, + ]: + f["entities"]["robot"].create_dataset(robot_dataset.name, data=robot_dataset) + f["samplings"]["25 hz"].create_dataset( + "monotonic_time_points_us", + data=(trial_df.monotonic_time_ms - trial_df.monotonic_time_ms[0]) * 1000, + ) + f.save_as(output_path, no_warning=True) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Convert csv file from the socoro experiments to RoboFish track format." + ) + parser.add_argument( + "input", type=str, help="Single folder of files to be converted" + ) + parser.add_argument("output", type=str, help="Output path") + args = parser.parse_args() + + output_path = Path(args.output) + output_path.mkdir(exist_ok=True) + for path in tqdm(list(Path(args.input).glob("*.csv"))): + convert(input_path=path, output_path=output_path / f"{path.stem}.hdf5") diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py index fd7e9afa73fd3d02e28617527fe97899ef9c797c..0316122dc4d9150b37b9447f2d6da616b85a59db 100644 --- a/src/robofish/evaluate/evaluate.py +++ b/src/robofish/evaluate/evaluate.py @@ -494,7 +494,7 @@ def evaluate_quiver( if poses_from_paths is None: poses_from_paths, file_settings = utils.get_all_poses_from_paths( - paths, predicate + paths, predicate, predicate ) if speeds_turns_from_paths is None: speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths( diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index b4ef9d0874bd88e1bc374f86d63fc65a3c8623b9..b4b1818ff3d992fe215aef037de023465b1d444a 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -934,15 +934,6 @@ class File(h5py.File): # Plotting outside of the figure to have the label ax.plot([550, 600], [550, 600], lw=5, c=this_c, label=fish_id) - # ax.scatter( - # [poses[:, skip_timesteps, 0]], - # [poses[:, skip_timesteps, 1]], - # marker="h", - # c="black", - # s=ms, - # label="Start", - # zorder=5, - # ) ax.scatter( [poses[:, -1, 0]], [poses[:, -1, 1]], @@ -1108,9 +1099,12 @@ class File(h5py.File): ) xv, yv = np.meshgrid(x, y) - grid_points = plt.scatter(xv, yv, c="gray", s=1.5) + points = [ + plt.scatter([], [], marker="x", color="k"), + plt.plot([], [], linestyle="dotted", alpha=0.5, color="k", zorder=0)[0], + plt.scatter(xv, yv, c="gray", s=1.5), + ] - # border = plt.plot(border_vertices[0], border_vertices[1], "k") border = patches.Polygon(border_vertices.T, facecolor="w", zorder=-1) def title(file_frame: int) -> str: @@ -1201,7 +1195,7 @@ class File(h5py.File): [options["view_size"], min_view + options["margin"]] ) - if not np.isnan(min_view).any() and not new_view_size is np.nan: + if not np.isnan(min_view).any() and new_view_size is not np.nan: self.middle_of_swarm = options[ "slow_view" ] * self.middle_of_swarm + (1 - options["slow_view"]) * np.mean(