diff --git a/src/robofish/evaluate/app.py b/src/robofish/evaluate/app.py index 183b5ddacd4cae3fd3a321e8e8bf597f38eac920..a73fe46fcd99d3791b02bdb56e9d02ceffaf5076 100644 --- a/src/robofish/evaluate/app.py +++ b/src/robofish/evaluate/app.py @@ -26,7 +26,7 @@ def function_dict(): "tank_positions": base.evaluate_tankpositions, "tracks": base.evaluate_tracks, "tracks_distance": base.evaluate_tracks_distance, - "evaluate_positionVec": base.evaluate_positionVec, + "socialVec": base.evaluate_socialVec, "follow_iid": base.evaluate_follow_iid, "individual_speeds": base.evaluate_individual_speeds, "individual_iid": base.evaluate_individual_iid, @@ -114,11 +114,11 @@ def evaluate(args=None): print("\n".join([str(p) for p in save_paths])) else: fig = fdict[args.analysis_type](**params) - - if save_path is None: - plt.show() - else: - fig.savefig(save_path) - plt.close(fig) + if fig is not None: + if save_path is None: + plt.show() + else: + fig.savefig(save_path) + plt.close(fig) else: print(f"Evaluation function not found {args.analysis_type}") diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py index eaf6acbee55bd6e453e1ea4187864b00c7b17efa..00fb7e5c0ee2460b5e68652da224e916176526c4 100644 --- a/src/robofish/evaluate/evaluate.py +++ b/src/robofish/evaluate/evaluate.py @@ -7,6 +7,7 @@ # email andi.gerken@gmail.com # Last doku update Feb 2021 +from cv2 import floodFill import robofish.io import robofish.evaluate from pathlib import Path @@ -87,7 +88,10 @@ def evaluate_speed( path_speeds.extend(e_speeds_turns[:, 0] * frequency) file.close() + # Exclude possible nans path_speeds = np.array(path_speeds) + path_speeds = path_speeds[~np.isnan(path_speeds)] + left_quantiles.append(np.quantile(path_speeds, 0.001)) right_quantiles.append(np.quantile(path_speeds, 0.999)) speeds.append(path_speeds) @@ -139,7 +143,10 @@ def evaluate_turn( path_turns.extend(np.rad2deg(e_speeds_turns[:, 1])) file.close() + # Exclude possible nans path_turns = np.array(path_turns) + path_turns = path_turns[~np.isnan(path_turns)] + left_quantiles.append(np.quantile(path_turns, 0.001)) right_quantiles.append(np.quantile(path_turns, 0.999)) turns.append(path_turns) @@ -254,13 +261,13 @@ def evaluate_relative_orientation( poses = file.select_entity_poses( None if predicate is None else predicate[k] ) - all_poses = file.entity_poses + for i in range(len(poses)): - for j in range(len(all_poses)): - if (poses[i] != all_poses[j]).any(): + for j in range(len(poses)): + if i != j: ori_diff = ( - all_poses[j, :, 2] - poses[i, :, 2], - all_poses[j, :, 3] - poses[i, :, 3], + poses[j, :, 2] - poses[i, :, 2], + poses[j, :, 3] - poses[i, :, 3], ) path_orientations.extend(np.arctan2(ori_diff[1], ori_diff[0])) file.close() @@ -367,6 +374,10 @@ def evaluate_tankpositions( predicate: a lambda function, selecting entities (example: lambda e: e.category == "fish") """ + + # Parameter to set how many steps are skipped in the kde plot. + poses_step = 5 + files_per_path = [robofish.io.read_multiple_files(p) for p in paths] x_pos, y_pos = [], [] world_bounds = [] @@ -378,11 +389,18 @@ def evaluate_tankpositions( ) world_bounds.append(file.attrs["world_size_cm"]) for e_poses in poses: - path_x_pos.extend(e_poses[:, 0]) - path_y_pos.extend(e_poses[:, 1]) + path_x_pos.extend(e_poses[::poses_step, 0]) + path_y_pos.extend(e_poses[::poses_step, 1]) file.close() - x_pos.append(np.array(path_x_pos)) - y_pos.append(np.array(path_y_pos)) + + # Exclude possible nans + path_x_pos = np.array(path_x_pos) + path_y_pos = np.array(path_y_pos) + path_x_pos = path_x_pos[~np.isnan(path_x_pos)] + path_y_pos = path_y_pos[~np.isnan(path_y_pos)] + + x_pos.append(path_x_pos) + y_pos.append(path_y_pos) fig, ax = plt.subplots(1, len(x_pos), figsize=(8 * len(x_pos), 8)) if len(x_pos) == 1: @@ -396,12 +414,12 @@ def evaluate_tankpositions( ax[i].set_xlabel("x [cm]") ax[i].set_ylabel("y [cm]") - sns.kdeplot(x=x_pos[i], y=y_pos[i], n_levels=25, shade=True, ax=ax[i]) + sns.kdeplot(x=x_pos[i], y=y_pos[i], n_levels=20, shade=True, ax=ax[i]) return fig -def evaluate_positionVec( +def evaluate_socialVec( paths: Iterable[str], labels: Iterable[str] = None, predicate=None, @@ -416,42 +434,55 @@ def evaluate_positionVec( (example: lambda e: e.category == "fish") """ files_per_path = [robofish.io.read_multiple_files(p) for p in paths] - posVec = [] + socialVec = [] worldBoundsX, worldBoundsY = [], [] for k, files in enumerate(files_per_path): - path_posVec = [] + path_socialVec = [] for p, file in files.items(): worldBoundsX.append(file.attrs["world_size_cm"][0]) worldBoundsY.append(file.attrs["world_size_cm"][1]) poses = file.select_entity_poses( None if predicate is None else predicate[k] ) - all_poses = file.entity_poses - # calculate posVec for every fish combination + + if poses.shape[0] != 2: + print( + "The positionVec can only be calculated when there are exactly 2 fish." + ) + return + # calculate socialVec for every fish combination for i in range(len(poses)): - for j in range(len(all_poses)): - if (poses[i] != all_poses[j]).any(): - posVec_input = np.append( - poses[i, :, [0, 1]].T, all_poses[j, :, [0, 1]].T, axis=1 + for j in range(len(poses)): + if i != j: + socialVec_input = np.append( + poses[i, :, [0, 1]].T, poses[j, :, [0, 1]].T, axis=1 ) - path_posVec.append( - calculate_posVec( - posVec_input, np.arctan2(poses[i, :, 3], poses[i, :, 2]) + + path_socialVec.append( + calculate_socialVec( + socialVec_input, + np.arctan2(poses[i, :, 3], poses[i, :, 2]), ) ) file.close() - posVec.append(np.concatenate(path_posVec, axis=0)) + + # Concatenate and exclude possible nans + path_socialVec = np.concatenate(path_socialVec, axis=0) + path_socialVec = path_socialVec[~np.isnan(path_socialVec).any(axis=1)] + + socialVec.append(path_socialVec) grids = [] worldBoundsX, worldBoundsY = max(worldBoundsX), max(worldBoundsY) - for i in range(len(posVec)): - df = pd.DataFrame({"x": posVec[i][:, 0], "y": posVec[i][:, 1]}) - grid = sns.displot(df, x="x", y="y", binwidth=(10, 10), cbar=True) + for i in range(len(socialVec)): + df = pd.DataFrame({"x": socialVec[i][:, 0], "y": socialVec[i][:, 1]}) + grid = sns.displot(df, x="x", y="y", binwidth=(1, 1), cbar=True) grid.axes[0, 0].set_xlabel("x [cm]") grid.axes[0, 0].set_ylabel("y [cm]") - grid.set(xlim=(-worldBoundsX, worldBoundsX)) - grid.set(ylim=(-worldBoundsY, worldBoundsY)) + # Limits set by educated guesses. If it doesnt work for your data adjust it. + grid.set(xlim=(-worldBoundsX / 7.0, worldBoundsX / 7.0)) + grid.set(ylim=(-worldBoundsY / 7.0, worldBoundsY / 7.0)) grids.append(grid) fig = plt.figure(figsize=(8 * len(grids), 8)) @@ -492,16 +523,22 @@ def evaluate_follow_iid( poses = file.select_entity_poses( None if predicate is None else predicate[k] ) - all_poses = file.entity_poses + + if poses.shape[0] < 2: + print( + "The FollowIID can only be calculated when there are at least 2 fish." + ) + return + for i in range(len(poses)): - for j in range(len(all_poses)): - if (poses[i] != all_poses[j]).any(): + for j in range(len(poses)): + if i != j and (poses[i] != poses[j]).any(): path_iid.append( - calculate_iid(poses[i, :-1, 0:2], all_poses[j, :-1, 0:2]) + calculate_iid(poses[i, :-1, 0:2], poses[j, :-1, 0:2]) ) path_follow.append( - calculate_follow(poses[i, :, 0:2], all_poses[j, :, 0:2]) + calculate_follow(poses[i, :, 0:2], poses[j, :, 0:2]) ) file.close() follow.append(np.array(path_follow)) @@ -513,6 +550,10 @@ def evaluate_follow_iid( follow_flat = np.concatenate([f.flatten() for f in follow]) iid_flat = np.concatenate([i.flatten() for i in iid]) + # Exclude possible nans + follow_flat = follow_flat[~np.isnan(follow_flat)] + iid_flat = iid_flat[~np.isnan(iid_flat)] + # Find the 0.5%/ 99.5% quantiles as min and max for the axis follow_range = np.max( [-1 * np.quantile(follow_flat, 0.005), np.quantile(follow_flat, 0.995)] @@ -671,7 +712,7 @@ def evaluate_individuals( files_per_path = [robofish.io.read_multiple_files(p) for p in paths] fig = plt.figure(figsize=(10, 4)) - small_iid_files = [] + # small_iid_files = [] offset = 0 for k, files in enumerate(files_per_path): all_avg = [] @@ -682,12 +723,18 @@ def evaluate_individuals( metric = file.entity_actions_speeds_turns[..., 0] * file.frequency elif mode == "iid": poses = file.entity_poses_rad + if poses.shape[0] != 2: + print( + "The evaluate_individual_iid function only works when there are exactly 2 individuals." + ) + return metric = calculate_iid(poses[0], poses[1])[None] - mean = np.mean(metric, axis=1)[0] - if mean < 7: - small_iid_files.append(str(file.path)) - all_avg.append(np.mean(metric, axis=1)) - all_std.append(np.std(metric, axis=1)) + # mean = np.mean(metric, axis=1)[0] + # if mean < 7: + # small_iid_files.append(str(file.path)) + + all_avg.append(np.nanmean(metric, axis=1)) + all_std.append(np.nanstd(metric, axis=1)) file.close() all_avg = np.concatenate(all_avg, axis=0) all_std = np.concatenate(all_std, axis=0) @@ -724,15 +771,15 @@ def evaluate_individuals( plt.ylabel(this_text["ylabel"]) plt.tight_layout() - np.random.shuffle(small_iid_files) - - train_part = int(len(small_iid_files) * 0.8) + # This can be used to return a test data and training data split + # np.random.shuffle(small_iid_files) + # train_part = int(len(small_iid_files) * 0.8) - print("TRAINING DATA:") - print(" ".join(small_iid_files[:train_part])) + # print("TRAINING DATA:") + # print(" ".join(small_iid_files[:train_part])) - print("TEST DATA:") - print(" ".join(small_iid_files[train_part:])) + # print("TEST DATA:") + # print(" ".join(small_iid_files[train_part:])) return fig @@ -773,9 +820,10 @@ def evaluate_all( t.refresh() # to show immediately the update save_path = save_folder / (f_name + ".png") fig = f_callable(paths=paths, labels=labels, predicate=predicate) - fig.savefig(save_path) - plt.close(fig) - save_paths.append(save_path) + if fig is not None: + fig.savefig(save_path) + plt.close(fig) + save_paths.append(save_path) return save_paths @@ -825,8 +873,8 @@ def normalize_series(x): return (x.T / np.linalg.norm(x, axis=-1)).T -def calculate_posVec(data, angle): - """Calculate position vectors. +def calculate_socialVec(data, angle): + """Calculate social vectors. Data should be of the form (n, (x1, y1, x2, y2)) and angle of the form (n, 1) returns x, y distance from fish1 to fish2 with respect to the direction diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index cc75d522c5f364e18854ef9f5bc44ae5a7d33138..d5ba5f909546bff5cca54116ab9491d424454370 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -767,7 +767,7 @@ class File(h5py.File): lw=lw, ) # Plotting outside of the figure to have the label - ax.plot([55, 60], [55, 60], lw=5, c=this_c, label=fish_id) + ax.plot([550, 600], [550, 600], lw=5, c=this_c, label=fish_id) ax.scatter( [poses[:, skip_timesteps, 0]], [poses[:, skip_timesteps, 1]], @@ -930,23 +930,27 @@ class File(h5py.File): this_pose = entity_poses[:, file_frame] if not options["fixed_view"]: - self.middle_of_swarm = options[ - "slow_view" - ] * self.middle_of_swarm + (1 - options["slow_view"]) * np.mean( - this_pose, axis=0 - ) # Find the maximal distance between the entities in x or y direction min_view = np.max( (np.max(this_pose, axis=0) - np.min(this_pose, axis=0))[:2] ) + new_view_size = np.max( [options["view_size"], min_view + options["margin"]] ) - self.view_size = ( - options["slow_zoom"] * self.view_size - + (1 - options["slow_zoom"]) * new_view_size - ) + + if not np.isnan(min_view).any() and not new_view_size is np.nan: + self.middle_of_swarm = options[ + "slow_view" + ] * self.middle_of_swarm + (1 - options["slow_view"]) * np.mean( + this_pose, axis=0 + ) + + self.view_size = ( + options["slow_zoom"] * self.view_size + + (1 - options["slow_zoom"]) * new_view_size + ) ax.set_xlim( self.middle_of_swarm[0] - self.view_size / 2, diff --git a/src/robofish/io/utils.py b/src/robofish/io/utils.py index 1992dc36b4a57f47b568e0cf62054eda32338b0b..50be83f8bf4a290aaf23dea88aaf58c65f5816a0 100644 --- a/src/robofish/io/utils.py +++ b/src/robofish/io/utils.py @@ -29,10 +29,12 @@ def limit_angle_range(angle: Union[float, Iterable], _range=(-np.pi, np.pi)): """ assert np.isclose(_range[1] - _range[0], 2 * np.pi) - def limit_one(value): - return (value - _range[0]) % (2 * np.pi) + _range[0] + def limit_simple(a): + return (a - _range[0]) % (2 * np.pi) + _range[0] if isinstance(angle, Iterable): - return np.array([limit_one(v) for v in angle]) + nan = np.isnan(angle) + angle[~nan] = limit_simple(angle[~nan]) else: - return limit_one(angle) + angle = limit_simple(angle) + return angle diff --git a/tests/resources/nan_test.hdf5 b/tests/resources/nan_test.hdf5 new file mode 100644 index 0000000000000000000000000000000000000000..3ff017ef600aec8928a7e55c38a68c2821876e22 Binary files /dev/null and b/tests/resources/nan_test.hdf5 differ diff --git a/tests/robofish/evaluate/test_app_evaluate.py b/tests/robofish/evaluate/test_app_evaluate.py index 5a89d399dffde8ced6002c35782abbba629fdbae..d8b1d69947a7e6208b085303cb521c95b9a8ec60 100644 --- a/tests/robofish/evaluate/test_app_evaluate.py +++ b/tests/robofish/evaluate/test_app_evaluate.py @@ -11,6 +11,30 @@ logging.getLogger().setLevel(logging.INFO) h5py_file_1 = utils.full_path(__file__, "../../resources/valid_1.hdf5") h5py_file_2 = utils.full_path(__file__, "../../resources/valid_2.hdf5") +nan_file_path = utils.full_path(__file__, "../../resources/nan_test.hdf5") + + +def test_app_validate(tmp_path): + """This tests the function of the robofish-io-validate command""" + + class DummyArgs: + def __init__(self, analysis_type, paths, save_path): + self.paths = paths + self.names = None + self.analysis_type = analysis_type + self.save_path = save_path + self.labels = None + + for mode in app.function_dict().keys(): + if mode == "all": + app.evaluate(DummyArgs(mode, [h5py_file_1], tmp_path)) + app.evaluate(DummyArgs(mode, [h5py_file_2, h5py_file_2], tmp_path)) + app.evaluate(DummyArgs(mode, [nan_file_path], tmp_path)) + else: + app.evaluate(DummyArgs(mode, [h5py_file_1], tmp_path / "image.png")) + app.evaluate( + DummyArgs(mode, [h5py_file_2, h5py_file_2], tmp_path / "image.png") + ) def test_app_validate(tmp_path): diff --git a/tests/robofish/io/test_app_io.py b/tests/robofish/io/test_app_io.py index 31ca2ded7b17b39049745ff43f8f13f9d185dcac..7d5ffd8e82ab7065f4122c2399080030957fa7d7 100644 --- a/tests/robofish/io/test_app_io.py +++ b/tests/robofish/io/test_app_io.py @@ -22,7 +22,7 @@ def test_app_validate(): raw_output = app.validate(DummyArgs(resources_path, "raw")) # The three files valid.hdf5, almost_valid.hdf5, and invalid.hdf5 should be found. - assert len(raw_output) == 3 + assert len(raw_output) == 4 app.validate(DummyArgs(resources_path, "human")) diff --git a/tests/robofish/io/test_file.py b/tests/robofish/io/test_file.py index aa30a26af3dfc6db0acd4820d653428db572d9e6..04d4b988148fe4c44304deac663e1bb5d6422098 100644 --- a/tests/robofish/io/test_file.py +++ b/tests/robofish/io/test_file.py @@ -11,6 +11,7 @@ import logging LOGGER = logging.getLogger(__name__) valid_file_path = utils.full_path(__file__, "../../resources/valid_1.hdf5") +nan_file_path = utils.full_path(__file__, "../../resources/nan_test.hdf5") def test_constructor(): @@ -243,6 +244,12 @@ def test_file_plot(): f.plot(lw_distances=True) +def test_file_plot(): + with robofish.io.File(nan_file_path) as f: + f.plot() + f.plot(lw_distances=True) + + if __name__ == "__main__": # Find all functions in this module and execute them all_functions = inspect.getmembers(sys.modules[__name__], inspect.isfunction) diff --git a/tests/robofish/io/test_io.py b/tests/robofish/io/test_io.py index 0cd1c80b65b23b53a914acba090ec99f19ae67ca..ad0f68f64a9a35785a065c6b71d8730fb3bfe7c8 100644 --- a/tests/robofish/io/test_io.py +++ b/tests/robofish/io/test_io.py @@ -36,8 +36,8 @@ def test_read_multiple_folder(): robofish.io.read_multiple_files(resources_path), robofish.io.read_multiple_files(str(resources_path)), ]: - # Should find the 3 presaved hdf5 files - assert len(sf) == 3 + # Should find the 4 available hdf5 files + assert len(sf) == 4 for p, f in sf.items(): print(p) assert type(f) == robofish.io.File