diff --git a/src/robofish/evaluate/app.py b/src/robofish/evaluate/app.py
index cd52952f1431b8744e394a5feb1466057c2881e1..183b5ddacd4cae3fd3a321e8e8bf597f38eac920 100644
--- a/src/robofish/evaluate/app.py
+++ b/src/robofish/evaluate/app.py
@@ -29,6 +29,7 @@ def function_dict():
         "evaluate_positionVec": base.evaluate_positionVec,
         "follow_iid": base.evaluate_follow_iid,
         "individual_speeds": base.evaluate_individual_speeds,
+        "individual_iid": base.evaluate_individual_iid,
         "all": base.evaluate_all,
     }
 
@@ -55,6 +56,9 @@ def evaluate(args=None):
         formatter_class=argparse.RawTextHelpFormatter,
     )
 
+    for name, func in fdict.items():
+        assert func.__doc__ is not None, f"Function '{name}' does not have a docstring."
+
     parser.add_argument(
         "analysis_type",
         type=str,
diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py
index 60c6e7fc1fc22180a1279921a7a3b8ca33cb14bf..e10bbb58df736ab73ef93bdb7b5575dbf4f13f0b 100644
--- a/src/robofish/evaluate/evaluate.py
+++ b/src/robofish/evaluate/evaluate.py
@@ -613,15 +613,28 @@ def evaluate_tracks(
     return fig
 
 
-def evaluate_individual_speeds(
+def evaluate_individual_speeds(**kwargs):
+    """Evaluate the average speeds per individual"""
+    return evaluate_individuals(mode="speed", **kwargs)
+
+
+def evaluate_individual_iid(**kwargs):
+    """Evaluate the average iid per file"""
+    return evaluate_individuals(mode="iid", **kwargs)
+
+
+def evaluate_individuals(
+    mode: str,
     paths: Iterable[str],
     labels: Iterable[str] = None,
     predicate=None,
     lw_distances=False,
+    threshold=7,
 ):
     """Evaluate the average speeds per individual.
 
     Args:
+        mode: A choice between ['speed', 'iid']
         paths: An array of strings, with files of folders.
         labels: Labels for the paths. If no labels are given, the paths will
             be used
@@ -632,33 +645,68 @@ def evaluate_individual_speeds(
     files_per_path = [robofish.io.read_multiple_files(p) for p in paths]
 
     fig = plt.figure(figsize=(10, 4))
-
+    small_iid_files = []
     offset = 0
     for k, files in enumerate(files_per_path):
-        all_avg_speeds = []
-        all_std_speeds = []
+        all_avg = []
+        all_std = []
+
         for path, file in files.items():
-            speeds = file.entity_actions_speeds_turns[..., 0] * file.frequency
-            all_avg_speeds.append(np.mean(speeds, axis=1))
-            all_std_speeds.append(np.std(speeds, axis=1))
+            if mode == "speed":
+                metric = file.entity_actions_speeds_turns[..., 0] * file.frequency
+            elif mode == "iid":
+                poses = file.entity_poses_rad
+                metric = calculate_iid(poses[0], poses[1])[None]
+                mean = np.mean(metric, axis=1)[0]
+                if mean < 7:
+                    small_iid_files.append(str(file.path))
+            all_avg.append(np.mean(metric, axis=1))
+            all_std.append(np.std(metric, axis=1))
             file.close()
-        all_avg_speeds = np.concatenate(all_avg_speeds, axis=0)
-        all_std_speeds = np.concatenate(all_std_speeds, axis=0)
-        individuals = all_avg_speeds.shape[0]
+        all_avg = np.concatenate(all_avg, axis=0)
+        all_std = np.concatenate(all_std, axis=0)
+        individuals = all_avg.shape[0]
+
         plt.errorbar(
             np.arange(offset, individuals + offset),
-            all_avg_speeds,
-            all_std_speeds,
+            all_avg,
+            all_std,
             label=labels[k],
             fmt="o",
         )
         offset += individuals
-    plt.title("Average speed per individual")
+
+    if mode == "iid":
+        plt.plot([0, offset], [threshold, threshold], c="red", alpha=0.5)
+
+    text = {
+        "speed": {
+            "title": "Average speed per individual",
+            "ylabel": "Average Speed (cm / s) +/- Std Dev",
+            "xlabel": "Individual ID",
+        },
+        "iid": {
+            "title": "Average iid per individual",
+            "ylabel": "Average iid (cm) +/- Std Dev",
+            "xlabel": "File ID",
+        },
+    }
+    this_text = text[mode]
+    plt.title(this_text["title"])
     plt.legend(loc="upper right")
-    plt.xlabel("Individual ID")
-    plt.ylabel("Average Speed (cm / s) +/- Std Dev")
+    plt.xlabel(this_text["xlabel"])
+    plt.ylabel(this_text["ylabel"])
     plt.tight_layout()
 
+    np.random.shuffle(small_iid_files)
+
+    train_part = int(len(small_iid_files) * 0.8)
+
+    print("TRAINING DATA:")
+    print(" ".join(small_iid_files[:train_part]))
+
+    print("TEST DATA:")
+    print(" ".join(small_iid_files[train_part:]))
     return fig
 
 
diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py
index 0d8396de57681daa208bffab879faf12f0f17fad..7580ac24b97f3b461bd66fe7816556ecc86b4a8f 100644
--- a/src/robofish/io/file.py
+++ b/src/robofish/io/file.py
@@ -675,7 +675,7 @@ class File(h5py.File):
     def __str__(self):
         return self.to_string()
 
-    def plot(self, ax=None, lw_distances=False, figsize=None, step_size=4):
+    def plot(self, ax=None, lw_distances=False, figsize=None, step_size=4, c=None):
         poses = self.entity_poses[:, :, :2]
 
         if lw_distances and poses.shape[0] < 2:
@@ -702,6 +702,9 @@ class File(h5py.File):
         cmap = cm.get_cmap("Set1")
 
         x_world, y_world = self.world_size
+        if figsize is None:
+            figsize = (8, 8)
+
         if ax is None:
             fig, ax = plt.subplots(1, 1, figsize=figsize)
 
@@ -711,7 +714,11 @@ class File(h5py.File):
         ax.set_xlim(-x_world / 2, x_world / 2)
         ax.set_ylim(-y_world / 2, y_world / 2)
         for fish_id in range(poses.shape[0]):
-            c = cmap(fish_id)
+            if c is None:
+                this_c = cmap(fish_id)
+            elif isinstance(c, list):
+                this_c = c[fish_id]
+
             for t in range(0, poses.shape[1] - 1, step_size):
                 if lw_distances:
                     lw = np.mean(line_width[t : t + step_size + 1])
@@ -720,7 +727,7 @@ class File(h5py.File):
                 ax.plot(
                     poses[fish_id, t : t + step_size + 1, 0],
                     poses[fish_id, t : t + step_size + 1, 1],
-                    c=c,
+                    c=this_c,
                     lw=lw,
                 )
             # Plotting outside of the figure to have the label