diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py index 71952787a75e9c93095dccdcd622df711d16d72e..7b83ee16360bc11664bc830fcf2cb0aba6e0c2c9 100644 --- a/src/robofish/evaluate/evaluate.py +++ b/src/robofish/evaluate/evaluate.py @@ -504,12 +504,20 @@ def evaluate_follow_iid( calculate_follow(poses[i, :, 0:2], all_poses[j, :, 0:2]) ) file.close() - follow.append(path_follow) - iid.append(path_iid) + follow.append(np.array(path_follow)) + iid.append(np.array(path_iid)) grids = [] worldBoundsX, worldBoundsY = max(worldBoundsX), max(worldBoundsY) - maxDist = (worldBoundsX ** 2 + worldBoundsY ** 2) ** (1 / 2) + + follow_flat = np.concatenate([f.flatten() for f in follow]) + iid_flat = np.concatenate([i.flatten() for i in iid]) + + # Find the 0.5%/ 99.5% quantiles as min and max for the axis + follow_range = np.max( + [-1 * np.quantile(follow_flat, 0.005), np.quantile(follow_flat, 0.995)] + ) + max_iid = np.quantile(iid_flat, 0.995) for i in range(len(follow)): follow_iid_data = pd.DataFrame( @@ -524,10 +532,12 @@ def evaluate_follow_iid( x="IID [cm]", y="Follow", data=follow_iid_data, - linewidth=0, - kind="scatter", - xlim=(0, maxDist), - ylim=(-5, 5), + kind="hist", + xlim=(0, max_iid), + ylim=(-follow_range, follow_range), + cbar=True, + legend=True, + joint_kws={"bins": 50}, ) grid.fig.set_figwidth(9) grid.fig.set_figheight(6) diff --git a/src/robofish/io/validation.py b/src/robofish/io/validation.py index cd277138e8d4cd85f5dfdfe0982abb23d577cf5c..d3013e383623bdfd60553544fff6efd75321bc76 100644 --- a/src/robofish/io/validation.py +++ b/src/robofish/io/validation.py @@ -304,6 +304,8 @@ def validate_positions_range(world_size, positions, e_name): # positions which are just a bit over the world edge are fine error_allowance = 1.01 + positions = np.array([p for p in positions if not any(np.isnan(p))]) + allowed_x = [ -1 * world_size[0] * error_allowance / 2, world_size[0] * error_allowance / 2, @@ -332,12 +334,9 @@ def validate_positions_range(world_size, positions, e_name): def validate_orientations_length(orientations, e_name): - ori_lengths = np.linalg.norm(orientations, axis=1) + orientations = np.array([o for o in orientations if not any(np.isnan(o))]) - # import matplotlib.pyplot as plt - # - # plt.plot(ori_lengths) - # plt.show() + ori_lengths = np.linalg.norm(orientations, axis=1) # Check if all orientation lengths are all 1. Different lengths cause warnings. assert_validate(