From b1e5e55c1dd8fd0df04b296f1be3585124baba6a Mon Sep 17 00:00:00 2001 From: Andi Gerken <andi.gerken@gmail.com> Date: Mon, 16 May 2022 17:51:35 +0200 Subject: [PATCH] Workaround for issue #24 --- src/robofish/evaluate/evaluate.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py index 18bb408..7656426 100644 --- a/src/robofish/evaluate/evaluate.py +++ b/src/robofish/evaluate/evaluate.py @@ -6,6 +6,7 @@ # Released under GNU 3.0 License # email andi.gerken@gmail.com +from ast import Import import robofish.io import robofish.evaluate from robofish.io import utils @@ -431,8 +432,17 @@ def evaluate_quiver( max_files=None, bins=25, ): - """Plot the flow of movement in the files.""" - import torch + """BETA: Plot the flow of movement in the files.""" + try: + import torch + from fish_models.models.pascals_lstms.attribution import SocialVectors + except ImportError: + print( + "Either torch or fish_models is not installed.", + "The social vector should come from robofish io.", + "This is a known issue (#24)", + ) + return if poses_from_paths is None: poses_from_paths, file_settings = utils.get_all_poses_from_paths(paths) @@ -480,8 +490,6 @@ def evaluate_quiver( # print(tank_directions[x,y] - tank_directions_speed[x,y]) tank_count[x, y] = len(d) - from fish_models.models.pascals_lstms.attribution import SocialVectors - sv = SocialVectors(torch.tensor(poses_from_paths[0])) sv_r = torch.tensor(sv.social_vectors_without_focal_zeros)[:, :, :-1].reshape( (-1, 3) -- GitLab