From 0cfdb6392bde05af789dca19eae6ad1846a776d2 Mon Sep 17 00:00:00 2001 From: Andi Gerken <andi.gerken@gmail.com> Date: Mon, 14 Feb 2022 15:15:59 +0100 Subject: [PATCH] Fixed follow_iid plots. Fixed validation to allow np.nan --- src/robofish/evaluate/evaluate.py | 24 +++++++++++++++++------- src/robofish/io/validation.py | 9 ++++----- 2 files changed, 21 insertions(+), 12 deletions(-) diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py index 7195278..7b83ee1 100644 --- a/src/robofish/evaluate/evaluate.py +++ b/src/robofish/evaluate/evaluate.py @@ -504,12 +504,20 @@ def evaluate_follow_iid( calculate_follow(poses[i, :, 0:2], all_poses[j, :, 0:2]) ) file.close() - follow.append(path_follow) - iid.append(path_iid) + follow.append(np.array(path_follow)) + iid.append(np.array(path_iid)) grids = [] worldBoundsX, worldBoundsY = max(worldBoundsX), max(worldBoundsY) - maxDist = (worldBoundsX ** 2 + worldBoundsY ** 2) ** (1 / 2) + + follow_flat = np.concatenate([f.flatten() for f in follow]) + iid_flat = np.concatenate([i.flatten() for i in iid]) + + # Find the 0.5%/ 99.5% quantiles as min and max for the axis + follow_range = np.max( + [-1 * np.quantile(follow_flat, 0.005), np.quantile(follow_flat, 0.995)] + ) + max_iid = np.quantile(iid_flat, 0.995) for i in range(len(follow)): follow_iid_data = pd.DataFrame( @@ -524,10 +532,12 @@ def evaluate_follow_iid( x="IID [cm]", y="Follow", data=follow_iid_data, - linewidth=0, - kind="scatter", - xlim=(0, maxDist), - ylim=(-5, 5), + kind="hist", + xlim=(0, max_iid), + ylim=(-follow_range, follow_range), + cbar=True, + legend=True, + joint_kws={"bins": 50}, ) grid.fig.set_figwidth(9) grid.fig.set_figheight(6) diff --git a/src/robofish/io/validation.py b/src/robofish/io/validation.py index cd27713..d3013e3 100644 --- a/src/robofish/io/validation.py +++ b/src/robofish/io/validation.py @@ -304,6 +304,8 @@ def validate_positions_range(world_size, positions, e_name): # positions which are just a bit over the world edge are fine error_allowance = 1.01 + positions = np.array([p for p in positions if not any(np.isnan(p))]) + allowed_x = [ -1 * world_size[0] * error_allowance / 2, world_size[0] * error_allowance / 2, @@ -332,12 +334,9 @@ def validate_positions_range(world_size, positions, e_name): def validate_orientations_length(orientations, e_name): - ori_lengths = np.linalg.norm(orientations, axis=1) + orientations = np.array([o for o in orientations if not any(np.isnan(o))]) - # import matplotlib.pyplot as plt - # - # plt.plot(ori_lengths) - # plt.show() + ori_lengths = np.linalg.norm(orientations, axis=1) # Check if all orientation lengths are all 1. Different lengths cause warnings. assert_validate( -- GitLab