diff --git a/src/robofish/evaluate/app.py b/src/robofish/evaluate/app.py index cd52952f1431b8744e394a5feb1466057c2881e1..183b5ddacd4cae3fd3a321e8e8bf597f38eac920 100644 --- a/src/robofish/evaluate/app.py +++ b/src/robofish/evaluate/app.py @@ -29,6 +29,7 @@ def function_dict(): "evaluate_positionVec": base.evaluate_positionVec, "follow_iid": base.evaluate_follow_iid, "individual_speeds": base.evaluate_individual_speeds, + "individual_iid": base.evaluate_individual_iid, "all": base.evaluate_all, } @@ -55,6 +56,9 @@ def evaluate(args=None): formatter_class=argparse.RawTextHelpFormatter, ) + for name, func in fdict.items(): + assert func.__doc__ is not None, f"Function '{name}' does not have a docstring." + parser.add_argument( "analysis_type", type=str, diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py index 60c6e7fc1fc22180a1279921a7a3b8ca33cb14bf..e10bbb58df736ab73ef93bdb7b5575dbf4f13f0b 100644 --- a/src/robofish/evaluate/evaluate.py +++ b/src/robofish/evaluate/evaluate.py @@ -613,15 +613,28 @@ def evaluate_tracks( return fig -def evaluate_individual_speeds( +def evaluate_individual_speeds(**kwargs): + """Evaluate the average speeds per individual""" + return evaluate_individuals(mode="speed", **kwargs) + + +def evaluate_individual_iid(**kwargs): + """Evaluate the average iid per file""" + return evaluate_individuals(mode="iid", **kwargs) + + +def evaluate_individuals( + mode: str, paths: Iterable[str], labels: Iterable[str] = None, predicate=None, lw_distances=False, + threshold=7, ): """Evaluate the average speeds per individual. Args: + mode: A choice between ['speed', 'iid'] paths: An array of strings, with files of folders. labels: Labels for the paths. If no labels are given, the paths will be used @@ -632,33 +645,68 @@ def evaluate_individual_speeds( files_per_path = [robofish.io.read_multiple_files(p) for p in paths] fig = plt.figure(figsize=(10, 4)) - + small_iid_files = [] offset = 0 for k, files in enumerate(files_per_path): - all_avg_speeds = [] - all_std_speeds = [] + all_avg = [] + all_std = [] + for path, file in files.items(): - speeds = file.entity_actions_speeds_turns[..., 0] * file.frequency - all_avg_speeds.append(np.mean(speeds, axis=1)) - all_std_speeds.append(np.std(speeds, axis=1)) + if mode == "speed": + metric = file.entity_actions_speeds_turns[..., 0] * file.frequency + elif mode == "iid": + poses = file.entity_poses_rad + metric = calculate_iid(poses[0], poses[1])[None] + mean = np.mean(metric, axis=1)[0] + if mean < 7: + small_iid_files.append(str(file.path)) + all_avg.append(np.mean(metric, axis=1)) + all_std.append(np.std(metric, axis=1)) file.close() - all_avg_speeds = np.concatenate(all_avg_speeds, axis=0) - all_std_speeds = np.concatenate(all_std_speeds, axis=0) - individuals = all_avg_speeds.shape[0] + all_avg = np.concatenate(all_avg, axis=0) + all_std = np.concatenate(all_std, axis=0) + individuals = all_avg.shape[0] + plt.errorbar( np.arange(offset, individuals + offset), - all_avg_speeds, - all_std_speeds, + all_avg, + all_std, label=labels[k], fmt="o", ) offset += individuals - plt.title("Average speed per individual") + + if mode == "iid": + plt.plot([0, offset], [threshold, threshold], c="red", alpha=0.5) + + text = { + "speed": { + "title": "Average speed per individual", + "ylabel": "Average Speed (cm / s) +/- Std Dev", + "xlabel": "Individual ID", + }, + "iid": { + "title": "Average iid per individual", + "ylabel": "Average iid (cm) +/- Std Dev", + "xlabel": "File ID", + }, + } + this_text = text[mode] + plt.title(this_text["title"]) plt.legend(loc="upper right") - plt.xlabel("Individual ID") - plt.ylabel("Average Speed (cm / s) +/- Std Dev") + plt.xlabel(this_text["xlabel"]) + plt.ylabel(this_text["ylabel"]) plt.tight_layout() + np.random.shuffle(small_iid_files) + + train_part = int(len(small_iid_files) * 0.8) + + print("TRAINING DATA:") + print(" ".join(small_iid_files[:train_part])) + + print("TEST DATA:") + print(" ".join(small_iid_files[train_part:])) return fig diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index 0d8396de57681daa208bffab879faf12f0f17fad..7580ac24b97f3b461bd66fe7816556ecc86b4a8f 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -675,7 +675,7 @@ class File(h5py.File): def __str__(self): return self.to_string() - def plot(self, ax=None, lw_distances=False, figsize=None, step_size=4): + def plot(self, ax=None, lw_distances=False, figsize=None, step_size=4, c=None): poses = self.entity_poses[:, :, :2] if lw_distances and poses.shape[0] < 2: @@ -702,6 +702,9 @@ class File(h5py.File): cmap = cm.get_cmap("Set1") x_world, y_world = self.world_size + if figsize is None: + figsize = (8, 8) + if ax is None: fig, ax = plt.subplots(1, 1, figsize=figsize) @@ -711,7 +714,11 @@ class File(h5py.File): ax.set_xlim(-x_world / 2, x_world / 2) ax.set_ylim(-y_world / 2, y_world / 2) for fish_id in range(poses.shape[0]): - c = cmap(fish_id) + if c is None: + this_c = cmap(fish_id) + elif isinstance(c, list): + this_c = c[fish_id] + for t in range(0, poses.shape[1] - 1, step_size): if lw_distances: lw = np.mean(line_width[t : t + step_size + 1]) @@ -720,7 +727,7 @@ class File(h5py.File): ax.plot( poses[fish_id, t : t + step_size + 1, 0], poses[fish_id, t : t + step_size + 1, 1], - c=c, + c=this_c, lw=lw, ) # Plotting outside of the figure to have the label