diff --git a/src/robofish/evaluate/app.py b/src/robofish/evaluate/app.py index c037b583f82cac3fa5fa146b434b13d8fd3c410f..234d5029c4d0429bb290c205c355b88ba36b7f13 100644 --- a/src/robofish/evaluate/app.py +++ b/src/robofish/evaluate/app.py @@ -98,7 +98,7 @@ def evaluate(args=None): args = parser.parse_args() if args.analysis_type == "all" and args.save_path is None: - raise Exception("When the analysis type is all, a path must be given.") + raise Exception("When the analysis type is all, a --save_path must be given.") if args.analysis_type in fdict: if args.labels is None: diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py index 15117029ab850828e4104f750c58d5d0ee6b48e1..60eb1ca71b926c7c7f625e945c5ed73331f8f9ca 100644 --- a/src/robofish/evaluate/evaluate.py +++ b/src/robofish/evaluate/evaluate.py @@ -60,7 +60,6 @@ def evaluate_speed( # Exclude possible nans path_speeds = np.array(path_speeds) - path_speeds = path_speeds[~np.isnan(path_speeds)] left_quantiles.append(np.quantile(path_speeds, 0.001)) right_quantiles.append(np.quantile(path_speeds, 0.999)) @@ -117,9 +116,7 @@ def evaluate_turn( for e_speeds_turns in speeds_turns: path_turns.extend(np.rad2deg(e_speeds_turns[:, 1])) - # Exclude possible nans path_turns = np.array(path_turns) - path_turns = path_turns[~np.isnan(path_turns)] left_quantiles.append(np.quantile(path_turns, 0.001)) right_quantiles.append(np.quantile(path_turns, 0.999)) @@ -393,8 +390,8 @@ def evaluate_tank_position( new_xy_positions = np.concatenate( [p[:, ::poses_step, :2].reshape(-1, 2) for p in poses_per_path], axis=0 ) - # Exclude possible nans - xy_positions.append(new_xy_positions[~np.isnan(new_xy_positions).any(axis=1)]) + + xy_positions.append(new_xy_positions) fig, ax = plt.subplots(1, len(xy_positions), figsize=(8 * len(xy_positions), 8)) if len(xy_positions) == 1: @@ -466,11 +463,7 @@ def evaluate_social_vector( ) ) - # Concatenate and exclude possible nans - path_socialVec = np.concatenate(path_socialVec, axis=0) - path_socialVec = path_socialVec[~np.isnan(path_socialVec).any(axis=1)] - - socialVec.append(path_socialVec) + socialVec.append(np.concatenate(path_socialVec, axis=0)) grids = [] @@ -535,30 +528,37 @@ def evaluate_follow_iid( calculate_follow(poses[i, :, :2], poses[j, :, :2]) ) - follow.append(np.array(path_follow)) - iid.append(np.array(path_iid)) + follow.append(np.concatenate([np.atleast_1d(f) for f in path_follow])) + iid.append(np.concatenate([np.atleast_1d(i) for i in path_iid])) grids = [] - follow_flat = np.concatenate([f.flatten() for f in follow]) - iid_flat = np.concatenate([i.flatten() for i in iid]) - # Exclude possible nans - follow_flat = follow_flat[~np.isnan(follow_flat)] - iid_flat = iid_flat[~np.isnan(iid_flat)] # Find the 0.5%/ 99.5% quantiles as min and max for the axis follow_range = np.max( - [-1 * np.quantile(follow_flat, 0.005), np.quantile(follow_flat, 0.995)] + [ + -1 * np.quantile(np.concatenate(follow), 0.1), + np.quantile(np.concatenate(follow), 0.9), + ] ) - max_iid = np.quantile(iid_flat, 0.995) + max_iid = np.quantile(np.concatenate(iid, dtype=np.float32), 0.9) for i in range(len(follow)): + + # Mask data which is outside of the ranges + mask = ( + (follow[i] > -1 * follow_range) + & (follow[i] < follow_range) + & (iid[i] < max_iid) + ) + follow_iid_data = pd.DataFrame( { - "IID [cm]": np.concatenate(iid[i], axis=0), - "Follow": np.concatenate(follow[i], axis=0), - } + "IID [cm]": iid[i][mask], + "Follow": follow[i][mask], + }, + dtype=np.float32, ) plt.rcParams["lines.markersize"] = 1 @@ -569,9 +569,10 @@ def evaluate_follow_iid( kind="hist", xlim=(0, max_iid), ylim=(-follow_range, follow_range), - cbar=True, - legend=True, - joint_kws={"bins": 50}, + # cbar=True, + # legend=True, + joint_kws={"bins": 30}, + marginal_kws=dict(bins=30), ) # grid.fig.set_figwidth(9) # grid.fig.set_figheight(6) @@ -623,7 +624,7 @@ def evaluate_tracks( file_settings=None, lw_distances=False, seed=42, - max_timesteps=4000, + max_timesteps=None, ): """Evaluate the track. diff --git a/src/robofish/io/utils.py b/src/robofish/io/utils.py index bac783619ddc0be738f092105d2f6be36f15b26f..52268948dd45c830d3f98e269c5dd9c805bde027 100644 --- a/src/robofish/io/utils.py +++ b/src/robofish/io/utils.py @@ -104,6 +104,10 @@ def get_all_data_from_paths( data = file.select_entity_property( pred, entity_property=properties[request_type] ) + + # Exclude timesteps where there is any nan in the row + data = data[:, ~np.isnan(data).any(axis=2).any(axis=0)] + data_from_files.append(data) all_data.append(data_from_files)