Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
I
io
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
bioroboticslab
robofish
io
Commits
af18330c
Commit
af18330c
authored
2 years ago
by
Andi Gerken
Browse files
Options
Downloads
Patches
Plain Diff
Fixed social vec evaulation
parent
ec4f67d6
No related branches found
No related tags found
1 merge request
!45
Fixed social vec evaulation
Pipeline
#56280
failed
2 years ago
Stage: package
Stage: test
Changes
1
Pipelines
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
src/robofish/evaluate/evaluate.py
+29
-61
29 additions, 61 deletions
src/robofish/evaluate/evaluate.py
with
29 additions
and
61 deletions
src/robofish/evaluate/evaluate.py
+
29
−
61
View file @
af18330c
...
@@ -654,56 +654,44 @@ def evaluate_social_vector(
...
@@ -654,56 +654,44 @@ def evaluate_social_vector(
matplotlib.figure.Figure: The figure of the social vectors.
matplotlib.figure.Figure: The figure of the social vectors.
"""
"""
try
:
from
fish_models.models.pascals_lstms.attribution
import
SocialVectors
except
ImportError
:
raise
ImportError
(
"
Please install the fish_models package to use this function.
"
)
if
poses_from_paths
is
None
:
if
poses_from_paths
is
None
:
poses_from_paths
,
file_settings
=
utils
.
get_all_poses_from_paths
(
poses_from_paths
,
file_settings
=
utils
.
get_all_poses_from_paths
(
paths
,
predicate
paths
,
predicate
)
)
socialVec
=
[]
fig
,
ax
=
plt
.
subplots
(
1
,
len
(
poses_from_paths
),
figsize
=
(
8
*
len
(
poses_from_paths
),
8
)
# Iterate all paths
for
poses_per_path
in
poses_from_paths
:
path_socialVec
=
[]
# Iterate all files
for
poses
in
poses_per_path
:
# calculate socialVec for every fish combination
for
i
in
range
(
len
(
poses
)):
for
j
in
range
(
len
(
poses
)):
if
i
!=
j
:
socialVec_input
=
np
.
append
(
poses
[
i
,
:,
[
0
,
1
]].
T
,
poses
[
j
,
:,
[
0
,
1
]].
T
,
axis
=
1
)
)
if
len
(
poses_from_paths
)
==
1
:
ax
=
[
ax
]
path_socialVec
.
append
(
# Iterate all paths
calculate_socialVec
(
for
i
,
poses_per_path
in
enumerate
(
poses_from_paths
):
socialVec_input
,
poses
=
np
.
stack
(
poses_per_path
)
np
.
arctan2
(
poses
[
i
,
:,
3
],
poses
[
i
,
:,
2
]),
poses
=
np
.
concatenate
(
)
[
poses
[...,
:
2
],
np
.
arctan2
(
poses
[...,
3
],
poses
[...,
2
])[...,
None
],
],
axis
=-
1
,
)
)
socialVec
.
append
(
np
.
concatenate
(
path_socialVec
,
axis
=
0
))
print
(
np
.
stack
(
poses
).
shape
)
social_vec
=
SocialVectors
(
poses
).
social_vectors_without_focal_zeros
grids
=
[]
flat_sv
=
social_vec
.
reshape
((
-
1
,
3
))
for
i
in
range
(
len
(
socialVec
)):
df
=
pd
.
DataFrame
({
"
x
"
:
socialVec
[
i
][:,
0
],
"
y
"
:
socialVec
[
i
][:,
1
]})
grid
=
sns
.
displot
(
df
,
x
=
"
x
"
,
y
=
"
y
"
,
binwidth
=
(
1
,
1
),
cbar
=
True
)
grid
.
axes
[
0
,
0
].
set_xlabel
(
"
x [cm]
"
)
grid
.
axes
[
0
,
0
].
set_ylabel
(
"
y [cm]
"
)
# Limits set by educated guesses. If it doesnt work for your data adjust it.
grid
.
set
(
xlim
=
(
-
10
,
10
))
grid
.
set
(
ylim
=
(
-
10
,
10
))
grids
.
append
(
grid
)
fig
=
plt
.
figure
(
figsize
=
(
8
*
len
(
grids
),
8
))
fig
.
suptitle
(
"
positionVector: from left to right:
"
+
str
(
labels
),
fontsize
=
16
)
gs
=
gridspec
.
GridSpec
(
1
,
len
(
grids
))
for
i
in
range
(
len
(
grids
)):
SeabornFig2Grid
(
grids
[
i
],
fig
,
gs
[
i
])
ax
[
i
].
hist2d
(
flat_sv
[:,
0
],
flat_sv
[:,
1
],
range
=
[[
-
7.5
,
7.5
],
[
-
7.5
,
7.5
]],
bins
=
100
)
ax
[
i
].
set_title
(
labels
[
i
])
plt
.
suptitle
(
"
Social Vectors
"
)
return
fig
return
fig
...
@@ -1284,26 +1272,6 @@ def normalize_series(x: np.ndarray) -> np.ndarray:
...
@@ -1284,26 +1272,6 @@ def normalize_series(x: np.ndarray) -> np.ndarray:
return
(
x
.
T
/
np
.
linalg
.
norm
(
x
,
axis
=-
1
)).
T
return
(
x
.
T
/
np
.
linalg
.
norm
(
x
,
axis
=-
1
)).
T
def
calculate_socialVec
(
data
:
np
.
ndarray
,
angle
:
np
.
ndarray
)
->
np
.
ndarray
:
"""
Calculate social vectors.
Returns the x, y distance from fish1 to fish2 with respect to the direction of fish1
Args:
data(np.ndarray): The data to calculate the social vectors from shape (n, 4) (n, (x1, y1, x2, y2))
angle(np.ndarray): The angle of fish1 (n, 1)
Returns:
np.ndarray: The social vectors (n, 2)
"""
# rotate axes by angle and adjust x,y of points
temp
=
np
.
copy
(
data
)
temp
[:,
0
]
=
data
[:,
0
]
*
np
.
cos
(
angle
)
-
data
[:,
1
]
*
np
.
sin
(
angle
)
temp
[:,
1
]
=
data
[:,
0
]
*
np
.
sin
(
angle
)
+
data
[:,
1
]
*
np
.
cos
(
angle
)
temp
[:,
2
]
=
data
[:,
2
]
*
np
.
cos
(
angle
)
-
data
[:,
3
]
*
np
.
sin
(
angle
)
temp
[:,
3
]
=
data
[:,
2
]
*
np
.
sin
(
angle
)
+
data
[:,
3
]
*
np
.
cos
(
angle
)
return
np
.
append
(
temp
[:,
[
2
]]
-
temp
[:,
[
0
]],
temp
[:,
[
3
]]
-
temp
[:,
[
1
]],
axis
=
1
)
def
calculate_distLinePoint
(
def
calculate_distLinePoint
(
x1
:
float
,
y1
:
float
,
x2
:
float
,
y2
:
float
,
points
:
np
.
ndarray
x1
:
float
,
y1
:
float
,
x2
:
float
,
y2
:
float
,
points
:
np
.
ndarray
)
->
np
.
ndarray
:
)
->
np
.
ndarray
:
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment