Skip to content
Snippets Groups Projects
Commit 3487e21f authored by Andi Gerken's avatar Andi Gerken
Browse files

Added flip elimination

parent 806075f6
Branches
Tags
1 merge request!44Added robofish-io-fix-switches
......@@ -18,6 +18,7 @@ import argparse
import itertools
from pathlib import Path
import robofish.io
import matplotlib.pyplot as plt
try:
from tqdm import tqdm
......@@ -74,7 +75,7 @@ def handle_switches(poses, supress=None):
np.ndarray: Array of shape (n_frames, n_fish, 3) containing the poses of the fish with switches handled.
"""
n_fish = poses.shape[1]
n_timesteps, n_fish, three = poses.shape
all_switches = []
......@@ -126,7 +127,6 @@ def handle_switches(poses, supress=None):
f"Switch: {connections} distance sum\t{np.min(switch_distances_sum):.2f} vs\t{np.min(switch_distances_sum[0]):.2f}"
)
poses[t:] = poses[t:, connections]
all_switches.append(t)
# Update last poses for every fish that is not nan
......@@ -198,10 +198,10 @@ def handle_file(file, args):
pf_np = pf.to_numpy()
print(pf_np.shape)
fname = str(file)[:-4] + ".hdf5" if args.output is None else args.output
io_file_path = str(file)[:-4] + ".hdf5" if args.output is None else args.output
with robofish.io.File(
fname,
io_file_path,
"w",
world_size_cm=[args.world_size, args.world_size],
frequency_hz=args.frequency,
......@@ -209,6 +209,7 @@ def handle_file(file, args):
poses = np.empty((pf_np.shape[0], n_fish, 3), dtype=np.float32)
for f in range(n_fish):
f_cols = pf_np[
:,
header_cols
......@@ -266,6 +267,112 @@ def handle_file(file, args):
# assert (poses[:, 1] <= 101).all(), "Error: y coordinate is not <= 100"
# assert (poses[:, 2] >= 0).all(), "Error: orientation is not >= 0"
# assert (poses[:, 2] <= 2 * np.pi).all(), "Error: orientation is not 2*pi"
return io_file_path, all_switches
def eliminate_flips(io_file_path, analysis_path) -> None:
"""This function eliminates flips in the orientation of the fish.
First we check if we find a pair of two close flips with low speed. If we do we flip the fish between these two flips.
Args:
io_file_path (str): The path to the hdf5 file that should be corrected.
"""
if analysis_path is not None:
analysis_path = Path(analysis_path)
if analysis_path.exists() and analysis_path.is_dir():
analysis_path = analysis_path / "flip_analysis.png"
print("Eliminating flips in ", io_file_path)
flipps = []
flipp_starts = []
with robofish.io.File(io_file_path, "r+") as iof:
n_timesteps = iof.entity_actions_speeds_turns.shape[1]
fig, ax = plt.subplots(1, len(iof.entities), figsize=(20, 7))
for e, entity in enumerate(iof.entities):
actions = entity.actions_speeds_turns
turns = actions[:, 1]
biggest_turns = np.argsort(np.abs(turns[~np.isnan(turns)]))[::-1]
flips = {}
for investigate_turn in biggest_turns:
if (
investigate_turn not in flips.keys()
and investigate_turn not in flips.values()
and not np.isnan(turns[investigate_turn])
and np.abs(turns[investigate_turn]) > 0.6 * np.pi
):
# Find the biggest flip within 10 timesteps from investigate_turn
turns_below = turns[investigate_turn - 20 : investigate_turn]
turns_above = turns[investigate_turn + 1 : investigate_turn + 20]
turns_wo_investigated = np.concatenate(
[turns_below, [0], turns_above]
)
turns_wo_investigated[np.isnan(turns_wo_investigated)] = 0
biggest_neighbors = np.argsort(np.abs(turns_wo_investigated))[::-1]
for neighbor in biggest_neighbors:
if (
neighbor not in flips.keys()
and neighbor not in flips.values()
):
if np.abs(turns_wo_investigated[neighbor]) > 0.6 * np.pi:
flips[investigate_turn] = neighbor + (
investigate_turn - 10
)
break
for k, v in flips.items():
start = min(k, v)
end = max(k, v)
entity["orientations"][start:end] = (
np.pi - entity["orientations"][start:end]
) % (np.pi * 2)
print(f"Flipping from {start} to {end}")
flipps.extend(list(range(start, end)))
flipp_starts.extend([start])
if analysis_path is not None:
all_flip_idx = np.array(list(flips.values()) + list(flips.keys()))
# Transfer the flip ids to the sorted order of abs turns
turn_order = np.argsort(np.abs(turns))
x = np.arange(len(turns), dtype=np.int32)
ax[e].scatter(
x,
np.abs(turns[turn_order]) / np.pi,
c=[
"blue" if i not in all_flip_idx else "red"
for i in x[turn_order]
],
alpha=0.5,
)
if analysis_path is not None:
plt.savefig(analysis_path)
flipps = [i for i in range(n_timesteps) if i in flipps]
iof.attrs["switches"] = np.array(flipps, dtype=np.int32)
iof.update_calculated_data()
return flipp_starts
parser = argparse.ArgumentParser(
......@@ -282,10 +389,19 @@ parser.add_argument("--oricol", default=7)
parser.add_argument("--world_size", default=100)
parser.add_argument("--frequency", default=25)
parser.add_argument("--disable_fix_switches", action="store_true")
parser.add_argument("--disable_fix_flips", action="store_true")
parser.add_argument("--disable_centering", action="store_true")
parser.add_argument("--min_timesteps_between_switches", type=int, default=0)
parser.add_argument(
"--analysis_path",
default=None,
help="Path to save analysis to. Folder will create two png files.",
)
args = parser.parse_args()
if args.analysis_path is not None:
args.analysis_path = Path(args.analysis_path)
for path in args.path:
path = Path(path)
......@@ -294,12 +410,21 @@ for path in args.path:
continue
if path.suffix == ".csv":
handle_file(path, args)
files = [path]
elif path.is_dir():
files = path.rglob("*.csv")
for file in tqdm(files):
handle_file(file, args)
else:
print("'%s' is not a folder nor a csv file" % path)
continue
for file in tqdm(files):
io_file_path, all_switches = handle_file(file, args)
if not args.disable_fix_flips:
all_flipps = eliminate_flips(io_file_path, args.analysis_path)
if args.analysis_path is not None and args.analysis_path.is_dir():
plt.figure()
plt.hist([all_switches, all_flipps], bins=50, label=["switches", "flips"])
plt.legend()
plt.savefig(args.analysis_path / "switches_flipps.png")
......@@ -513,7 +513,7 @@ class File(h5py.File):
), f"A 3 dimensional array was expected (entity, timestep, 3). There were {poses.ndim} dimensions in poses: {poses.shape}"
assert poses.shape[2] in [3, 4]
agents = poses.shape[0]
entity_names = []
entities = []
for i in range(agents):
e_name = None if names is None else names[i]
......@@ -521,7 +521,7 @@ class File(h5py.File):
outlines if outlines is None or outlines.ndim == 3 else outlines[i]
)
individual_id = None if individual_ids is None else individual_ids[i]
entity_names.append(
entities.append(
self.create_entity(
category=category,
sampling=sampling,
......@@ -531,7 +531,7 @@ class File(h5py.File):
outlines=e_outline,
)
)
return entity_names
return entities
def update_calculated_data(self, verbose=False):
changed = any([e.update_calculated_data(verbose) for e in self.entities])
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment