diff --git a/src/conversion_scripts/convert_from_csv.py b/src/conversion_scripts/convert_from_csv.py index f9e6405af1d528ebe7007571adea58c67958fda0..e437cd82ed603309497b9797cdb5e1b85ec50a20 100644 --- a/src/conversion_scripts/convert_from_csv.py +++ b/src/conversion_scripts/convert_from_csv.py @@ -13,6 +13,7 @@ The script is not finished and still in progress. # flake8: noqa import pandas as pd +import numpy as np import argparse from pathlib import Path import robofish.io @@ -60,26 +61,244 @@ DEFAULT_COLUMNS = [ ] -def handle_file(file): - pf = pd.read_csv(file) - pf.columns = DEFAULT_COLUMNS - print(pf) +def get_distances(poses, last_poses=None): + def get_single_distance(pose_1, pose_2): + assert pose_1.shape == pose_2.shape == (3,) + return np.linalg.norm(pose_1[:2] - pose_2[:2]) - iof = robofish.io.File(world_size=[100, 100]) + if poses.ndim == 2: + n_fish, three = poses.shape + distances = np.zeros((n_fish, n_fish)) + for i in range(n_fish): + for j in range(n_fish): + if last_poses is None: + distances[i, j] = get_single_distance(poses[i], poses[j]) + else: + distances[i, j] = get_single_distance(poses[i], last_poses[j]) + else: + n_frames, n_fish, three = poses.shape + distances = np.zeros((n_frames, n_fish, n_fish)) + + for t in range(n_frames - 1): + for i in range(n_fish): + for j in range(n_fish): + assert last_poses is None + distances[t, i, j] = get_single_distance( + poses[t + 1, i], poses[t, j] + ) + return distances + + +def handle_file(file, args): + pf = pd.read_csv(file, skiprows=4, header=None, sep=args.sep) + if args.columns_per_entity is None: + all_col_types_matching = [] + for cols in range(1, len(pf.columns) // 2 + 1): + + dt = np.array(pf.dtypes, dtype=str) + # print(dt) + + extraced_col_types = np.array( + [ + dt[len(dt) - (cols * (i + 1)) : len(dt) - (cols * i)] + for i in range(len(dt) // cols) + ] + ) + + matching_types = [ + all(extraced_col_types[i] == extraced_col_types[0]) + for i in range(len(extraced_col_types)) + ] + + if all(matching_types): + all_col_types_matching.append(cols) + # print(cols, "\t", matching_types) + + # pf.columns = DEFAULT_COLUMNS + if len(all_col_types_matching) == 0: + print( + "Error: Could not detect columns_per_entity. Please specify manually using --columns_per_entity" + ) + else: + assert [ + all_col_types_matching[i] % all_col_types_matching[0] == 0 + for i in range(0, len(all_col_types_matching)) + ], f"Error: Found multiple columns_per_entity which were not multiples of each other {all_col_types_matching}" + columns_per_entity = all_col_types_matching[0] + print("Found columns_per_entity: %d" % columns_per_entity) + + else: + columns_per_entity = args.columns_per_entity + + n_fish = len(pf.columns) // columns_per_entity + header_cols = len(pf.columns) % n_fish + + print( + f"Header columns: {header_cols}, n_fish: {n_fish}, columns per fish: {columns_per_entity} total columns: {len(pf.columns)}" + ) + print() + + print( + f"IMPORTANT: Check if this is correct (not automatic):\n\tColumn {args.xcol} is selected as x coordinate\n\tColumn {args.ycol} is selected as y coordinate.\n\tColumn {args.oricol} is selected as orientation.\nIf this is not correct, please specify the correct columns using --xcol, --ycol and --oricol." + ) + print() + + first_fish = pf.loc[:, header_cols : header_cols + columns_per_entity - 1].head() + first_fish.columns = range(columns_per_entity) + print(first_fish) + + pf_np = pf.to_numpy() + print(pf_np.shape) + + fname = str(file)[:-4] + ".hdf5" if args.output is None else args.output + + with robofish.io.File( + fname, + "w", + world_size_cm=[args.world_size, args.world_size], + frequency_hz=args.frequency, + ) as iof: + poses = np.empty((pf_np.shape[0], n_fish, 3), dtype=np.float32) + + for f in range(n_fish): + f_cols = pf_np[ + :, + header_cols + + f * columns_per_entity : header_cols + + (f + 1) * columns_per_entity, + ] + poses[:, f] = f_cols[:, [args.xcol, args.ycol, args.oricol]].astype( + np.float32 + ) + + poses[:, :, :2] -= args.world_size / 2 # center world around 0,0 + poses[:, :, 1] *= -1 # flip y axis + poses[:, :, 2] *= 2 * np.pi / 360 # convert to radians + poses[:, :, 2] -= np.pi # rotate 180 degrees + + def handle_switches(poses, supress=None): + """Find and handle switches in the data. + Switches are defined as a fish that is not at the same position as it was in the previous frame. + If a switch is found, the fish is moved to the position of the fish that was at the same position in the previous frame. + If multiple switches are found, the fish that is closest to the position of the fish that was at the same position in the previous frame is moved. + + Args: + poses (np.ndarray): Array of shape (n_frames, n_fish, 3) containing the poses of the fish. + supress (list): List of frames to ignore. + Returns: + np.ndarray: Array of shape (n_frames, n_fish, 3) containing the poses of the fish with switches handled. + """ + + all_switches = [] + last_poses = np.copy(poses[0]) + assert not np.any(np.isnan(last_poses)), "Error: NaN in first frame" + for t in range(1, poses.shape[0]): + if not np.any(np.isnan(poses[t])) and ( + supress is None or t not in supress + ): + + distances = get_distances(poses[t], last_poses) + + switches = {} + for i in range(n_fish): + if np.argmin(distances[i]) != i: + switches[i] = np.argmin(distances[i]) + + if len(switches) > 0: + print(f"Switches at time {t}: {switches}") + + if sorted(switches.keys()) == sorted(switches.values()): + print("Switches are one-to-one") + + for i in switches: + print(list(switches.keys()), list(switches.values())) + connections = np.arange(n_fish, dtype=int) + for k, v in switches.items(): + connections[v] = k + else: + print("Attempting to fix switches...") + + smallest_distances = np.argsort(distances.flatten()) + + connections = np.empty(n_fish) + connections[:] = np.nan + + print(distances) + print( + smallest_distances[0] // n_fish, + smallest_distances[0] % n_fish, + ) + for sd in smallest_distances: + if np.isnan(connections[sd // n_fish]) and not np.any( + connections == sd % n_fish + ): + connections[sd // n_fish] = sd % n_fish + + if not np.any(np.isnan(connections)): + break + assert np.sum(connections) == np.sum( + range(n_fish) + ) # Simple check to see if all fish are connected + + connections = connections.astype(int) + print(f"Connections: {connections}") + poses[t:] = poses[t:, connections] + + all_switches.append(t) + + last_poses = poses[t] + return poses, all_switches + + supress = [] + + if not args.disable_fix_switches: + for run in range(20): + print("RUN ", run) + switched_poses, all_switches = handle_switches(poses) + + print(all_switches) + diff = np.diff(all_switches, axis=0) + + new_supress = np.array(all_switches)[np.where(diff < 10)] + new_supress = [n for n in new_supress if n not in supress] + if len(new_supress) == 0: + break + print(supress) + supress.extend(new_supress) - type_ = "robot" if pf["Type1"][0] == "R" else "fish" - poses = pf[["Robo x", "Robo y", "Robo ori rad"]] + distances = get_distances(switched_poses) + switched_poses[ + np.where(np.diagonal(distances > 1, axis1=1, axis2=2)) + ] = np.nan + poses = switched_poses - robot = iof.create_single_entity(type_=type_, name=(str)(pf["ID"][0])) + for f in range(n_fish): + iof.create_entity("fish", poses[:, f]) - print(iof) - iof.validate() + # assert np.all( + # poses[np.logical_not(np.isnan(poses[:, 0])), 0] >= 0 + # ), f"Error: x coordinate is not positive, {np.min(poses[:, 0])}" + # assert (poses[:, 1] >= -1).all(), "Error: y coordinate is not positive" + # assert (poses[:, 0] <= 101).all(), "Error: x coordinate is not <= 100" + # assert (poses[:, 1] <= 101).all(), "Error: y coordinate is not <= 100" + # assert (poses[:, 2] >= 0).all(), "Error: orientation is not >= 0" + # assert (poses[:, 2] <= 2 * np.pi).all(), "Error: orientation is not 2*pi" parser = argparse.ArgumentParser( description="This tool converts files from csv files. The column names come currently from 2019/Q_trials." ) parser.add_argument("path", nargs="+") +parser.add_argument("-o", "--output", default=None) +parser.add_argument("--sep", default=";") +parser.add_argument("--header", default=3) +parser.add_argument("--columns_per_entity", default=None) +parser.add_argument("--xcol", default=4) +parser.add_argument("--ycol", default=5) +parser.add_argument("--oricol", default=7) +parser.add_argument("--world_size", default=100) +parser.add_argument("--frequency", default=25) +parser.add_argument("--disable_fix_switches", action="store_true") args = parser.parse_args() for path in args.path: @@ -90,12 +309,12 @@ for path in args.path: continue if path.suffix == ".csv": - handle_file(path) + handle_file(path, args) elif path.is_dir(): files = path.rglob("*.csv") for file in tqdm(files): - handle_file(file) + handle_file(file, args) else: print("'%s' is not a folder nor a csv file" % path) continue diff --git a/src/robofish/io/app.py b/src/robofish/io/app.py index f89674b513778f3c5e3d85b079277474e7672bba..16f714d8553fd60950e475f4f14dc5ec9200fcd0 100644 --- a/src/robofish/io/app.py +++ b/src/robofish/io/app.py @@ -344,28 +344,41 @@ def update_individual_ids(args=None): n_fish = None with robofish.io.File(file, "r+") as f: - if n_fish is None: - n_fish = len(f.entities) + if "initial_poses_info" in f: + print(f"Found initial_poses_info in {file}.") + for e, entity in enumerate(f.entities): + entity.attrs["individual_id"] = int( + f["initial_poses_info"].attrs["individual_ids"][e] + ) + + running_individual_id = None else: - assert n_fish == len( - f.entities - ), f"Number of fish in file {file} is not the same as in the previous file." - if "video" in f.attrs: - if video is None and "video" in f.attrs: - video = f.attrs["video"] + assert ( + running_individual_id is not None + ), "The update script found files with initial_poses_info and files without. Mixing them is not supported." + if n_fish is None: + n_fish = len(f.entities) else: - assert ( - video == f.attrs["video"] - ), f"Video in file {file} is not the same as in the previous file." - - for e, entity in enumerate(f.entities): - entity.attrs["individual_id"] = running_individual_id + e - - # Delete the old individual_id attribute - if "global_individual_id" in entity.attrs: - del entity.attrs["global_individual_id"] - - running_individual_id += n_fish + assert n_fish == len( + f.entities + ), f"Number of fish in file {file} is not the same as in the previous file." + if "video" in f.attrs: + if video is None and "video" in f.attrs: + video = f.attrs["video"] + else: + assert ( + video == f.attrs["video"] + ), f"Video in file {file} is not the same as in the previous file." + + for e, entity in enumerate(f.entities): + + entity.attrs["individual_id"] = running_individual_id + e + + # Delete the old individual_id attribute + if "global_individual_id" in entity.attrs: + del entity.attrs["global_individual_id"] + + running_individual_id += n_fish print("Update finished.") for fp in sorted(files): diff --git a/src/robofish/io/entity.py b/src/robofish/io/entity.py index 3d1649cd7604053dac8e74b0a673883403a056db..5c54a4c491c942d763cf96a09a7a1bf558f38d40 100644 --- a/src/robofish/io/entity.py +++ b/src/robofish/io/entity.py @@ -69,16 +69,26 @@ class Entity(h5py.Group): return entity @classmethod - def convert_rad_to_vector(cla, orientations_rad): - if min(orientations_rad) < 0 or max(orientations_rad) > 2 * np.pi: + def convert_rad_to_vector(cla, orientations_rad: np.ndarray) -> np.ndarray: + """Converts an orientation array from radiants to a vector. + + Args: + orientations_rad (np.ndarray): Orientations in radiants (shape: (n, 1)) + + Returns: + np.ndarray: Orientations as vectors (shape: (n, 2)) + """ + assert orientations_rad.ndim == 2 and orientations_rad.shape[1] == 1 + if np.nanmin(orientations_rad) < 0 or np.nanmax(orientations_rad) > 2 * np.pi: logging.warning( - "Converting orientations, from a bigger range than [0, 2 * pi]. When passing the orientations, they are assumed to be in radians." + f"Converting orientations, from a bigger range than [0, 2 * pi]: [{np.nanmin(orientations_rad)}, {np.nanmax(orientations_rad)}]. When passing the orientations, they are assumed to be in radians." ) ori_rad = utils.np_array(orientations_rad) assert ori_rad.shape[1] == 1 ori_vec = np.empty((ori_rad.shape[0], 2)) - ori_vec[:, 0] = np.cos(ori_rad[:, 0]) - ori_vec[:, 1] = np.sin(ori_rad[:, 0]) + valid_rows = ~np.isnan(ori_rad)[:, 0] + ori_vec[valid_rows, 0] = np.cos(ori_rad[valid_rows, 0]) + ori_vec[valid_rows, 1] = np.sin(ori_rad[valid_rows, 0]) return ori_vec @property diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index cacbcb6a1e156b12d908a2faaf208e7059453cca..b3f7913192f1e718e8373dc9fc8897af8736dfda 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -1180,7 +1180,14 @@ class File(h5py.File): n_frames = int((frame_range[1] - frame_range[0]) / options["speedup"]) - start_pose = self.entity_poses_rad[:, frame_range[0]] + for skip in range(20): + start_pose = self.entity_poses_rad[:, frame_range[0] + skip] + if not np.any(np.isnan(start_pose)): + break + else: + raise ValueError( + "Could not find a valid start pose in the first 20 frames." + ) self.middle_of_swarm = np.mean(start_pose, axis=0) min_view = np.max((np.max(start_pose, axis=0) - np.min(start_pose, axis=0))[:2]) @@ -1203,18 +1210,26 @@ class File(h5py.File): if not options["fixed_view"]: # Find the maximal distance between the entities in x or y direction - min_view = np.max( - (np.max(this_pose, axis=0) - np.min(this_pose, axis=0))[:2] + min_view = np.nanmax( + (np.nanmax(this_pose, axis=0) - np.nanmin(this_pose, axis=0))[ + :2 + ] ) - new_view_size = np.max( + new_view_size = np.nanmax( [options["view_size"], min_view + options["margin"]] ) - if not np.isnan(min_view).any() and new_view_size is not np.nan: + if ( + not np.any(np.isnan(min_view)) + and not np.any(np.isnan(new_view_size)) + and not np.any(np.isnan(this_pose)) + ): self.middle_of_swarm = options[ "slow_view" - ] * self.middle_of_swarm + (1 - options["slow_view"]) * np.mean( + ] * self.middle_of_swarm + ( + 1 - options["slow_view"] + ) * np.nanmean( this_pose, axis=0 ) @@ -1231,6 +1246,7 @@ class File(h5py.File): self.middle_of_swarm[1] - self.view_size / 2, self.middle_of_swarm[1] + self.view_size / 2, ) + if options["show_text"]: ax.set_title(title(file_frame))