From 5c588e8c8a456f38c118f94499f13525c23ce3ef Mon Sep 17 00:00:00 2001 From: Andi Gerken <andi.gerken@gmail.com> Date: Thu, 15 Dec 2022 18:00:09 +0000 Subject: [PATCH] Improved switch detection --- src/conversion_scripts/convert_from_csv.py | 257 ++++++++++----------- src/conversion_scripts/create_test_csv.py | 33 +++ src/conversion_scripts/print_file.py | 41 ++++ src/robofish/io/app.py | 1 + src/robofish/io/file.py | 29 ++- 5 files changed, 224 insertions(+), 137 deletions(-) create mode 100644 src/conversion_scripts/create_test_csv.py create mode 100644 src/conversion_scripts/print_file.py diff --git a/src/conversion_scripts/convert_from_csv.py b/src/conversion_scripts/convert_from_csv.py index e437cd8..def488e 100644 --- a/src/conversion_scripts/convert_from_csv.py +++ b/src/conversion_scripts/convert_from_csv.py @@ -15,6 +15,7 @@ The script is not finished and still in progress. import pandas as pd import numpy as np import argparse +import itertools from pathlib import Path import robofish.io @@ -27,41 +28,7 @@ except ImportError as e: return x -DEFAULT_COLUMNS = [ - "Framenumber", - "TimeHuman", - "Time", - "Type1", - "ID", - "Robo x", - "Robo y", - "Robo ori deg", - "Robo ori rad", - "Type2", - "ID2", - "Fish x", - "Fish y", - "Fish ori deg", - "Fish ori rad", - "Mode", - "FishModel", - "LeadSubexperiment", - "fear", - "follow", - "cz", - "rz", - "lz", - "v_max", - "adapt", - "ADir", - "ASpeed", - "AClose", - "LSpeed", - "LClose", -] - - -def get_distances(poses, last_poses=None): +def get_distances(poses, last_poses=None, diagonal=False): def get_single_distance(pose_1, pose_2): assert pose_1.shape == pose_2.shape == (3,) return np.linalg.norm(pose_1[:2] - pose_2[:2]) @@ -71,10 +38,11 @@ def get_distances(poses, last_poses=None): distances = np.zeros((n_fish, n_fish)) for i in range(n_fish): for j in range(n_fish): - if last_poses is None: - distances[i, j] = get_single_distance(poses[i], poses[j]) - else: - distances[i, j] = get_single_distance(poses[i], last_poses[j]) + if i == j or not diagonal: + if last_poses is None: + distances[i, j] = get_single_distance(poses[i], poses[j]) + else: + distances[i, j] = get_single_distance(poses[i], last_poses[j]) else: n_frames, n_fish, three = poses.shape distances = np.zeros((n_frames, n_fish, n_fish)) @@ -82,11 +50,96 @@ def get_distances(poses, last_poses=None): for t in range(n_frames - 1): for i in range(n_fish): for j in range(n_fish): - assert last_poses is None - distances[t, i, j] = get_single_distance( - poses[t + 1, i], poses[t, j] - ) - return distances + if i == j or not diagonal: + assert last_poses is None + distances[t, i, j] = get_single_distance( + poses[t + 1, i], poses[t, j] + ) + if diagonal: + return np.diagonal(distances, axis1=-1, axis2=-2) + else: + return distances + + +def handle_switches(poses, supress=None): + """Find and handle switches in the data. + Switches are defined as a fish that is not at the same position as it was in the previous frame. + If a switch is found, the fish is moved to the position of the fish that was at the same position in the previous frame. + If multiple switches are found, the fish that is closest to the position of the fish that was at the same position in the previous frame is moved. + + Args: + poses (np.ndarray): Array of shape (n_frames, n_fish, 3) containing the poses of the fish. + supress (list): List of frames to ignore. + Returns: + np.ndarray: Array of shape (n_frames, n_fish, 3) containing the poses of the fish with switches handled. + """ + + n_fish = poses.shape[1] + + all_switches = [] + + last_poses = np.copy(poses[0]) + + all_connection_permutations = list(itertools.permutations(np.arange(n_fish))) + + for t in range(1, poses.shape[0]): + if np.all(np.isclose(np.abs(np.diff(poses[t - 1], axis=0)), 0)): + print( + f"Warning: All fish at same position in frame {t - 1} setting them to NaN" + ) + poses[t - 1] = np.nan + + if supress is None or t not in supress: + + distances = get_distances(poses[t], last_poses) + + switches = {} + for i in range(n_fish): + if np.argmin(distances[i]) != i: + switches[i] = np.argmin(distances[i]) + + if len(switches) > 0 or np.any( + np.isnan(poses[t]) != np.isnan(poses[t - 1]) + ): + + switch_distances = np.array( + [ + get_distances(poses[t, con_perm], last_poses, diagonal=True) + for con_perm in all_connection_permutations + ] + ) + + switch_distances_sum = np.nansum(switch_distances, axis=1) + + # if t > 10 and t < 18: + # print("Poses\n", poses[t], "\nLast Poses\n", last_poses) + # print(t, "\n", switch_distances[:2], "\n", switch_distances_sum[:2]) + + connections = np.array( + all_connection_permutations[np.argmin(switch_distances_sum)] + ).astype(int) + + distances_normal = switch_distances_sum[0] + distances_switched = np.min(switch_distances_sum) + + if np.argmin(switch_distances_sum) != 0: + if distances_switched * 1.5 < distances_normal: + + print( + f"Switch: {connections} distance sum\t{np.min(switch_distances_sum):.2f} vs\t{np.min(switch_distances_sum[0]):.2f}" + ) + poses[t:] = poses[t:, connections] + + all_switches.append(t) + + # Update last poses for every fish that is not nan + last_poses[np.where(~np.isnan(poses[t]))] = poses[t][ + np.where(~np.isnan(poses[t])) + ] + + assert not np.any(np.isnan(last_poses)), "Error: NaN in last_poses" + + return poses, all_switches def handle_file(file, args): @@ -114,7 +167,6 @@ def handle_file(file, args): all_col_types_matching.append(cols) # print(cols, "\t", matching_types) - # pf.columns = DEFAULT_COLUMNS if len(all_col_types_matching) == 0: print( "Error: Could not detect columns_per_entity. Please specify manually using --columns_per_entity" @@ -128,7 +180,7 @@ def handle_file(file, args): print("Found columns_per_entity: %d" % columns_per_entity) else: - columns_per_entity = args.columns_per_entity + columns_per_entity = int(args.columns_per_entity) n_fish = len(pf.columns) // columns_per_entity header_cols = len(pf.columns) % n_fish @@ -167,113 +219,48 @@ def handle_file(file, args): + f * columns_per_entity : header_cols + (f + 1) * columns_per_entity, ] - poses[:, f] = f_cols[:, [args.xcol, args.ycol, args.oricol]].astype( - np.float32 - ) - - poses[:, :, :2] -= args.world_size / 2 # center world around 0,0 - poses[:, :, 1] *= -1 # flip y axis - poses[:, :, 2] *= 2 * np.pi / 360 # convert to radians - poses[:, :, 2] -= np.pi # rotate 180 degrees - - def handle_switches(poses, supress=None): - """Find and handle switches in the data. - Switches are defined as a fish that is not at the same position as it was in the previous frame. - If a switch is found, the fish is moved to the position of the fish that was at the same position in the previous frame. - If multiple switches are found, the fish that is closest to the position of the fish that was at the same position in the previous frame is moved. - - Args: - poses (np.ndarray): Array of shape (n_frames, n_fish, 3) containing the poses of the fish. - supress (list): List of frames to ignore. - Returns: - np.ndarray: Array of shape (n_frames, n_fish, 3) containing the poses of the fish with switches handled. - """ - - all_switches = [] - last_poses = np.copy(poses[0]) - assert not np.any(np.isnan(last_poses)), "Error: NaN in first frame" - for t in range(1, poses.shape[0]): - if not np.any(np.isnan(poses[t])) and ( - supress is None or t not in supress - ): - - distances = get_distances(poses[t], last_poses) - - switches = {} - for i in range(n_fish): - if np.argmin(distances[i]) != i: - switches[i] = np.argmin(distances[i]) - - if len(switches) > 0: - print(f"Switches at time {t}: {switches}") - - if sorted(switches.keys()) == sorted(switches.values()): - print("Switches are one-to-one") - - for i in switches: - print(list(switches.keys()), list(switches.values())) - connections = np.arange(n_fish, dtype=int) - for k, v in switches.items(): - connections[v] = k - else: - print("Attempting to fix switches...") - - smallest_distances = np.argsort(distances.flatten()) - - connections = np.empty(n_fish) - connections[:] = np.nan - - print(distances) - print( - smallest_distances[0] // n_fish, - smallest_distances[0] % n_fish, - ) - for sd in smallest_distances: - if np.isnan(connections[sd // n_fish]) and not np.any( - connections == sd % n_fish - ): - connections[sd // n_fish] = sd % n_fish - - if not np.any(np.isnan(connections)): - break - assert np.sum(connections) == np.sum( - range(n_fish) - ) # Simple check to see if all fish are connected - - connections = connections.astype(int) - print(f"Connections: {connections}") - poses[t:] = poses[t:, connections] - - all_switches.append(t) + poses[:, f] = f_cols[ + :, [int(args.xcol), int(args.ycol), int(args.oricol)] + ].astype(np.float32) - last_poses = poses[t] - return poses, all_switches + if not args.disable_centering: + poses[:, :, :2] -= args.world_size / 2 # center world around 0,0 + poses[:, :, 1] *= -1 # flip y axis + poses[:, :, 2] *= 2 * np.pi / 360 # convert to radians + poses[:, :, 2] -= np.pi # rotate 180 degrees supress = [] + all_switches = None if not args.disable_fix_switches: for run in range(20): print("RUN ", run) switched_poses, all_switches = handle_switches(poses) - - print(all_switches) + print("All switches: ", all_switches) diff = np.diff(all_switches, axis=0) - new_supress = np.array(all_switches)[np.where(diff < 10)] + new_supress = np.array(all_switches)[ + np.where(diff < args.min_timesteps_between_switches) + ] new_supress = [n for n in new_supress if n not in supress] + if len(new_supress) == 0: + print("No need to supress more.") break - print(supress) + supress.extend(new_supress) + print("supressing in the next runn: ", supress) - distances = get_distances(switched_poses) - switched_poses[ - np.where(np.diagonal(distances > 1, axis1=1, axis2=2)) - ] = np.nan + # distances = get_distances(switched_poses) + # switched_poses[ + # np.where(np.diagonal(distances > 1, axis1=1, axis2=2)) + # ] = np.nan poses = switched_poses for f in range(n_fish): iof.create_entity("fish", poses[:, f]) + if all_switches is not None: + iof.attrs["switches"] = all_switches # assert np.all( # poses[np.logical_not(np.isnan(poses[:, 0])), 0] >= 0 @@ -299,6 +286,8 @@ parser.add_argument("--oricol", default=7) parser.add_argument("--world_size", default=100) parser.add_argument("--frequency", default=25) parser.add_argument("--disable_fix_switches", action="store_true") +parser.add_argument("--disable_centering", action="store_true") +parser.add_argument("--min_timesteps_between_switches", type=int, default=0) args = parser.parse_args() for path in args.path: diff --git a/src/conversion_scripts/create_test_csv.py b/src/conversion_scripts/create_test_csv.py new file mode 100644 index 0000000..40623ec --- /dev/null +++ b/src/conversion_scripts/create_test_csv.py @@ -0,0 +1,33 @@ +import pandas as pd +import numpy as np + +# python convert_from_csv.py test.csv --sep , --header 1 --columns_per_entity 3 --xcol 0 --ycol 1 --oricol 2 --disable_fix_switches + +df = pd.DataFrame(columns=["header", "x1", "y1", "o1", "x2", "y2", "o2"]) + +df["header"] = 0 + +df["x1"] = np.linspace(20, 80, 60) +df["x2"] = np.linspace(20, 80, 60) +df["y1"] = 40 +df["y2"] = 60 + +# Switch y1 and y2 after 50 samples +df.loc[20:30, "y1"] = 60 +df.loc[20:30, "y2"] = 40 + +# Set x1 and y1 to np.nan for 5 samples +df.loc[40:45, "x1"] = np.nan +df.loc[40:45, "y1"] = np.nan +df.loc[46:50, "y1"] = 60 +df.loc[46:50, "y2"] = 40 +df.loc[51:55, "x2"] = np.nan +df.loc[51:55, "y2"] = np.nan + + +df["o1"] = 180 +df["o2"] = 180 + + +print(df) +df.to_csv("test.csv", index=False) diff --git a/src/conversion_scripts/print_file.py b/src/conversion_scripts/print_file.py new file mode 100644 index 0000000..03455b1 --- /dev/null +++ b/src/conversion_scripts/print_file.py @@ -0,0 +1,41 @@ +import robofish.io +import pandas as pd +import numpy as np + + +def print_file(args): + + # Open the robofish.io.File + with robofish.io.File(args.filename, "r") as f: + timesteps = ( + args.max_timesteps + args.skip_timesteps + if args.max_timesteps is not None + else f.entity_positions.shape[1] + ) + + n_fish = len(f.entities) + df = pd.DataFrame( + columns=np.concatenate([[f"x{i}", f"y{i}", f"o{i}"] for i in range(n_fish)]) + ) + + for i in range(n_fish): + df[f"x{i}"] = f.entity_positions[i][args.skip_timesteps : timesteps, 0] + df[f"y{i}"] = f.entity_positions[i][args.skip_timesteps : timesteps, 1] + df[f"o{i}"] = f.entity_orientations_rad[i][ + args.skip_timesteps : timesteps, 0 + ] + + print(df) + + df.to_csv("cutout.csv", index=False) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Print a file") + parser.add_argument("filename", help="File to print") + parser.add_argument("--skip_timesteps", type=int, default=0, help="Skip timesteps") + parser.add_argument("--max_timesteps", type=int, default=None, help="Max timesteps") + args = parser.parse_args() + print_file(args) diff --git a/src/robofish/io/app.py b/src/robofish/io/app.py index 16f714d..eaf9d96 100644 --- a/src/robofish/io/app.py +++ b/src/robofish/io/app.py @@ -235,6 +235,7 @@ def render(args: argparse.Namespace = None) -> None: "show_ids": False, "render_goals": False, "render_targets": False, + "highlight_switches": False, "figsize": 10, } diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index b3f7913..114e417 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -1035,6 +1035,7 @@ class File(h5py.File): "show_ids": False, "render_goals": False, "render_targets": False, + "highlight_switches": False, "dpi": 200, "figsize": 10, } @@ -1086,16 +1087,20 @@ class File(h5py.File): plt.plot([], [], linestyle="dotted", alpha=0.5, color="k", zorder=0)[0], ] categories = [entity.attrs.get("category", None) for entity in self.entities] + entity_colors = [lines[entity].get_color() for entity in range(n_entities)] entity_polygons = [ patches.Polygon( shape_vertices(options["entity_scale"]), edgecolor=edgecolor, facecolor=color, + alpha=0.8, ) for edgecolor, color in [ - ("k", "white") if category == "robot" else ("k", "k") - for category in categories + ("k", "white") + if category == "robot" + else (entity_colors[entity], entity_colors[entity]) + for entity, category in enumerate(categories) ] ] @@ -1197,6 +1202,8 @@ class File(h5py.File): pbar = tqdm(range(n_frames)) def update(frame): + output_list = [] + if "pbar" in locals().keys(): pbar.update(1) pbar.refresh() @@ -1207,6 +1214,19 @@ class File(h5py.File): file_frame = (frame * options["speedup"]) + frame_range[0] this_pose = entity_poses[:, file_frame] + if options["highlight_switches"] and "switches" in self.attrs: + if any( + [ + file_frame + i in self.attrs["switches"] + for i in range(options["speedup"]) + ] + ): + ax.set_facecolor("lightgray") + + else: + ax.set_facecolor("white") + output_list.append(ax) + if not options["fixed_view"]: # Find the maximal distance between the entities in x or y direction @@ -1284,7 +1304,10 @@ class File(h5py.File): raise Exception( f"Frame is bigger than n_frames {file_frame} of {n_frames}" ) - return lines + entity_polygons + [border] + points + annotations + + return ( + output_list + lines + entity_polygons + [border] + points + annotations + ) print(f"Preparing to render n_frames: {n_frames}") -- GitLab