From d2a08d197d628b7225e2d96714ffcaa7399317c2 Mon Sep 17 00:00:00 2001 From: Andi Gerken <andi.gerken@gmail.com> Date: Mon, 14 Feb 2022 18:54:09 +0100 Subject: [PATCH] Fixed evaluate scripts to work with nan data. Additional bugfixes. --- src/robofish/evaluate/app.py | 14 +- src/robofish/evaluate/evaluate.py | 150 ++++++++++++------- src/robofish/io/file.py | 24 +-- src/robofish/io/utils.py | 10 +- tests/resources/nan_test.hdf5 | Bin 0 -> 12296 bytes tests/robofish/evaluate/test_app_evaluate.py | 24 +++ tests/robofish/io/test_app_io.py | 2 +- tests/robofish/io/test_file.py | 7 + tests/robofish/io/test_io.py | 4 +- 9 files changed, 160 insertions(+), 75 deletions(-) create mode 100644 tests/resources/nan_test.hdf5 diff --git a/src/robofish/evaluate/app.py b/src/robofish/evaluate/app.py index 183b5dd..a73fe46 100644 --- a/src/robofish/evaluate/app.py +++ b/src/robofish/evaluate/app.py @@ -26,7 +26,7 @@ def function_dict(): "tank_positions": base.evaluate_tankpositions, "tracks": base.evaluate_tracks, "tracks_distance": base.evaluate_tracks_distance, - "evaluate_positionVec": base.evaluate_positionVec, + "socialVec": base.evaluate_socialVec, "follow_iid": base.evaluate_follow_iid, "individual_speeds": base.evaluate_individual_speeds, "individual_iid": base.evaluate_individual_iid, @@ -114,11 +114,11 @@ def evaluate(args=None): print("\n".join([str(p) for p in save_paths])) else: fig = fdict[args.analysis_type](**params) - - if save_path is None: - plt.show() - else: - fig.savefig(save_path) - plt.close(fig) + if fig is not None: + if save_path is None: + plt.show() + else: + fig.savefig(save_path) + plt.close(fig) else: print(f"Evaluation function not found {args.analysis_type}") diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py index eaf6acb..00fb7e5 100644 --- a/src/robofish/evaluate/evaluate.py +++ b/src/robofish/evaluate/evaluate.py @@ -7,6 +7,7 @@ # email andi.gerken@gmail.com # Last doku update Feb 2021 +from cv2 import floodFill import robofish.io import robofish.evaluate from pathlib import Path @@ -87,7 +88,10 @@ def evaluate_speed( path_speeds.extend(e_speeds_turns[:, 0] * frequency) file.close() + # Exclude possible nans path_speeds = np.array(path_speeds) + path_speeds = path_speeds[~np.isnan(path_speeds)] + left_quantiles.append(np.quantile(path_speeds, 0.001)) right_quantiles.append(np.quantile(path_speeds, 0.999)) speeds.append(path_speeds) @@ -139,7 +143,10 @@ def evaluate_turn( path_turns.extend(np.rad2deg(e_speeds_turns[:, 1])) file.close() + # Exclude possible nans path_turns = np.array(path_turns) + path_turns = path_turns[~np.isnan(path_turns)] + left_quantiles.append(np.quantile(path_turns, 0.001)) right_quantiles.append(np.quantile(path_turns, 0.999)) turns.append(path_turns) @@ -254,13 +261,13 @@ def evaluate_relative_orientation( poses = file.select_entity_poses( None if predicate is None else predicate[k] ) - all_poses = file.entity_poses + for i in range(len(poses)): - for j in range(len(all_poses)): - if (poses[i] != all_poses[j]).any(): + for j in range(len(poses)): + if i != j: ori_diff = ( - all_poses[j, :, 2] - poses[i, :, 2], - all_poses[j, :, 3] - poses[i, :, 3], + poses[j, :, 2] - poses[i, :, 2], + poses[j, :, 3] - poses[i, :, 3], ) path_orientations.extend(np.arctan2(ori_diff[1], ori_diff[0])) file.close() @@ -367,6 +374,10 @@ def evaluate_tankpositions( predicate: a lambda function, selecting entities (example: lambda e: e.category == "fish") """ + + # Parameter to set how many steps are skipped in the kde plot. + poses_step = 5 + files_per_path = [robofish.io.read_multiple_files(p) for p in paths] x_pos, y_pos = [], [] world_bounds = [] @@ -378,11 +389,18 @@ def evaluate_tankpositions( ) world_bounds.append(file.attrs["world_size_cm"]) for e_poses in poses: - path_x_pos.extend(e_poses[:, 0]) - path_y_pos.extend(e_poses[:, 1]) + path_x_pos.extend(e_poses[::poses_step, 0]) + path_y_pos.extend(e_poses[::poses_step, 1]) file.close() - x_pos.append(np.array(path_x_pos)) - y_pos.append(np.array(path_y_pos)) + + # Exclude possible nans + path_x_pos = np.array(path_x_pos) + path_y_pos = np.array(path_y_pos) + path_x_pos = path_x_pos[~np.isnan(path_x_pos)] + path_y_pos = path_y_pos[~np.isnan(path_y_pos)] + + x_pos.append(path_x_pos) + y_pos.append(path_y_pos) fig, ax = plt.subplots(1, len(x_pos), figsize=(8 * len(x_pos), 8)) if len(x_pos) == 1: @@ -396,12 +414,12 @@ def evaluate_tankpositions( ax[i].set_xlabel("x [cm]") ax[i].set_ylabel("y [cm]") - sns.kdeplot(x=x_pos[i], y=y_pos[i], n_levels=25, shade=True, ax=ax[i]) + sns.kdeplot(x=x_pos[i], y=y_pos[i], n_levels=20, shade=True, ax=ax[i]) return fig -def evaluate_positionVec( +def evaluate_socialVec( paths: Iterable[str], labels: Iterable[str] = None, predicate=None, @@ -416,42 +434,55 @@ def evaluate_positionVec( (example: lambda e: e.category == "fish") """ files_per_path = [robofish.io.read_multiple_files(p) for p in paths] - posVec = [] + socialVec = [] worldBoundsX, worldBoundsY = [], [] for k, files in enumerate(files_per_path): - path_posVec = [] + path_socialVec = [] for p, file in files.items(): worldBoundsX.append(file.attrs["world_size_cm"][0]) worldBoundsY.append(file.attrs["world_size_cm"][1]) poses = file.select_entity_poses( None if predicate is None else predicate[k] ) - all_poses = file.entity_poses - # calculate posVec for every fish combination + + if poses.shape[0] != 2: + print( + "The positionVec can only be calculated when there are exactly 2 fish." + ) + return + # calculate socialVec for every fish combination for i in range(len(poses)): - for j in range(len(all_poses)): - if (poses[i] != all_poses[j]).any(): - posVec_input = np.append( - poses[i, :, [0, 1]].T, all_poses[j, :, [0, 1]].T, axis=1 + for j in range(len(poses)): + if i != j: + socialVec_input = np.append( + poses[i, :, [0, 1]].T, poses[j, :, [0, 1]].T, axis=1 ) - path_posVec.append( - calculate_posVec( - posVec_input, np.arctan2(poses[i, :, 3], poses[i, :, 2]) + + path_socialVec.append( + calculate_socialVec( + socialVec_input, + np.arctan2(poses[i, :, 3], poses[i, :, 2]), ) ) file.close() - posVec.append(np.concatenate(path_posVec, axis=0)) + + # Concatenate and exclude possible nans + path_socialVec = np.concatenate(path_socialVec, axis=0) + path_socialVec = path_socialVec[~np.isnan(path_socialVec).any(axis=1)] + + socialVec.append(path_socialVec) grids = [] worldBoundsX, worldBoundsY = max(worldBoundsX), max(worldBoundsY) - for i in range(len(posVec)): - df = pd.DataFrame({"x": posVec[i][:, 0], "y": posVec[i][:, 1]}) - grid = sns.displot(df, x="x", y="y", binwidth=(10, 10), cbar=True) + for i in range(len(socialVec)): + df = pd.DataFrame({"x": socialVec[i][:, 0], "y": socialVec[i][:, 1]}) + grid = sns.displot(df, x="x", y="y", binwidth=(1, 1), cbar=True) grid.axes[0, 0].set_xlabel("x [cm]") grid.axes[0, 0].set_ylabel("y [cm]") - grid.set(xlim=(-worldBoundsX, worldBoundsX)) - grid.set(ylim=(-worldBoundsY, worldBoundsY)) + # Limits set by educated guesses. If it doesnt work for your data adjust it. + grid.set(xlim=(-worldBoundsX / 7.0, worldBoundsX / 7.0)) + grid.set(ylim=(-worldBoundsY / 7.0, worldBoundsY / 7.0)) grids.append(grid) fig = plt.figure(figsize=(8 * len(grids), 8)) @@ -492,16 +523,22 @@ def evaluate_follow_iid( poses = file.select_entity_poses( None if predicate is None else predicate[k] ) - all_poses = file.entity_poses + + if poses.shape[0] < 2: + print( + "The FollowIID can only be calculated when there are at least 2 fish." + ) + return + for i in range(len(poses)): - for j in range(len(all_poses)): - if (poses[i] != all_poses[j]).any(): + for j in range(len(poses)): + if i != j and (poses[i] != poses[j]).any(): path_iid.append( - calculate_iid(poses[i, :-1, 0:2], all_poses[j, :-1, 0:2]) + calculate_iid(poses[i, :-1, 0:2], poses[j, :-1, 0:2]) ) path_follow.append( - calculate_follow(poses[i, :, 0:2], all_poses[j, :, 0:2]) + calculate_follow(poses[i, :, 0:2], poses[j, :, 0:2]) ) file.close() follow.append(np.array(path_follow)) @@ -513,6 +550,10 @@ def evaluate_follow_iid( follow_flat = np.concatenate([f.flatten() for f in follow]) iid_flat = np.concatenate([i.flatten() for i in iid]) + # Exclude possible nans + follow_flat = follow_flat[~np.isnan(follow_flat)] + iid_flat = iid_flat[~np.isnan(iid_flat)] + # Find the 0.5%/ 99.5% quantiles as min and max for the axis follow_range = np.max( [-1 * np.quantile(follow_flat, 0.005), np.quantile(follow_flat, 0.995)] @@ -671,7 +712,7 @@ def evaluate_individuals( files_per_path = [robofish.io.read_multiple_files(p) for p in paths] fig = plt.figure(figsize=(10, 4)) - small_iid_files = [] + # small_iid_files = [] offset = 0 for k, files in enumerate(files_per_path): all_avg = [] @@ -682,12 +723,18 @@ def evaluate_individuals( metric = file.entity_actions_speeds_turns[..., 0] * file.frequency elif mode == "iid": poses = file.entity_poses_rad + if poses.shape[0] != 2: + print( + "The evaluate_individual_iid function only works when there are exactly 2 individuals." + ) + return metric = calculate_iid(poses[0], poses[1])[None] - mean = np.mean(metric, axis=1)[0] - if mean < 7: - small_iid_files.append(str(file.path)) - all_avg.append(np.mean(metric, axis=1)) - all_std.append(np.std(metric, axis=1)) + # mean = np.mean(metric, axis=1)[0] + # if mean < 7: + # small_iid_files.append(str(file.path)) + + all_avg.append(np.nanmean(metric, axis=1)) + all_std.append(np.nanstd(metric, axis=1)) file.close() all_avg = np.concatenate(all_avg, axis=0) all_std = np.concatenate(all_std, axis=0) @@ -724,15 +771,15 @@ def evaluate_individuals( plt.ylabel(this_text["ylabel"]) plt.tight_layout() - np.random.shuffle(small_iid_files) - - train_part = int(len(small_iid_files) * 0.8) + # This can be used to return a test data and training data split + # np.random.shuffle(small_iid_files) + # train_part = int(len(small_iid_files) * 0.8) - print("TRAINING DATA:") - print(" ".join(small_iid_files[:train_part])) + # print("TRAINING DATA:") + # print(" ".join(small_iid_files[:train_part])) - print("TEST DATA:") - print(" ".join(small_iid_files[train_part:])) + # print("TEST DATA:") + # print(" ".join(small_iid_files[train_part:])) return fig @@ -773,9 +820,10 @@ def evaluate_all( t.refresh() # to show immediately the update save_path = save_folder / (f_name + ".png") fig = f_callable(paths=paths, labels=labels, predicate=predicate) - fig.savefig(save_path) - plt.close(fig) - save_paths.append(save_path) + if fig is not None: + fig.savefig(save_path) + plt.close(fig) + save_paths.append(save_path) return save_paths @@ -825,8 +873,8 @@ def normalize_series(x): return (x.T / np.linalg.norm(x, axis=-1)).T -def calculate_posVec(data, angle): - """Calculate position vectors. +def calculate_socialVec(data, angle): + """Calculate social vectors. Data should be of the form (n, (x1, y1, x2, y2)) and angle of the form (n, 1) returns x, y distance from fish1 to fish2 with respect to the direction diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index cc75d52..d5ba5f9 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -767,7 +767,7 @@ class File(h5py.File): lw=lw, ) # Plotting outside of the figure to have the label - ax.plot([55, 60], [55, 60], lw=5, c=this_c, label=fish_id) + ax.plot([550, 600], [550, 600], lw=5, c=this_c, label=fish_id) ax.scatter( [poses[:, skip_timesteps, 0]], [poses[:, skip_timesteps, 1]], @@ -930,23 +930,27 @@ class File(h5py.File): this_pose = entity_poses[:, file_frame] if not options["fixed_view"]: - self.middle_of_swarm = options[ - "slow_view" - ] * self.middle_of_swarm + (1 - options["slow_view"]) * np.mean( - this_pose, axis=0 - ) # Find the maximal distance between the entities in x or y direction min_view = np.max( (np.max(this_pose, axis=0) - np.min(this_pose, axis=0))[:2] ) + new_view_size = np.max( [options["view_size"], min_view + options["margin"]] ) - self.view_size = ( - options["slow_zoom"] * self.view_size - + (1 - options["slow_zoom"]) * new_view_size - ) + + if not np.isnan(min_view).any() and not new_view_size is np.nan: + self.middle_of_swarm = options[ + "slow_view" + ] * self.middle_of_swarm + (1 - options["slow_view"]) * np.mean( + this_pose, axis=0 + ) + + self.view_size = ( + options["slow_zoom"] * self.view_size + + (1 - options["slow_zoom"]) * new_view_size + ) ax.set_xlim( self.middle_of_swarm[0] - self.view_size / 2, diff --git a/src/robofish/io/utils.py b/src/robofish/io/utils.py index 1992dc3..50be83f 100644 --- a/src/robofish/io/utils.py +++ b/src/robofish/io/utils.py @@ -29,10 +29,12 @@ def limit_angle_range(angle: Union[float, Iterable], _range=(-np.pi, np.pi)): """ assert np.isclose(_range[1] - _range[0], 2 * np.pi) - def limit_one(value): - return (value - _range[0]) % (2 * np.pi) + _range[0] + def limit_simple(a): + return (a - _range[0]) % (2 * np.pi) + _range[0] if isinstance(angle, Iterable): - return np.array([limit_one(v) for v in angle]) + nan = np.isnan(angle) + angle[~nan] = limit_simple(angle[~nan]) else: - return limit_one(angle) + angle = limit_simple(angle) + return angle diff --git a/tests/resources/nan_test.hdf5 b/tests/resources/nan_test.hdf5 new file mode 100644 index 0000000000000000000000000000000000000000..3ff017ef600aec8928a7e55c38a68c2821876e22 GIT binary patch literal 12296 zcmeD5aB<`1lHy_j0S*oZ76t(j3y%Lo0fzxZ2+I8r;W02IKpBisx&unDV1h6h8Q2&= zauN_Og8<Zg1!jnV21t^DfgvQw)s=yPkpX5tjE1OUU|@h60Hxr<ql}Re0v@i80U)17 zfCvT#1`Q~E0-DaCT!z%VlFX9K)M6OFI5D>%Co?Y{CIC%t4AA5ZQ-+fkgr-;Ybj`yM zz`())O0j$lA`B7?<@rT9De=XbRjKjGxeN>pf(r6rc?Jdr25}As@p`BYjEoQ$z?==E z85mADfiMq)2Us5?13v=~g9Jlbeo<~>NqkvqQE_H|9s`2_2Lo6hWG=`^T;@S+VPHU* z$IT$ZzyUL_v?vFpn}I=+krC`eP_BpQ;Q&iO*bEHdK!*r{^DzThHCPD>!N>rD@NCKe z3%3GLATcm7Okjg3CzQ``C_&Ugt%Pt!$q^F*==mI!E<v8(08Kw2HVCI>7H7m8g7P1T z$H3qLp~xl~8CbyS8&<!->Zt?r5F;BP6ay~z1+YUEKq(_rg^Vh2`2Z21mHRxP?yG=U z!+^^l3Mvo<Fn>UJ3=Hnh{yyNm4h?89n*o~d85kHcN=gcft@QQNGfVU`a|`s+N_CS` zi*hpa^iopwlQQ#*@{{sQGLwsQ5|i{nY*4)DmlP!?XU9W|E`447qSTz!#NyOqeM3D1 zSb4_^bpR}!pv?hj*~Y*CG6$*zOeOGxsZnAy1V%$(Gz3ONU^E0qLtr!nMnhmU1V%$( zGz3ONU^E0qLx7AB2=?=LVPs%pfQ`3#K*#G~;{fR6cd+ph4(NCXOr3`$#Bu|uID|KH zNJ8V9404dLfQ2)JH%g9_5P+l)2Hf*oBPAe4{fRXM(8rZw<-`hEROPVp7CbM5Dn%9# zQvNf-%l{S7@*8GPhBU<34m5W_L`O+j2*BpM(8G&|!2vv`&B(wDnSU!vEi6sVORkIu zPY>XpCxjWyaKaJ7gbvU{IB*hXFM9opu$KdDZ%S%fVrfnZ!TCP$xI8rRLB{G~H9p)* z6b7VR!#!`mK^3ALY9xd+N{*NiKrh!|^Y#tUd0tpR6yz6YmSpDV!RG1ni!xL5N)q8h zu=0r5d3*G7?E+f4=AZzvIRVWb5YbT*76P#H1wFhF>k2r*<x6s6Nosn2Q6)?tp>+k| zZU(eafXw}2FW+G6hrpAw;Q3TY3Swk}qzDL)fdNuhA=ZV!_{<CvV6C{;W3V!CfaRH( zAO><k7aB1${!oMX#Q|awgu@^Nttml*>*~0Y7#K33GW4M!>6!`ZN(P1>did>y1~J0| z==?8!)gwuE!9fNF5T0?6fdPak9Asbs;f{k03?STakbwb&3l1_cfN;V=1_lrgILN>N z!VU)+7(m$IAOiykD;#8C0AYcH3=AO5aFBrkgdZGWU;yC*2N)PYc)<Y%1`uvIz`y{) z4hI+*K$zhG0|N**>}OyAVNhHgsAphku!mp<2L=WZc5q-|0O1A)1_lrYsafE_z;FPH zA2={DfG`87-QdW;AOOW6bs%iu$iM)?4vq{AAROSxzyQJtjtmSST;RyS0KyH93=AOL z;mE)M!V?@B7(jT2BLf2ngWScypn9{$KH5ORuCSumzRG-_9iwE5ePQ7XyS>u>_7at1 z_Mcfy!5G8_iGk!mYCvj1dO&(XW-u^-+;*To+CX6+h+kzsZy!j$u<*q`ka~$qvHc+Z zpIJ=zgUsA3?Y|#nCZlA^evp}k6~+5OW~$z-*$*-k6oy8IysA(>$SjaN$SjcheNg=% zv%nZ?K3EQ9Ce(hASs?f9M{=JX$V`x3AUi>JgVclE0x}ciHjo=ZZUvbMay!TmAisdj l1o;i*M^NmJ;?WQo4S~@R7!3jXhrqzKD=nZ+*1^@T1OSD7x%~hD literal 0 HcmV?d00001 diff --git a/tests/robofish/evaluate/test_app_evaluate.py b/tests/robofish/evaluate/test_app_evaluate.py index 5a89d39..d8b1d69 100644 --- a/tests/robofish/evaluate/test_app_evaluate.py +++ b/tests/robofish/evaluate/test_app_evaluate.py @@ -11,6 +11,30 @@ logging.getLogger().setLevel(logging.INFO) h5py_file_1 = utils.full_path(__file__, "../../resources/valid_1.hdf5") h5py_file_2 = utils.full_path(__file__, "../../resources/valid_2.hdf5") +nan_file_path = utils.full_path(__file__, "../../resources/nan_test.hdf5") + + +def test_app_validate(tmp_path): + """This tests the function of the robofish-io-validate command""" + + class DummyArgs: + def __init__(self, analysis_type, paths, save_path): + self.paths = paths + self.names = None + self.analysis_type = analysis_type + self.save_path = save_path + self.labels = None + + for mode in app.function_dict().keys(): + if mode == "all": + app.evaluate(DummyArgs(mode, [h5py_file_1], tmp_path)) + app.evaluate(DummyArgs(mode, [h5py_file_2, h5py_file_2], tmp_path)) + app.evaluate(DummyArgs(mode, [nan_file_path], tmp_path)) + else: + app.evaluate(DummyArgs(mode, [h5py_file_1], tmp_path / "image.png")) + app.evaluate( + DummyArgs(mode, [h5py_file_2, h5py_file_2], tmp_path / "image.png") + ) def test_app_validate(tmp_path): diff --git a/tests/robofish/io/test_app_io.py b/tests/robofish/io/test_app_io.py index 31ca2de..7d5ffd8 100644 --- a/tests/robofish/io/test_app_io.py +++ b/tests/robofish/io/test_app_io.py @@ -22,7 +22,7 @@ def test_app_validate(): raw_output = app.validate(DummyArgs(resources_path, "raw")) # The three files valid.hdf5, almost_valid.hdf5, and invalid.hdf5 should be found. - assert len(raw_output) == 3 + assert len(raw_output) == 4 app.validate(DummyArgs(resources_path, "human")) diff --git a/tests/robofish/io/test_file.py b/tests/robofish/io/test_file.py index aa30a26..04d4b98 100644 --- a/tests/robofish/io/test_file.py +++ b/tests/robofish/io/test_file.py @@ -11,6 +11,7 @@ import logging LOGGER = logging.getLogger(__name__) valid_file_path = utils.full_path(__file__, "../../resources/valid_1.hdf5") +nan_file_path = utils.full_path(__file__, "../../resources/nan_test.hdf5") def test_constructor(): @@ -243,6 +244,12 @@ def test_file_plot(): f.plot(lw_distances=True) +def test_file_plot(): + with robofish.io.File(nan_file_path) as f: + f.plot() + f.plot(lw_distances=True) + + if __name__ == "__main__": # Find all functions in this module and execute them all_functions = inspect.getmembers(sys.modules[__name__], inspect.isfunction) diff --git a/tests/robofish/io/test_io.py b/tests/robofish/io/test_io.py index 0cd1c80..ad0f68f 100644 --- a/tests/robofish/io/test_io.py +++ b/tests/robofish/io/test_io.py @@ -36,8 +36,8 @@ def test_read_multiple_folder(): robofish.io.read_multiple_files(resources_path), robofish.io.read_multiple_files(str(resources_path)), ]: - # Should find the 3 presaved hdf5 files - assert len(sf) == 3 + # Should find the 4 available hdf5 files + assert len(sf) == 4 for p, f in sf.items(): print(p) assert type(f) == robofish.io.File -- GitLab