Skip to content
Snippets Groups Projects
Commit db27f89b authored by Andi Gerken's avatar Andi Gerken
Browse files

Added conversion script from csv and added better support for nan files.

parent dddae355
Branches
Tags
1 merge request!42Added conversion script from csv and added better support for nan files.
Pipeline #53824 passed
...@@ -13,6 +13,7 @@ The script is not finished and still in progress. ...@@ -13,6 +13,7 @@ The script is not finished and still in progress.
# flake8: noqa # flake8: noqa
import pandas as pd import pandas as pd
import numpy as np
import argparse import argparse
from pathlib import Path from pathlib import Path
import robofish.io import robofish.io
...@@ -60,26 +61,244 @@ DEFAULT_COLUMNS = [ ...@@ -60,26 +61,244 @@ DEFAULT_COLUMNS = [
] ]
def handle_file(file): def get_distances(poses, last_poses=None):
pf = pd.read_csv(file) def get_single_distance(pose_1, pose_2):
pf.columns = DEFAULT_COLUMNS assert pose_1.shape == pose_2.shape == (3,)
print(pf) return np.linalg.norm(pose_1[:2] - pose_2[:2])
iof = robofish.io.File(world_size=[100, 100]) if poses.ndim == 2:
n_fish, three = poses.shape
distances = np.zeros((n_fish, n_fish))
for i in range(n_fish):
for j in range(n_fish):
if last_poses is None:
distances[i, j] = get_single_distance(poses[i], poses[j])
else:
distances[i, j] = get_single_distance(poses[i], last_poses[j])
else:
n_frames, n_fish, three = poses.shape
distances = np.zeros((n_frames, n_fish, n_fish))
for t in range(n_frames - 1):
for i in range(n_fish):
for j in range(n_fish):
assert last_poses is None
distances[t, i, j] = get_single_distance(
poses[t + 1, i], poses[t, j]
)
return distances
def handle_file(file, args):
pf = pd.read_csv(file, skiprows=4, header=None, sep=args.sep)
if args.columns_per_entity is None:
all_col_types_matching = []
for cols in range(1, len(pf.columns) // 2 + 1):
dt = np.array(pf.dtypes, dtype=str)
# print(dt)
extraced_col_types = np.array(
[
dt[len(dt) - (cols * (i + 1)) : len(dt) - (cols * i)]
for i in range(len(dt) // cols)
]
)
matching_types = [
all(extraced_col_types[i] == extraced_col_types[0])
for i in range(len(extraced_col_types))
]
if all(matching_types):
all_col_types_matching.append(cols)
# print(cols, "\t", matching_types)
# pf.columns = DEFAULT_COLUMNS
if len(all_col_types_matching) == 0:
print(
"Error: Could not detect columns_per_entity. Please specify manually using --columns_per_entity"
)
else:
assert [
all_col_types_matching[i] % all_col_types_matching[0] == 0
for i in range(0, len(all_col_types_matching))
], f"Error: Found multiple columns_per_entity which were not multiples of each other {all_col_types_matching}"
columns_per_entity = all_col_types_matching[0]
print("Found columns_per_entity: %d" % columns_per_entity)
else:
columns_per_entity = args.columns_per_entity
n_fish = len(pf.columns) // columns_per_entity
header_cols = len(pf.columns) % n_fish
print(
f"Header columns: {header_cols}, n_fish: {n_fish}, columns per fish: {columns_per_entity} total columns: {len(pf.columns)}"
)
print()
print(
f"IMPORTANT: Check if this is correct (not automatic):\n\tColumn {args.xcol} is selected as x coordinate\n\tColumn {args.ycol} is selected as y coordinate.\n\tColumn {args.oricol} is selected as orientation.\nIf this is not correct, please specify the correct columns using --xcol, --ycol and --oricol."
)
print()
first_fish = pf.loc[:, header_cols : header_cols + columns_per_entity - 1].head()
first_fish.columns = range(columns_per_entity)
print(first_fish)
pf_np = pf.to_numpy()
print(pf_np.shape)
fname = str(file)[:-4] + ".hdf5" if args.output is None else args.output
with robofish.io.File(
fname,
"w",
world_size_cm=[args.world_size, args.world_size],
frequency_hz=args.frequency,
) as iof:
poses = np.empty((pf_np.shape[0], n_fish, 3), dtype=np.float32)
for f in range(n_fish):
f_cols = pf_np[
:,
header_cols
+ f * columns_per_entity : header_cols
+ (f + 1) * columns_per_entity,
]
poses[:, f] = f_cols[:, [args.xcol, args.ycol, args.oricol]].astype(
np.float32
)
poses[:, :, :2] -= args.world_size / 2 # center world around 0,0
poses[:, :, 1] *= -1 # flip y axis
poses[:, :, 2] *= 2 * np.pi / 360 # convert to radians
poses[:, :, 2] -= np.pi # rotate 180 degrees
def handle_switches(poses, supress=None):
"""Find and handle switches in the data.
Switches are defined as a fish that is not at the same position as it was in the previous frame.
If a switch is found, the fish is moved to the position of the fish that was at the same position in the previous frame.
If multiple switches are found, the fish that is closest to the position of the fish that was at the same position in the previous frame is moved.
Args:
poses (np.ndarray): Array of shape (n_frames, n_fish, 3) containing the poses of the fish.
supress (list): List of frames to ignore.
Returns:
np.ndarray: Array of shape (n_frames, n_fish, 3) containing the poses of the fish with switches handled.
"""
all_switches = []
last_poses = np.copy(poses[0])
assert not np.any(np.isnan(last_poses)), "Error: NaN in first frame"
for t in range(1, poses.shape[0]):
if not np.any(np.isnan(poses[t])) and (
supress is None or t not in supress
):
distances = get_distances(poses[t], last_poses)
switches = {}
for i in range(n_fish):
if np.argmin(distances[i]) != i:
switches[i] = np.argmin(distances[i])
if len(switches) > 0:
print(f"Switches at time {t}: {switches}")
if sorted(switches.keys()) == sorted(switches.values()):
print("Switches are one-to-one")
for i in switches:
print(list(switches.keys()), list(switches.values()))
connections = np.arange(n_fish, dtype=int)
for k, v in switches.items():
connections[v] = k
else:
print("Attempting to fix switches...")
smallest_distances = np.argsort(distances.flatten())
connections = np.empty(n_fish)
connections[:] = np.nan
print(distances)
print(
smallest_distances[0] // n_fish,
smallest_distances[0] % n_fish,
)
for sd in smallest_distances:
if np.isnan(connections[sd // n_fish]) and not np.any(
connections == sd % n_fish
):
connections[sd // n_fish] = sd % n_fish
if not np.any(np.isnan(connections)):
break
assert np.sum(connections) == np.sum(
range(n_fish)
) # Simple check to see if all fish are connected
connections = connections.astype(int)
print(f"Connections: {connections}")
poses[t:] = poses[t:, connections]
all_switches.append(t)
last_poses = poses[t]
return poses, all_switches
supress = []
if not args.disable_fix_switches:
for run in range(20):
print("RUN ", run)
switched_poses, all_switches = handle_switches(poses)
print(all_switches)
diff = np.diff(all_switches, axis=0)
new_supress = np.array(all_switches)[np.where(diff < 10)]
new_supress = [n for n in new_supress if n not in supress]
if len(new_supress) == 0:
break
print(supress)
supress.extend(new_supress)
type_ = "robot" if pf["Type1"][0] == "R" else "fish" distances = get_distances(switched_poses)
poses = pf[["Robo x", "Robo y", "Robo ori rad"]] switched_poses[
np.where(np.diagonal(distances > 1, axis1=1, axis2=2))
] = np.nan
poses = switched_poses
robot = iof.create_single_entity(type_=type_, name=(str)(pf["ID"][0])) for f in range(n_fish):
iof.create_entity("fish", poses[:, f])
print(iof) # assert np.all(
iof.validate() # poses[np.logical_not(np.isnan(poses[:, 0])), 0] >= 0
# ), f"Error: x coordinate is not positive, {np.min(poses[:, 0])}"
# assert (poses[:, 1] >= -1).all(), "Error: y coordinate is not positive"
# assert (poses[:, 0] <= 101).all(), "Error: x coordinate is not <= 100"
# assert (poses[:, 1] <= 101).all(), "Error: y coordinate is not <= 100"
# assert (poses[:, 2] >= 0).all(), "Error: orientation is not >= 0"
# assert (poses[:, 2] <= 2 * np.pi).all(), "Error: orientation is not 2*pi"
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="This tool converts files from csv files. The column names come currently from 2019/Q_trials." description="This tool converts files from csv files. The column names come currently from 2019/Q_trials."
) )
parser.add_argument("path", nargs="+") parser.add_argument("path", nargs="+")
parser.add_argument("-o", "--output", default=None)
parser.add_argument("--sep", default=";")
parser.add_argument("--header", default=3)
parser.add_argument("--columns_per_entity", default=None)
parser.add_argument("--xcol", default=4)
parser.add_argument("--ycol", default=5)
parser.add_argument("--oricol", default=7)
parser.add_argument("--world_size", default=100)
parser.add_argument("--frequency", default=25)
parser.add_argument("--disable_fix_switches", action="store_true")
args = parser.parse_args() args = parser.parse_args()
for path in args.path: for path in args.path:
...@@ -90,12 +309,12 @@ for path in args.path: ...@@ -90,12 +309,12 @@ for path in args.path:
continue continue
if path.suffix == ".csv": if path.suffix == ".csv":
handle_file(path) handle_file(path, args)
elif path.is_dir(): elif path.is_dir():
files = path.rglob("*.csv") files = path.rglob("*.csv")
for file in tqdm(files): for file in tqdm(files):
handle_file(file) handle_file(file, args)
else: else:
print("'%s' is not a folder nor a csv file" % path) print("'%s' is not a folder nor a csv file" % path)
continue continue
...@@ -344,6 +344,18 @@ def update_individual_ids(args=None): ...@@ -344,6 +344,18 @@ def update_individual_ids(args=None):
n_fish = None n_fish = None
with robofish.io.File(file, "r+") as f: with robofish.io.File(file, "r+") as f:
if "initial_poses_info" in f:
print(f"Found initial_poses_info in {file}.")
for e, entity in enumerate(f.entities):
entity.attrs["individual_id"] = int(
f["initial_poses_info"].attrs["individual_ids"][e]
)
running_individual_id = None
else:
assert (
running_individual_id is not None
), "The update script found files with initial_poses_info and files without. Mixing them is not supported."
if n_fish is None: if n_fish is None:
n_fish = len(f.entities) n_fish = len(f.entities)
else: else:
...@@ -359,6 +371,7 @@ def update_individual_ids(args=None): ...@@ -359,6 +371,7 @@ def update_individual_ids(args=None):
), f"Video in file {file} is not the same as in the previous file." ), f"Video in file {file} is not the same as in the previous file."
for e, entity in enumerate(f.entities): for e, entity in enumerate(f.entities):
entity.attrs["individual_id"] = running_individual_id + e entity.attrs["individual_id"] = running_individual_id + e
# Delete the old individual_id attribute # Delete the old individual_id attribute
......
...@@ -69,16 +69,26 @@ class Entity(h5py.Group): ...@@ -69,16 +69,26 @@ class Entity(h5py.Group):
return entity return entity
@classmethod @classmethod
def convert_rad_to_vector(cla, orientations_rad): def convert_rad_to_vector(cla, orientations_rad: np.ndarray) -> np.ndarray:
if min(orientations_rad) < 0 or max(orientations_rad) > 2 * np.pi: """Converts an orientation array from radiants to a vector.
Args:
orientations_rad (np.ndarray): Orientations in radiants (shape: (n, 1))
Returns:
np.ndarray: Orientations as vectors (shape: (n, 2))
"""
assert orientations_rad.ndim == 2 and orientations_rad.shape[1] == 1
if np.nanmin(orientations_rad) < 0 or np.nanmax(orientations_rad) > 2 * np.pi:
logging.warning( logging.warning(
"Converting orientations, from a bigger range than [0, 2 * pi]. When passing the orientations, they are assumed to be in radians." f"Converting orientations, from a bigger range than [0, 2 * pi]: [{np.nanmin(orientations_rad)}, {np.nanmax(orientations_rad)}]. When passing the orientations, they are assumed to be in radians."
) )
ori_rad = utils.np_array(orientations_rad) ori_rad = utils.np_array(orientations_rad)
assert ori_rad.shape[1] == 1 assert ori_rad.shape[1] == 1
ori_vec = np.empty((ori_rad.shape[0], 2)) ori_vec = np.empty((ori_rad.shape[0], 2))
ori_vec[:, 0] = np.cos(ori_rad[:, 0]) valid_rows = ~np.isnan(ori_rad)[:, 0]
ori_vec[:, 1] = np.sin(ori_rad[:, 0]) ori_vec[valid_rows, 0] = np.cos(ori_rad[valid_rows, 0])
ori_vec[valid_rows, 1] = np.sin(ori_rad[valid_rows, 0])
return ori_vec return ori_vec
@property @property
......
...@@ -1180,7 +1180,14 @@ class File(h5py.File): ...@@ -1180,7 +1180,14 @@ class File(h5py.File):
n_frames = int((frame_range[1] - frame_range[0]) / options["speedup"]) n_frames = int((frame_range[1] - frame_range[0]) / options["speedup"])
start_pose = self.entity_poses_rad[:, frame_range[0]] for skip in range(20):
start_pose = self.entity_poses_rad[:, frame_range[0] + skip]
if not np.any(np.isnan(start_pose)):
break
else:
raise ValueError(
"Could not find a valid start pose in the first 20 frames."
)
self.middle_of_swarm = np.mean(start_pose, axis=0) self.middle_of_swarm = np.mean(start_pose, axis=0)
min_view = np.max((np.max(start_pose, axis=0) - np.min(start_pose, axis=0))[:2]) min_view = np.max((np.max(start_pose, axis=0) - np.min(start_pose, axis=0))[:2])
...@@ -1203,18 +1210,26 @@ class File(h5py.File): ...@@ -1203,18 +1210,26 @@ class File(h5py.File):
if not options["fixed_view"]: if not options["fixed_view"]:
# Find the maximal distance between the entities in x or y direction # Find the maximal distance between the entities in x or y direction
min_view = np.max( min_view = np.nanmax(
(np.max(this_pose, axis=0) - np.min(this_pose, axis=0))[:2] (np.nanmax(this_pose, axis=0) - np.nanmin(this_pose, axis=0))[
:2
]
) )
new_view_size = np.max( new_view_size = np.nanmax(
[options["view_size"], min_view + options["margin"]] [options["view_size"], min_view + options["margin"]]
) )
if not np.isnan(min_view).any() and new_view_size is not np.nan: if (
not np.any(np.isnan(min_view))
and not np.any(np.isnan(new_view_size))
and not np.any(np.isnan(this_pose))
):
self.middle_of_swarm = options[ self.middle_of_swarm = options[
"slow_view" "slow_view"
] * self.middle_of_swarm + (1 - options["slow_view"]) * np.mean( ] * self.middle_of_swarm + (
1 - options["slow_view"]
) * np.nanmean(
this_pose, axis=0 this_pose, axis=0
) )
...@@ -1231,6 +1246,7 @@ class File(h5py.File): ...@@ -1231,6 +1246,7 @@ class File(h5py.File):
self.middle_of_swarm[1] - self.view_size / 2, self.middle_of_swarm[1] - self.view_size / 2,
self.middle_of_swarm[1] + self.view_size / 2, self.middle_of_swarm[1] + self.view_size / 2,
) )
if options["show_text"]: if options["show_text"]:
ax.set_title(title(file_frame)) ax.set_title(title(file_frame))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment