From af18330c535e2e296e00b7888e88dd59747ea100 Mon Sep 17 00:00:00 2001
From: Andi Gerken <andi.gerken@gmail.com>
Date: Fri, 16 Jun 2023 15:40:41 +0200
Subject: [PATCH] Fixed social vec evaulation

---
 src/robofish/evaluate/evaluate.py | 90 ++++++++++---------------------
 1 file changed, 29 insertions(+), 61 deletions(-)

diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py
index ac351f7..326db2f 100644
--- a/src/robofish/evaluate/evaluate.py
+++ b/src/robofish/evaluate/evaluate.py
@@ -654,56 +654,44 @@ def evaluate_social_vector(
         matplotlib.figure.Figure: The figure of the social vectors.
     """
 
+    try:
+        from fish_models.models.pascals_lstms.attribution import SocialVectors
+    except ImportError:
+        raise ImportError(
+            "Please install the fish_models package to use this function."
+        )
+
     if poses_from_paths is None:
         poses_from_paths, file_settings = utils.get_all_poses_from_paths(
             paths, predicate
         )
 
-    socialVec = []
+    fig, ax = plt.subplots(
+        1, len(poses_from_paths), figsize=(8 * len(poses_from_paths), 8)
+    )
+    if len(poses_from_paths) == 1:
+        ax = [ax]
 
     # Iterate all paths
-    for poses_per_path in poses_from_paths:
-        path_socialVec = []
-
-        # Iterate all files
-        for poses in poses_per_path:
-            # calculate socialVec for every fish combination
-            for i in range(len(poses)):
-                for j in range(len(poses)):
-                    if i != j:
-                        socialVec_input = np.append(
-                            poses[i, :, [0, 1]].T, poses[j, :, [0, 1]].T, axis=1
-                        )
-
-                        path_socialVec.append(
-                            calculate_socialVec(
-                                socialVec_input,
-                                np.arctan2(poses[i, :, 3], poses[i, :, 2]),
-                            )
-                        )
-
-        socialVec.append(np.concatenate(path_socialVec, axis=0))
-
-    grids = []
-
-    for i in range(len(socialVec)):
-        df = pd.DataFrame({"x": socialVec[i][:, 0], "y": socialVec[i][:, 1]})
-        grid = sns.displot(df, x="x", y="y", binwidth=(1, 1), cbar=True)
-        grid.axes[0, 0].set_xlabel("x [cm]")
-        grid.axes[0, 0].set_ylabel("y [cm]")
-        # Limits set by educated guesses. If it doesnt work for your data adjust it.
-        grid.set(xlim=(-10, 10))
-        grid.set(ylim=(-10, 10))
-        grids.append(grid)
-
-    fig = plt.figure(figsize=(8 * len(grids), 8))
+    for i, poses_per_path in enumerate(poses_from_paths):
+        poses = np.stack(poses_per_path)
+        poses = np.concatenate(
+            [
+                poses[..., :2],
+                np.arctan2(poses[..., 3], poses[..., 2])[..., None],
+            ],
+            axis=-1,
+        )
 
-    fig.suptitle("positionVector: from left to right: " + str(labels), fontsize=16)
-    gs = gridspec.GridSpec(1, len(grids))
-
-    for i in range(len(grids)):
-        SeabornFig2Grid(grids[i], fig, gs[i])
+        print(np.stack(poses).shape)
+        social_vec = SocialVectors(poses).social_vectors_without_focal_zeros
+        flat_sv = social_vec.reshape((-1, 3))
 
+        ax[i].hist2d(
+            flat_sv[:, 0], flat_sv[:, 1], range=[[-7.5, 7.5], [-7.5, 7.5]], bins=100
+        )
+        ax[i].set_title(labels[i])
+    plt.suptitle("Social Vectors")
     return fig
 
 
@@ -1284,26 +1272,6 @@ def normalize_series(x: np.ndarray) -> np.ndarray:
     return (x.T / np.linalg.norm(x, axis=-1)).T
 
 
-def calculate_socialVec(data: np.ndarray, angle: np.ndarray) -> np.ndarray:
-    """Calculate social vectors.
-
-    Returns the x, y distance from fish1 to fish2 with respect to the direction of fish1
-    Args:
-        data(np.ndarray): The data to calculate the social vectors from shape (n, 4) (n, (x1, y1, x2, y2))
-        angle(np.ndarray): The angle of fish1 (n, 1)
-    Returns:
-        np.ndarray: The social vectors (n, 2)
-    """
-    # rotate axes by angle and adjust x,y  of points
-    temp = np.copy(data)
-    temp[:, 0] = data[:, 0] * np.cos(angle) - data[:, 1] * np.sin(angle)
-    temp[:, 1] = data[:, 0] * np.sin(angle) + data[:, 1] * np.cos(angle)
-    temp[:, 2] = data[:, 2] * np.cos(angle) - data[:, 3] * np.sin(angle)
-    temp[:, 3] = data[:, 2] * np.sin(angle) + data[:, 3] * np.cos(angle)
-
-    return np.append(temp[:, [2]] - temp[:, [0]], temp[:, [3]] - temp[:, [1]], axis=1)
-
-
 def calculate_distLinePoint(
     x1: float, y1: float, x2: float, y2: float, points: np.ndarray
 ) -> np.ndarray:
-- 
GitLab