diff --git a/src/robofish/evaluate/app.py b/src/robofish/evaluate/app.py index 6d7b2689dddd331a25f8a11aebbd7b7ba8e037c9..97e3fdd8b040cf0d802eef6c05e714c8af3c8201 100644 --- a/src/robofish/evaluate/app.py +++ b/src/robofish/evaluate/app.py @@ -12,6 +12,7 @@ Functions available to be used in the commandline to evaluate robofish.io files. import robofish.evaluate import argparse from pathlib import Path +import matplotlib.pyplot as plt def function_dict(): @@ -20,12 +21,14 @@ def function_dict(): "speed": base.evaluate_speed, "turn": base.evaluate_turn, "orientation": base.evaluate_orientation, - "relative_orientation": base.evaluate_relativeOrientation, + "relative_orientation": base.evaluate_relative_orientation, "distance_to_wall": base.evaluate_distanceToWall, "tank_positions": base.evaluate_tankpositions, - "trajectories": base.evaluate_trajectories, + "tracks": base.evaluate_tracks, + "tracks_distance": base.evaluate_tracks_distance, "evaluate_positionVec": base.evaluate_positionVec, "follow_iid": base.evaluate_follow_iid, + "avg_speeds": base.evaluate_avg_speed, "all": base.evaluate_all, } @@ -89,13 +92,26 @@ def evaluate(args=None): if args is None: args = parser.parse_args() + if args.analysis_type == "all" and args.save_path is None: + raise Exception("When the analysis type is all, a path must be given.") + if args.analysis_type in fdict: save_path = None if args.save_path is None else Path(args.save_path) - params = (args.paths, args.names, save_path) + params = {"paths": args.paths, "labels": args.names} if args.analysis_type == "all": normal_functions = function_dict() normal_functions.pop("all") - params += (normal_functions,) - fdict[args.analysis_type](*params) + params["save_folder"] = save_path + params["fdict"] = normal_functions + save_paths = fdict[args.analysis_type](**params) + print("\n".join([str(p) for p in save_paths])) + else: + fig = fdict[args.analysis_type](**params) + + if save_path is None: + plt.show() + else: + fig.savefig(save_path) + plt.close(fig) else: print(f"Evaluation function not found {args.analysis_type}") diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py index fe48edbeeea0c0d22ba06cd3229f4487225e4562..a9c92e64f404c577eca0dd24cb536532575b8c81 100644 --- a/src/robofish/evaluate/evaluate.py +++ b/src/robofish/evaluate/evaluate.py @@ -19,6 +19,7 @@ import pandas as pd from typing import Iterable from scipy import stats from tqdm import tqdm +import random def get_all_poses_from_paths(paths: Iterable[str]): @@ -43,7 +44,7 @@ def get_all_poses_from_paths(paths: Iterable[str]): for files_dict in files_per_path ] - # close all files + # close all files and collect the frequencies frequencies = [] for files_dict in files_per_path: for path, f in files_dict.items(): @@ -59,7 +60,6 @@ def get_all_poses_from_paths(paths: Iterable[str]): def evaluate_speed( paths: Iterable[str], labels: Iterable[str] = None, - save_path: str = None, predicate=None, ): """Evaluate the speed of the entities as histogram. @@ -68,7 +68,6 @@ def evaluate_speed( paths: An array of strings, with files of folders. labels: Labels for the paths. If no labels are given, the paths will be used - save_path: A path to a save location. predicate: a lambda function, selecting entities (example: lambda e: e.category == "fish") """ @@ -86,6 +85,7 @@ def evaluate_speed( for e_speeds_turns in file.entity_actions_speeds_turns: path_speeds.extend(e_speeds_turns[:, 0] * frequency) + file.close() path_speeds = np.array(path_speeds) left_quantiles.append(np.quantile(path_speeds, 0.001)) @@ -95,6 +95,7 @@ def evaluate_speed( if labels is None: labels = paths + fig = plt.figure() plt.hist( list(speeds), bins=20, @@ -107,18 +108,12 @@ def evaluate_speed( plt.ticklabel_format(useOffset=False, style="plain") plt.legend() plt.tight_layout() - - if save_path is None: - plt.show() - else: - plt.savefig(save_path) - plt.close() + return fig def evaluate_turn( paths: Iterable[str], labels: Iterable[str] = None, - save_path: str = None, predicate=None, ): """Evaluate the turn angles of the entities as histogram. @@ -127,7 +122,6 @@ def evaluate_turn( paths: An array of strings, with files of folders. labels: Labels for the paths. If no labels are given, the paths will be used - save_path: A path to a save location. predicate: a lambda function, selecting entities (example: lambda e: e.category == "fish") """ @@ -145,6 +139,7 @@ def evaluate_turn( for e_speeds_turns in file.entity_actions_speeds_turns: path_turns.extend(np.rad2deg(e_speeds_turns[:, 1])) + file.close() path_turns = np.array(path_turns) left_quantiles.append(np.quantile(path_turns, 0.001)) @@ -154,6 +149,7 @@ def evaluate_turn( if labels is None: labels = paths + fig = plt.figure() plt.hist( turns, bins=41, @@ -168,17 +164,12 @@ def evaluate_turn( plt.legend() # plt.tight_layout() - if save_path is None: - plt.show() - else: - plt.savefig(save_path) - plt.close() + return fig def evaluate_orientation( paths: Iterable[str], labels: Iterable[str] = None, - save_path: str = None, predicate=None, ): """Evaluate the orientations of the entities on a 2d grid. @@ -187,7 +178,6 @@ def evaluate_orientation( paths: An array of strings, with files of folders. labels: Labels for the paths. If no labels are given, the paths will be used - save_path: A path to a save location. predicate: a lambda function, selecting entities (example: lambda e: e.category == "fish") """ @@ -214,6 +204,7 @@ def evaluate_orientation( ret_2 = stats.binned_statistic_2d( poses[:, 0], poses[:, 1], poses[:, 3], "mean", bins=[xbins, ybins] ) + file.close() orientations.append((ret_1, ret_2)) fig, ax = plt.subplots(1, len(orientations), figsize=(8 * len(orientations), 8)) @@ -248,17 +239,12 @@ def evaluate_orientation( cbar = plt.colorbar(plot, ax=ax[i], pad=0.015, aspect=10) show_values(plot) - if save_path is None: - plt.show() - else: - plt.savefig(save_path) - plt.close() + return fig -def evaluate_relativeOrientation( +def evaluate_relative_orientation( paths: Iterable[str], labels: Iterable[str] = None, - save_path: str = None, predicate=None, ): """Evaluate the relative orientations of the entities as a histogram. @@ -288,11 +274,13 @@ def evaluate_relativeOrientation( all_poses[j, :, 3] - poses[i, :, 3], ) path_orientations.extend(np.arctan2(ori_diff[1], ori_diff[0])) + file.close() orientations.append(path_orientations) if labels is None: labels = paths + fig = plt.figure() plt.hist(orientations, bins=40, label=labels, density=True, range=[0, np.pi]) plt.title("Relative orientation") plt.xlabel("orientation in radians") @@ -301,17 +289,12 @@ def evaluate_relativeOrientation( plt.legend() # plt.tight_layout() - if save_path is None: - plt.show() - else: - plt.savefig(save_path) - plt.close() + return fig def evaluate_distanceToWall( paths: Iterable[str], labels: Iterable[str] = None, - save_path: str = None, predicate=None, ): """Evaluate the distances of the entities to the walls as a histogram. @@ -320,7 +303,6 @@ def evaluate_distanceToWall( paths: An array of strings, with files of folders. labels: Labels for the paths. If no labels are given, the paths will be used - save_path: A path to a save location. predicate: a lambda function, selecting entities (example: lambda e: e.category == "fish") """ @@ -359,6 +341,7 @@ def evaluate_distanceToWall( # use distance only from closest wall dist = np.stack(dist).min(axis=0) path_distances.extend(dist) + file.close() distances.append(path_distances) worldBoundsX, worldBoundsY = max(worldBoundsX), max(worldBoundsY) @@ -366,6 +349,7 @@ def evaluate_distanceToWall( if labels is None: labels = paths + fig = plt.figure() plt.hist( distances, bins=20, @@ -381,17 +365,12 @@ def evaluate_distanceToWall( plt.legend() # plt.tight_layout() - if save_path is None: - plt.show() - else: - plt.savefig(save_path) - plt.close() + return fig def evaluate_tankpositions( paths: Iterable[str], labels: Iterable[str] = None, - save_path: str = None, predicate=None, ): """Evaluate the positions of the entities as a heatmap. @@ -402,7 +381,6 @@ def evaluate_tankpositions( paths: An array of strings, with files of folders. labels: Labels for the paths. If no labels are given, the paths will be used - save_path: A path to a save location. predicate: a lambda function, selecting entities (example: lambda e: e.category == "fish") """ @@ -419,6 +397,7 @@ def evaluate_tankpositions( for e_poses in poses: path_x_pos.extend(e_poses[:, 0]) path_y_pos.extend(e_poses[:, 1]) + file.close() x_pos.append(np.array(path_x_pos)) y_pos.append(np.array(path_y_pos)) @@ -438,106 +417,12 @@ def evaluate_tankpositions( sns.kdeplot(x=x_pos[i], y=y_pos[i], n_levels=25, shade=True, ax=ax[i]) - if save_path is None: - plt.show() - else: - fig.savefig(save_path) - plt.close() - - -def evaluate_trajectories( - paths: Iterable[str], - labels: Iterable[str] = None, - save_path: str = None, - predicate=None, -): - """Evaluate the trajectories of the entities. - - The original implementation is by Moritz Maxeiner. - - Args: - paths: An array of strings, with files of folders. - labels: Labels for the paths. If no labels are given, the paths will - be used - save_path: A path to a save location. - predicate: a lambda function, selecting entities - (example: lambda e: e.category == "fish") - """ - files_per_path = [robofish.io.read_multiple_files(p) for p in paths] - pos = [] - world_bounds = [] - for k, files in enumerate(files_per_path): - path_poses = [] - for p, file in files.items(): - poses = file.select_entity_poses( - None if predicate is None else predicate[k] - ) - world_bounds.append(file.attrs["world_size_cm"]) - path_poses.append(poses[:, :, :2]) - poses = np.concatenate(path_poses, axis=1) - - path_pos = { - fish: pd.DataFrame({"x": poses[fish, :, 0], "y": poses[fish, :, 1]}) - for fish in range(len(poses)) - } - combined = pd.concat( - [ - path_pos[fish].assign(Agent=f"{file.entity_names[fish]}") - for fish in path_pos.keys() - ] - ) - - pos.append((path_pos, combined)) - - fig, ax = plt.subplots(1, len(pos), figsize=(len(pos) * 8, 8)) - if len(pos) == 1: - ax = [ax] - if labels is None: - labels = paths - - for i in range(len(pos)): - sns.set_style("white", {"axes.linewidth": 2, "axes.edgecolor": "black"}) - sns.scatterplot( - x="x", y="y", hue="Agent", linewidth=0, s=4, data=pos[i][1], ax=ax[i] - ) - ax[i].set_title(f"Trajectories\n{labels[i]}") - ax[i].set_xlim(-world_bounds[i][0] / 2, world_bounds[i][0] / 2) - ax[i].set_ylim(-world_bounds[i][1] / 2, world_bounds[i][1] / 2) - ax[i].set_xlabel("x [cm]") - ax[i].set_ylabel("y [cm]") - ax[i].xaxis.set_ticks_position("top") - ax[i].xaxis.set_label_position("top") - ax[i].yaxis.set_ticks_position("left") - ax[i].yaxis.set_label_position("left") - ax[i].scatter( - [frame["x"][0] for frame in pos[i][0].values()], - [frame["y"][0] for frame in pos[i][0].values()], - marker="h", - c="black", - s=32, - label="Start", - ) - ax[i].scatter( - [frame["x"][len(frame["x"]) - 1] for frame in pos[i][0].values()], - [frame["y"][len(frame["y"]) - 1] for frame in pos[i][0].values()], - marker="x", - c="black", - s=32, - label="End", - ) - ax[i].legend() - - if save_path is None: - plt.show() - else: - fig.savefig(save_path) - plt.close() + return fig def evaluate_positionVec( paths: Iterable[str], labels: Iterable[str] = None, - save_path: str = None, predicate=None, ): """Evaluate the vectors pointing from the focal fish to the conspecifics as heatmap. @@ -546,7 +431,6 @@ def evaluate_positionVec( paths: An array of strings, with files of folders. labels: Labels for the paths. If no labels are given, the paths will be used - save_path: A path to a save location. predicate: a lambda function, selecting entities (example: lambda e: e.category == "fish") """ @@ -574,6 +458,7 @@ def evaluate_positionVec( posVec_input, np.arctan2(poses[i, :, 3], poses[i, :, 2]) ) ) + file.close() posVec.append(np.concatenate(path_posVec, axis=0)) grids = [] @@ -596,17 +481,12 @@ def evaluate_positionVec( for i in range(len(grids)): SeabornFig2Grid(grids[i], fig, gs[i]) - if save_path is None: - plt.show() - else: - fig.savefig(save_path) - plt.close() + return fig def evaluate_follow_iid( paths: Iterable[str], labels: Iterable[str] = None, - save_path: str = None, predicate=None, ): """Evaluate the follow metric in respect to the inter individual distance (iid). @@ -617,7 +497,6 @@ def evaluate_follow_iid( paths: An array of strings, with files of folders. labels: Labels for the paths. If no labels are given, the paths will be used - save_path: A path to a save location. predicate: a lambda function, selecting entities (example: lambda e: e.category == "fish") """ @@ -643,6 +522,7 @@ def evaluate_follow_iid( path_follow.append( calculate_follow(poses[i, :, 0:2], all_poses[j, :, 0:2]) ) + file.close() follow.append(path_follow) iid.append(path_iid) @@ -681,11 +561,126 @@ def evaluate_follow_iid( for i in range(len(grids)): SeabornFig2Grid(grids[i], fig, gs[i]) - if save_path is None: - plt.show() - else: - fig.savefig(save_path) - plt.close() + return fig + + +def evaluate_tracks_distance( + paths: Iterable[str], + labels: Iterable[str] = None, + predicate=None, +): + """Evaluate the distances of two or more fish on the track. + + Args: + paths: An array of strings, with files of folders. + labels: Labels for the paths. If no labels are given, the paths will + be used + predicate: a lambda function, selecting entities + (example: lambda e: e.category == "fish") + """ + return evaluate_tracks(paths, labels, predicate, lw_distances=True) + + +def evaluate_tracks( + paths: Iterable[str], + labels: Iterable[str] = None, + predicate=None, + lw_distances=False, +): + """Evaluate the track. + + Args: + paths: An array of strings, with files of folders. + labels: Labels for the paths. If no labels are given, the paths will + be used + predicate: a lambda function, selecting entities + (example: lambda e: e.category == "fish") + """ + + random.seed() + + files_per_path = [robofish.io.read_multiple_files(p) for p in paths] + max_files_per_path = max([len(files) for files in files_per_path]) + rows, cols = len(files_per_path), min(4, max_files_per_path) + + multirow = False + if rows == 1 and cols == 4: + rows = min(3, int(np.ceil(max_files_per_path / 4))) + multirow = True + + fig, ax = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4), squeeze=False) + + for k, files in enumerate(files_per_path): + file_paths = list(files.keys()) + random.shuffle(file_paths) + for i, path in enumerate(file_paths): + file = files[path] + if multirow: + if i >= cols * rows: + break + file.plot(ax[i // cols][i % cols], lw_distances=lw_distances) + else: + if i >= cols: + break + file.plot(ax[k][i], lw_distances=lw_distances) + + file.close() + + plt.tight_layout() + + return fig + + +def evaluate_avg_speed( + paths: Iterable[str], + labels: Iterable[str] = None, + predicate=None, + lw_distances=False, +): + """Evaluate the track. + + Args: + paths: An array of strings, with files of folders. + labels: Labels for the paths. If no labels are given, the paths will + be used + predicate: a lambda function, selecting entities + (example: lambda e: e.category == "fish") + """ + + files_per_path = [robofish.io.read_multiple_files(p) for p in paths] + + fig = plt.figure(figsize=(10, 4)) + + if labels is None: + labels = paths + + offset = 0 + for k, files in enumerate(files_per_path): + all_avg_speeds = [] + all_std_speeds = [] + for path, file in files.items(): + speeds = file.entity_actions_speeds_turns[..., 0] * file.frequency + all_avg_speeds.append(np.mean(speeds, axis=1)) + all_std_speeds.append(np.std(speeds, axis=1)) + file.close() + all_avg_speeds = np.concatenate(all_avg_speeds, axis=0) + all_std_speeds = np.concatenate(all_std_speeds, axis=0) + individuals = all_avg_speeds.shape[0] + plt.errorbar( + np.arange(offset, individuals + offset), + all_avg_speeds, + all_std_speeds, + label=labels[k], + fmt="o", + ) + offset += individuals + plt.title("Average speed per individual") + plt.legend(loc="upper right") + plt.xlabel("Individual ID") + plt.ylabel("Average Speed (cm / s) +/- Std Dev") + plt.tight_layout() + + return fig def evaluate_all( @@ -724,7 +719,9 @@ def evaluate_all( t.set_description(f_name) t.refresh() # to show immediately the update save_path = save_folder / (f_name + ".png") - f_callable(paths, labels, save_path, predicate) + fig = f_callable(paths, labels, predicate) + fig.savefig(save_path) + plt.close(fig) save_paths.append(save_path) return save_paths diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index 82cfefcad880788bf6d79895acf02ea7c344aaf2..1dfb4d1d512c1486380d028c853ece7f66e67f43 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -31,11 +31,13 @@ import uuid import deprecation import types import warnings +from textwrap import wrap import matplotlib as mpl import matplotlib.pyplot as plt from matplotlib import animation from matplotlib import patches +from matplotlib import cm from subprocess import run @@ -133,6 +135,8 @@ class File(h5py.File): a temporary copy of the file will be opened instead of the file itself. """ + self.path = path + if open_copy: assert ( path is not None @@ -642,6 +646,79 @@ class File(h5py.File): def __str__(self): return self.to_string() + def plot(self, ax=None, lw_distances=False, figsize=None, step_size=4): + poses = self.entity_poses[:, :, :2] + + if lw_distances and poses.shape[0] < 2: + lw_distances = False + + if lw_distances: + poses_diff = np.diff(poses, axis=0) # Axis 0 is fish + distances = np.linalg.norm(poses_diff, axis=2) + + min_distances = np.min(distances, axis=0) + + # Magic numbers found by trial and error. Everything above 15cm will be represented as line width 1 + max_distance = 10 + max_lw = 4 + line_width = ( + np.clip(max_distance - min_distances, 1, max_distance) + * max_lw + / max_distance + ) + else: + step_size = poses.shape[1] + line_width = 1 + + cmap = cm.get_cmap("Set1") + + x_world, y_world = self.world_size + if ax is None: + fig, ax = plt.subplots(1, 1, figsize=figsize) + + if self.path is not None: + ax.set_title("\n".join(wrap(Path(self.path).name, width=35))) + + ax.set_xlim(-x_world / 2, x_world / 2) + ax.set_ylim(-y_world / 2, y_world / 2) + for fish_id in range(poses.shape[0]): + c = cmap(fish_id) + for t in range(0, poses.shape[1] - 1, step_size): + if lw_distances: + lw = np.mean(line_width[t : t + step_size + 1]) + else: + lw = 1 + ax.plot( + poses[fish_id, t : t + step_size + 1, 0], + poses[fish_id, t : t + step_size + 1, 1], + c=c, + lw=lw, + ) + # Plotting outside of the figure to have the label + ax.plot([55, 60], [55, 60], lw=5, c=c, label=fish_id) + ax.scatter( + [poses[:, 0, 0]], + [poses[:, 0, 1]], + marker="h", + c="black", + s=32, + label="Start", + zorder=5, + ) + ax.scatter( + [poses[:, -1, 0]], + [poses[:, -1, 1]], + marker="x", + c="black", + s=32, + label="End", + zorder=5, + ) + ax.legend(loc="lower right") + ax.set_xlabel("x [cm]") + ax.set_ylabel("y [cm]") + return ax + def render(self, video_path=None, **kwargs): """Render a video of the file. diff --git a/tests/robofish/io/test_file.py b/tests/robofish/io/test_file.py index 601657d173f643d62502e5732c4ae5acb0a28166..c4d800fc9f48465f65c3a86ee9558b9a0b30fe84 100644 --- a/tests/robofish/io/test_file.py +++ b/tests/robofish/io/test_file.py @@ -232,6 +232,12 @@ def test_loading_saving(tmp_path): sf.close() +def test_file_plot(): + with robofish.io.File(valid_file_path) as f: + f.plot() + f.plot(lw_distances=True) + + if __name__ == "__main__": # Find all functions in this module and execute them all_functions = inspect.getmembers(sys.modules[__name__], inspect.isfunction)