diff --git a/src/robofish/io/validation.py b/src/robofish/io/validation.py index 502e0bbb3a06508381770dc28dd571453f1dfa8a..5d1cb7c0b9bc941a1e5d82d852dac5e5bf4b0099 100644 --- a/src/robofish/io/validation.py +++ b/src/robofish/io/validation.py @@ -304,7 +304,8 @@ def validate_positions_range(world_size, positions, e_name): # positions which are just a bit over the world edge are fine error_allowance = 1.01 - positions = np.array([p for p in positions if not any(np.isnan(p))]) + # Remove rows where there is any nan + positions = np.array(positions)[~np.isnan(positions).any(axis=1)] allowed_x = [ -1 * world_size[0] * error_allowance / 2, @@ -334,7 +335,9 @@ def validate_positions_range(world_size, positions, e_name): def validate_orientations_length(orientations, e_name): - orientations = np.array([o for o in orientations if not any(np.isnan(o))]) + + # Remove rows where there is any nan + orientations = np.array(orientations)[~np.isnan(orientations).any(axis=1)] ori_lengths = np.linalg.norm(orientations, axis=1)