From af18330c535e2e296e00b7888e88dd59747ea100 Mon Sep 17 00:00:00 2001 From: Andi Gerken <andi.gerken@gmail.com> Date: Fri, 16 Jun 2023 15:40:41 +0200 Subject: [PATCH] Fixed social vec evaulation --- src/robofish/evaluate/evaluate.py | 90 ++++++++++--------------------- 1 file changed, 29 insertions(+), 61 deletions(-) diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py index ac351f7..326db2f 100644 --- a/src/robofish/evaluate/evaluate.py +++ b/src/robofish/evaluate/evaluate.py @@ -654,56 +654,44 @@ def evaluate_social_vector( matplotlib.figure.Figure: The figure of the social vectors. """ + try: + from fish_models.models.pascals_lstms.attribution import SocialVectors + except ImportError: + raise ImportError( + "Please install the fish_models package to use this function." + ) + if poses_from_paths is None: poses_from_paths, file_settings = utils.get_all_poses_from_paths( paths, predicate ) - socialVec = [] + fig, ax = plt.subplots( + 1, len(poses_from_paths), figsize=(8 * len(poses_from_paths), 8) + ) + if len(poses_from_paths) == 1: + ax = [ax] # Iterate all paths - for poses_per_path in poses_from_paths: - path_socialVec = [] - - # Iterate all files - for poses in poses_per_path: - # calculate socialVec for every fish combination - for i in range(len(poses)): - for j in range(len(poses)): - if i != j: - socialVec_input = np.append( - poses[i, :, [0, 1]].T, poses[j, :, [0, 1]].T, axis=1 - ) - - path_socialVec.append( - calculate_socialVec( - socialVec_input, - np.arctan2(poses[i, :, 3], poses[i, :, 2]), - ) - ) - - socialVec.append(np.concatenate(path_socialVec, axis=0)) - - grids = [] - - for i in range(len(socialVec)): - df = pd.DataFrame({"x": socialVec[i][:, 0], "y": socialVec[i][:, 1]}) - grid = sns.displot(df, x="x", y="y", binwidth=(1, 1), cbar=True) - grid.axes[0, 0].set_xlabel("x [cm]") - grid.axes[0, 0].set_ylabel("y [cm]") - # Limits set by educated guesses. If it doesnt work for your data adjust it. - grid.set(xlim=(-10, 10)) - grid.set(ylim=(-10, 10)) - grids.append(grid) - - fig = plt.figure(figsize=(8 * len(grids), 8)) + for i, poses_per_path in enumerate(poses_from_paths): + poses = np.stack(poses_per_path) + poses = np.concatenate( + [ + poses[..., :2], + np.arctan2(poses[..., 3], poses[..., 2])[..., None], + ], + axis=-1, + ) - fig.suptitle("positionVector: from left to right: " + str(labels), fontsize=16) - gs = gridspec.GridSpec(1, len(grids)) - - for i in range(len(grids)): - SeabornFig2Grid(grids[i], fig, gs[i]) + print(np.stack(poses).shape) + social_vec = SocialVectors(poses).social_vectors_without_focal_zeros + flat_sv = social_vec.reshape((-1, 3)) + ax[i].hist2d( + flat_sv[:, 0], flat_sv[:, 1], range=[[-7.5, 7.5], [-7.5, 7.5]], bins=100 + ) + ax[i].set_title(labels[i]) + plt.suptitle("Social Vectors") return fig @@ -1284,26 +1272,6 @@ def normalize_series(x: np.ndarray) -> np.ndarray: return (x.T / np.linalg.norm(x, axis=-1)).T -def calculate_socialVec(data: np.ndarray, angle: np.ndarray) -> np.ndarray: - """Calculate social vectors. - - Returns the x, y distance from fish1 to fish2 with respect to the direction of fish1 - Args: - data(np.ndarray): The data to calculate the social vectors from shape (n, 4) (n, (x1, y1, x2, y2)) - angle(np.ndarray): The angle of fish1 (n, 1) - Returns: - np.ndarray: The social vectors (n, 2) - """ - # rotate axes by angle and adjust x,y of points - temp = np.copy(data) - temp[:, 0] = data[:, 0] * np.cos(angle) - data[:, 1] * np.sin(angle) - temp[:, 1] = data[:, 0] * np.sin(angle) + data[:, 1] * np.cos(angle) - temp[:, 2] = data[:, 2] * np.cos(angle) - data[:, 3] * np.sin(angle) - temp[:, 3] = data[:, 2] * np.sin(angle) + data[:, 3] * np.cos(angle) - - return np.append(temp[:, [2]] - temp[:, [0]], temp[:, [3]] - temp[:, [1]], axis=1) - - def calculate_distLinePoint( x1: float, y1: float, x2: float, y2: float, points: np.ndarray ) -> np.ndarray: -- GitLab