Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
I
io
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Requirements
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Test cases
Artifacts
Deploy
Releases
Package registry
Container registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Insights
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
bioroboticslab
robofish
io
Commits
3487e21f
Commit
3487e21f
authored
2 years ago
by
Andi Gerken
Browse files
Options
Downloads
Patches
Plain Diff
Added flip elimination
parent
806075f6
Branches
Branches containing commit
Tags
Tags containing commit
1 merge request
!44
Added robofish-io-fix-switches
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
src/conversion_scripts/convert_from_csv.py
+132
-7
132 additions, 7 deletions
src/conversion_scripts/convert_from_csv.py
src/robofish/io/file.py
+3
-3
3 additions, 3 deletions
src/robofish/io/file.py
with
135 additions
and
10 deletions
src/conversion_scripts/convert_from_csv.py
+
132
−
7
View file @
3487e21f
...
...
@@ -18,6 +18,7 @@ import argparse
import
itertools
from
pathlib
import
Path
import
robofish.io
import
matplotlib.pyplot
as
plt
try
:
from
tqdm
import
tqdm
...
...
@@ -74,7 +75,7 @@ def handle_switches(poses, supress=None):
np.ndarray: Array of shape (n_frames, n_fish, 3) containing the poses of the fish with switches handled.
"""
n_
fish
=
poses
.
shape
[
1
]
n_
timesteps
,
n_fish
,
three
=
poses
.
shape
all_switches
=
[]
...
...
@@ -126,7 +127,6 @@ def handle_switches(poses, supress=None):
f
"
Switch:
{
connections
}
distance sum
\t
{
np
.
min
(
switch_distances_sum
)
:
.
2
f
}
vs
\t
{
np
.
min
(
switch_distances_sum
[
0
])
:
.
2
f
}
"
)
poses
[
t
:]
=
poses
[
t
:,
connections
]
all_switches
.
append
(
t
)
# Update last poses for every fish that is not nan
...
...
@@ -198,10 +198,10 @@ def handle_file(file, args):
pf_np
=
pf
.
to_numpy
()
print
(
pf_np
.
shape
)
fname
=
str
(
file
)[:
-
4
]
+
"
.hdf5
"
if
args
.
output
is
None
else
args
.
output
io_file_path
=
str
(
file
)[:
-
4
]
+
"
.hdf5
"
if
args
.
output
is
None
else
args
.
output
with
robofish
.
io
.
File
(
fname
,
io_file_path
,
"
w
"
,
world_size_cm
=
[
args
.
world_size
,
args
.
world_size
],
frequency_hz
=
args
.
frequency
,
...
...
@@ -209,6 +209,7 @@ def handle_file(file, args):
poses
=
np
.
empty
((
pf_np
.
shape
[
0
],
n_fish
,
3
),
dtype
=
np
.
float32
)
for
f
in
range
(
n_fish
):
f_cols
=
pf_np
[
:,
header_cols
...
...
@@ -266,6 +267,112 @@ def handle_file(file, args):
# assert (poses[:, 1] <= 101).all(), "Error: y coordinate is not <= 100"
# assert (poses[:, 2] >= 0).all(), "Error: orientation is not >= 0"
# assert (poses[:, 2] <= 2 * np.pi).all(), "Error: orientation is not 2*pi"
return
io_file_path
,
all_switches
def
eliminate_flips
(
io_file_path
,
analysis_path
)
->
None
:
"""
This function eliminates flips in the orientation of the fish.
First we check if we find a pair of two close flips with low speed. If we do we flip the fish between these two flips.
Args:
io_file_path (str): The path to the hdf5 file that should be corrected.
"""
if
analysis_path
is
not
None
:
analysis_path
=
Path
(
analysis_path
)
if
analysis_path
.
exists
()
and
analysis_path
.
is_dir
():
analysis_path
=
analysis_path
/
"
flip_analysis.png
"
print
(
"
Eliminating flips in
"
,
io_file_path
)
flipps
=
[]
flipp_starts
=
[]
with
robofish
.
io
.
File
(
io_file_path
,
"
r+
"
)
as
iof
:
n_timesteps
=
iof
.
entity_actions_speeds_turns
.
shape
[
1
]
fig
,
ax
=
plt
.
subplots
(
1
,
len
(
iof
.
entities
),
figsize
=
(
20
,
7
))
for
e
,
entity
in
enumerate
(
iof
.
entities
):
actions
=
entity
.
actions_speeds_turns
turns
=
actions
[:,
1
]
biggest_turns
=
np
.
argsort
(
np
.
abs
(
turns
[
~
np
.
isnan
(
turns
)]))[::
-
1
]
flips
=
{}
for
investigate_turn
in
biggest_turns
:
if
(
investigate_turn
not
in
flips
.
keys
()
and
investigate_turn
not
in
flips
.
values
()
and
not
np
.
isnan
(
turns
[
investigate_turn
])
and
np
.
abs
(
turns
[
investigate_turn
])
>
0.6
*
np
.
pi
):
# Find the biggest flip within 10 timesteps from investigate_turn
turns_below
=
turns
[
investigate_turn
-
20
:
investigate_turn
]
turns_above
=
turns
[
investigate_turn
+
1
:
investigate_turn
+
20
]
turns_wo_investigated
=
np
.
concatenate
(
[
turns_below
,
[
0
],
turns_above
]
)
turns_wo_investigated
[
np
.
isnan
(
turns_wo_investigated
)]
=
0
biggest_neighbors
=
np
.
argsort
(
np
.
abs
(
turns_wo_investigated
))[::
-
1
]
for
neighbor
in
biggest_neighbors
:
if
(
neighbor
not
in
flips
.
keys
()
and
neighbor
not
in
flips
.
values
()
):
if
np
.
abs
(
turns_wo_investigated
[
neighbor
])
>
0.6
*
np
.
pi
:
flips
[
investigate_turn
]
=
neighbor
+
(
investigate_turn
-
10
)
break
for
k
,
v
in
flips
.
items
():
start
=
min
(
k
,
v
)
end
=
max
(
k
,
v
)
entity
[
"
orientations
"
][
start
:
end
]
=
(
np
.
pi
-
entity
[
"
orientations
"
][
start
:
end
]
)
%
(
np
.
pi
*
2
)
print
(
f
"
Flipping from
{
start
}
to
{
end
}
"
)
flipps
.
extend
(
list
(
range
(
start
,
end
)))
flipp_starts
.
extend
([
start
])
if
analysis_path
is
not
None
:
all_flip_idx
=
np
.
array
(
list
(
flips
.
values
())
+
list
(
flips
.
keys
()))
# Transfer the flip ids to the sorted order of abs turns
turn_order
=
np
.
argsort
(
np
.
abs
(
turns
))
x
=
np
.
arange
(
len
(
turns
),
dtype
=
np
.
int32
)
ax
[
e
].
scatter
(
x
,
np
.
abs
(
turns
[
turn_order
])
/
np
.
pi
,
c
=
[
"
blue
"
if
i
not
in
all_flip_idx
else
"
red
"
for
i
in
x
[
turn_order
]
],
alpha
=
0.5
,
)
if
analysis_path
is
not
None
:
plt
.
savefig
(
analysis_path
)
flipps
=
[
i
for
i
in
range
(
n_timesteps
)
if
i
in
flipps
]
iof
.
attrs
[
"
switches
"
]
=
np
.
array
(
flipps
,
dtype
=
np
.
int32
)
iof
.
update_calculated_data
()
return
flipp_starts
parser
=
argparse
.
ArgumentParser
(
...
...
@@ -282,10 +389,19 @@ parser.add_argument("--oricol", default=7)
parser
.
add_argument
(
"
--world_size
"
,
default
=
100
)
parser
.
add_argument
(
"
--frequency
"
,
default
=
25
)
parser
.
add_argument
(
"
--disable_fix_switches
"
,
action
=
"
store_true
"
)
parser
.
add_argument
(
"
--disable_fix_flips
"
,
action
=
"
store_true
"
)
parser
.
add_argument
(
"
--disable_centering
"
,
action
=
"
store_true
"
)
parser
.
add_argument
(
"
--min_timesteps_between_switches
"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"
--analysis_path
"
,
default
=
None
,
help
=
"
Path to save analysis to. Folder will create two png files.
"
,
)
args
=
parser
.
parse_args
()
if
args
.
analysis_path
is
not
None
:
args
.
analysis_path
=
Path
(
args
.
analysis_path
)
for
path
in
args
.
path
:
path
=
Path
(
path
)
...
...
@@ -294,12 +410,21 @@ for path in args.path:
continue
if
path
.
suffix
==
"
.csv
"
:
handle_file
(
path
,
args
)
files
=
[
path
]
elif
path
.
is_dir
():
files
=
path
.
rglob
(
"
*.csv
"
)
for
file
in
tqdm
(
files
):
handle_file
(
file
,
args
)
else
:
print
(
"'
%s
'
is not a folder nor a csv file
"
%
path
)
continue
for
file
in
tqdm
(
files
):
io_file_path
,
all_switches
=
handle_file
(
file
,
args
)
if
not
args
.
disable_fix_flips
:
all_flipps
=
eliminate_flips
(
io_file_path
,
args
.
analysis_path
)
if
args
.
analysis_path
is
not
None
and
args
.
analysis_path
.
is_dir
():
plt
.
figure
()
plt
.
hist
([
all_switches
,
all_flipps
],
bins
=
50
,
label
=
[
"
switches
"
,
"
flips
"
])
plt
.
legend
()
plt
.
savefig
(
args
.
analysis_path
/
"
switches_flipps.png
"
)
This diff is collapsed.
Click to expand it.
src/robofish/io/file.py
+
3
−
3
View file @
3487e21f
...
...
@@ -513,7 +513,7 @@ class File(h5py.File):
),
f
"
A 3 dimensional array was expected (entity, timestep, 3). There were
{
poses
.
ndim
}
dimensions in poses:
{
poses
.
shape
}
"
assert
poses
.
shape
[
2
]
in
[
3
,
4
]
agents
=
poses
.
shape
[
0
]
entit
y_nam
es
=
[]
entit
i
es
=
[]
for
i
in
range
(
agents
):
e_name
=
None
if
names
is
None
else
names
[
i
]
...
...
@@ -521,7 +521,7 @@ class File(h5py.File):
outlines
if
outlines
is
None
or
outlines
.
ndim
==
3
else
outlines
[
i
]
)
individual_id
=
None
if
individual_ids
is
None
else
individual_ids
[
i
]
entit
y_nam
es
.
append
(
entit
i
es
.
append
(
self
.
create_entity
(
category
=
category
,
sampling
=
sampling
,
...
...
@@ -531,7 +531,7 @@ class File(h5py.File):
outlines
=
e_outline
,
)
)
return
entit
y_nam
es
return
entit
i
es
def
update_calculated_data
(
self
,
verbose
=
False
):
changed
=
any
([
e
.
update_calculated_data
(
verbose
)
for
e
in
self
.
entities
])
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment