Skip to content
Snippets Groups Projects
Commit 0cfdb639 authored by Andi Gerken's avatar Andi Gerken
Browse files

Fixed follow_iid plots. Fixed validation to allow np.nan

parent d5a55c3c
No related branches found
No related tags found
No related merge requests found
Pipeline #48436 passed
......@@ -504,12 +504,20 @@ def evaluate_follow_iid(
calculate_follow(poses[i, :, 0:2], all_poses[j, :, 0:2])
)
file.close()
follow.append(path_follow)
iid.append(path_iid)
follow.append(np.array(path_follow))
iid.append(np.array(path_iid))
grids = []
worldBoundsX, worldBoundsY = max(worldBoundsX), max(worldBoundsY)
maxDist = (worldBoundsX ** 2 + worldBoundsY ** 2) ** (1 / 2)
follow_flat = np.concatenate([f.flatten() for f in follow])
iid_flat = np.concatenate([i.flatten() for i in iid])
# Find the 0.5%/ 99.5% quantiles as min and max for the axis
follow_range = np.max(
[-1 * np.quantile(follow_flat, 0.005), np.quantile(follow_flat, 0.995)]
)
max_iid = np.quantile(iid_flat, 0.995)
for i in range(len(follow)):
follow_iid_data = pd.DataFrame(
......@@ -524,10 +532,12 @@ def evaluate_follow_iid(
x="IID [cm]",
y="Follow",
data=follow_iid_data,
linewidth=0,
kind="scatter",
xlim=(0, maxDist),
ylim=(-5, 5),
kind="hist",
xlim=(0, max_iid),
ylim=(-follow_range, follow_range),
cbar=True,
legend=True,
joint_kws={"bins": 50},
)
grid.fig.set_figwidth(9)
grid.fig.set_figheight(6)
......
......@@ -304,6 +304,8 @@ def validate_positions_range(world_size, positions, e_name):
# positions which are just a bit over the world edge are fine
error_allowance = 1.01
positions = np.array([p for p in positions if not any(np.isnan(p))])
allowed_x = [
-1 * world_size[0] * error_allowance / 2,
world_size[0] * error_allowance / 2,
......@@ -332,12 +334,9 @@ def validate_positions_range(world_size, positions, e_name):
def validate_orientations_length(orientations, e_name):
ori_lengths = np.linalg.norm(orientations, axis=1)
orientations = np.array([o for o in orientations if not any(np.isnan(o))])
# import matplotlib.pyplot as plt
#
# plt.plot(ori_lengths)
# plt.show()
ori_lengths = np.linalg.norm(orientations, axis=1)
# Check if all orientation lengths are all 1. Different lengths cause warnings.
assert_validate(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment