diff --git a/src/conversion_scripts/convert_from_csv.py b/src/conversion_scripts/convert_from_csv.py index def488edaa0f6b47b4a854ead9489a379426f037..8734c0160bc237d5f8e2aee1c5d4a416bc45dfb0 100644 --- a/src/conversion_scripts/convert_from_csv.py +++ b/src/conversion_scripts/convert_from_csv.py @@ -18,6 +18,7 @@ import argparse import itertools from pathlib import Path import robofish.io +import matplotlib.pyplot as plt try: from tqdm import tqdm @@ -74,7 +75,7 @@ def handle_switches(poses, supress=None): np.ndarray: Array of shape (n_frames, n_fish, 3) containing the poses of the fish with switches handled. """ - n_fish = poses.shape[1] + n_timesteps, n_fish, three = poses.shape all_switches = [] @@ -129,7 +130,6 @@ def handle_switches(poses, supress=None): f"Switch: {connections} distance sum\t{np.min(switch_distances_sum):.2f} vs\t{np.min(switch_distances_sum[0]):.2f}" ) poses[t:] = poses[t:, connections] - all_switches.append(t) # Update last poses for every fish that is not nan @@ -202,10 +202,10 @@ def handle_file(file, args): pf_np = pf.to_numpy() print(pf_np.shape) - fname = str(file)[:-4] + ".hdf5" if args.output is None else args.output + io_file_path = str(file)[:-4] + ".hdf5" if args.output is None else args.output with robofish.io.File( - fname, + io_file_path, "w", world_size_cm=[args.world_size, args.world_size], frequency_hz=args.frequency, @@ -213,6 +213,7 @@ def handle_file(file, args): poses = np.empty((pf_np.shape[0], n_fish, 3), dtype=np.float32) for f in range(n_fish): + f_cols = pf_np[ :, header_cols @@ -270,6 +271,112 @@ def handle_file(file, args): # assert (poses[:, 1] <= 101).all(), "Error: y coordinate is not <= 100" # assert (poses[:, 2] >= 0).all(), "Error: orientation is not >= 0" # assert (poses[:, 2] <= 2 * np.pi).all(), "Error: orientation is not 2*pi" + return io_file_path, all_switches + + +def eliminate_flips(io_file_path, analysis_path) -> None: + """This function eliminates flips in the orientation of the fish. + + First we check if we find a pair of two close flips with low speed. If we do we flip the fish between these two flips. + + Args: + io_file_path (str): The path to the hdf5 file that should be corrected. + """ + + if analysis_path is not None: + analysis_path = Path(analysis_path) + if analysis_path.exists() and analysis_path.is_dir(): + analysis_path = analysis_path / "flip_analysis.png" + + print("Eliminating flips in ", io_file_path) + + flipps = [] + flipp_starts = [] + + with robofish.io.File(io_file_path, "r+") as iof: + n_timesteps = iof.entity_actions_speeds_turns.shape[1] + + fig, ax = plt.subplots(1, len(iof.entities), figsize=(20, 7)) + + for e, entity in enumerate(iof.entities): + actions = entity.actions_speeds_turns + + turns = actions[:, 1] + + biggest_turns = np.argsort(np.abs(turns[~np.isnan(turns)]))[::-1] + + flips = {} + + for investigate_turn in biggest_turns: + if ( + investigate_turn not in flips.keys() + and investigate_turn not in flips.values() + and not np.isnan(turns[investigate_turn]) + and np.abs(turns[investigate_turn]) > 0.6 * np.pi + ): + # Find the biggest flip within 10 timesteps from investigate_turn + turns_below = turns[investigate_turn - 20 : investigate_turn] + turns_above = turns[investigate_turn + 1 : investigate_turn + 20] + + turns_wo_investigated = np.concatenate( + [turns_below, [0], turns_above] + ) + turns_wo_investigated[np.isnan(turns_wo_investigated)] = 0 + + biggest_neighbors = np.argsort(np.abs(turns_wo_investigated))[::-1] + + for neighbor in biggest_neighbors: + if ( + neighbor not in flips.keys() + and neighbor not in flips.values() + ): + if np.abs(turns_wo_investigated[neighbor]) > 0.6 * np.pi: + flips[investigate_turn] = neighbor + ( + investigate_turn - 10 + ) + break + + for k, v in flips.items(): + start = min(k, v) + end = max(k, v) + + entity["orientations"][start:end] = ( + np.pi - entity["orientations"][start:end] + ) % (np.pi * 2) + + print(f"Flipping from {start} to {end}") + + flipps.extend(list(range(start, end))) + flipp_starts.extend([start]) + + if analysis_path is not None: + all_flip_idx = np.array(list(flips.values()) + list(flips.keys())) + + # Transfer the flip ids to the sorted order of abs turns + turn_order = np.argsort(np.abs(turns)) + + x = np.arange(len(turns), dtype=np.int32) + + ax[e].scatter( + x, + np.abs(turns[turn_order]) / np.pi, + c=[ + "blue" if i not in all_flip_idx else "red" + for i in x[turn_order] + ], + alpha=0.5, + ) + + if analysis_path is not None: + plt.savefig(analysis_path) + + flipps = [i for i in range(n_timesteps) if i in flipps] + + iof.attrs["switches"] = np.array(flipps, dtype=np.int32) + + iof.update_calculated_data() + + return flipp_starts parser = argparse.ArgumentParser( @@ -286,10 +393,19 @@ parser.add_argument("--oricol", default=7) parser.add_argument("--world_size", default=100) parser.add_argument("--frequency", default=25) parser.add_argument("--disable_fix_switches", action="store_true") +parser.add_argument("--disable_fix_flips", action="store_true") parser.add_argument("--disable_centering", action="store_true") parser.add_argument("--min_timesteps_between_switches", type=int, default=0) +parser.add_argument( + "--analysis_path", + default=None, + help="Path to save analysis to. Folder will create two png files.", +) args = parser.parse_args() +if args.analysis_path is not None: + args.analysis_path = Path(args.analysis_path) + for path in args.path: path = Path(path) @@ -298,12 +414,21 @@ for path in args.path: continue if path.suffix == ".csv": - handle_file(path, args) + files = [path] elif path.is_dir(): files = path.rglob("*.csv") - for file in tqdm(files): - handle_file(file, args) else: print("'%s' is not a folder nor a csv file" % path) continue + + for file in tqdm(files): + io_file_path, all_switches = handle_file(file, args) + if not args.disable_fix_flips: + all_flipps = eliminate_flips(io_file_path, args.analysis_path) + + if args.analysis_path is not None and args.analysis_path.is_dir(): + plt.figure() + plt.hist([all_switches, all_flipps], bins=50, label=["switches", "flips"]) + plt.legend() + plt.savefig(args.analysis_path / "switches_flipps.png") diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index 114e4176560f200f2f006f5149e1162e1cad3bed..2f7d146a320611ef6a37fc21f5ca03232de367fc 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -513,7 +513,7 @@ class File(h5py.File): ), f"A 3 dimensional array was expected (entity, timestep, 3). There were {poses.ndim} dimensions in poses: {poses.shape}" assert poses.shape[2] in [3, 4] agents = poses.shape[0] - entity_names = [] + entities = [] for i in range(agents): e_name = None if names is None else names[i] @@ -521,7 +521,7 @@ class File(h5py.File): outlines if outlines is None or outlines.ndim == 3 else outlines[i] ) individual_id = None if individual_ids is None else individual_ids[i] - entity_names.append( + entities.append( self.create_entity( category=category, sampling=sampling, @@ -531,7 +531,7 @@ class File(h5py.File): outlines=e_outline, ) ) - return entity_names + return entities def update_calculated_data(self, verbose=False): changed = any([e.update_calculated_data(verbose) for e in self.entities])