Skip to content
Snippets Groups Projects
Commit 5c588e8c authored by Andi Gerken's avatar Andi Gerken
Browse files

Improved switch detection

parent db27f89b
No related branches found
No related tags found
1 merge request!43Improved switch detection
Pipeline #53832 passed
......@@ -15,6 +15,7 @@ The script is not finished and still in progress.
import pandas as pd
import numpy as np
import argparse
import itertools
from pathlib import Path
import robofish.io
......@@ -27,41 +28,7 @@ except ImportError as e:
return x
DEFAULT_COLUMNS = [
"Framenumber",
"TimeHuman",
"Time",
"Type1",
"ID",
"Robo x",
"Robo y",
"Robo ori deg",
"Robo ori rad",
"Type2",
"ID2",
"Fish x",
"Fish y",
"Fish ori deg",
"Fish ori rad",
"Mode",
"FishModel",
"LeadSubexperiment",
"fear",
"follow",
"cz",
"rz",
"lz",
"v_max",
"adapt",
"ADir",
"ASpeed",
"AClose",
"LSpeed",
"LClose",
]
def get_distances(poses, last_poses=None):
def get_distances(poses, last_poses=None, diagonal=False):
def get_single_distance(pose_1, pose_2):
assert pose_1.shape == pose_2.shape == (3,)
return np.linalg.norm(pose_1[:2] - pose_2[:2])
......@@ -71,6 +38,7 @@ def get_distances(poses, last_poses=None):
distances = np.zeros((n_fish, n_fish))
for i in range(n_fish):
for j in range(n_fish):
if i == j or not diagonal:
if last_poses is None:
distances[i, j] = get_single_distance(poses[i], poses[j])
else:
......@@ -82,13 +50,98 @@ def get_distances(poses, last_poses=None):
for t in range(n_frames - 1):
for i in range(n_fish):
for j in range(n_fish):
if i == j or not diagonal:
assert last_poses is None
distances[t, i, j] = get_single_distance(
poses[t + 1, i], poses[t, j]
)
if diagonal:
return np.diagonal(distances, axis1=-1, axis2=-2)
else:
return distances
def handle_switches(poses, supress=None):
"""Find and handle switches in the data.
Switches are defined as a fish that is not at the same position as it was in the previous frame.
If a switch is found, the fish is moved to the position of the fish that was at the same position in the previous frame.
If multiple switches are found, the fish that is closest to the position of the fish that was at the same position in the previous frame is moved.
Args:
poses (np.ndarray): Array of shape (n_frames, n_fish, 3) containing the poses of the fish.
supress (list): List of frames to ignore.
Returns:
np.ndarray: Array of shape (n_frames, n_fish, 3) containing the poses of the fish with switches handled.
"""
n_fish = poses.shape[1]
all_switches = []
last_poses = np.copy(poses[0])
all_connection_permutations = list(itertools.permutations(np.arange(n_fish)))
for t in range(1, poses.shape[0]):
if np.all(np.isclose(np.abs(np.diff(poses[t - 1], axis=0)), 0)):
print(
f"Warning: All fish at same position in frame {t - 1} setting them to NaN"
)
poses[t - 1] = np.nan
if supress is None or t not in supress:
distances = get_distances(poses[t], last_poses)
switches = {}
for i in range(n_fish):
if np.argmin(distances[i]) != i:
switches[i] = np.argmin(distances[i])
if len(switches) > 0 or np.any(
np.isnan(poses[t]) != np.isnan(poses[t - 1])
):
switch_distances = np.array(
[
get_distances(poses[t, con_perm], last_poses, diagonal=True)
for con_perm in all_connection_permutations
]
)
switch_distances_sum = np.nansum(switch_distances, axis=1)
# if t > 10 and t < 18:
# print("Poses\n", poses[t], "\nLast Poses\n", last_poses)
# print(t, "\n", switch_distances[:2], "\n", switch_distances_sum[:2])
connections = np.array(
all_connection_permutations[np.argmin(switch_distances_sum)]
).astype(int)
distances_normal = switch_distances_sum[0]
distances_switched = np.min(switch_distances_sum)
if np.argmin(switch_distances_sum) != 0:
if distances_switched * 1.5 < distances_normal:
print(
f"Switch: {connections} distance sum\t{np.min(switch_distances_sum):.2f} vs\t{np.min(switch_distances_sum[0]):.2f}"
)
poses[t:] = poses[t:, connections]
all_switches.append(t)
# Update last poses for every fish that is not nan
last_poses[np.where(~np.isnan(poses[t]))] = poses[t][
np.where(~np.isnan(poses[t]))
]
assert not np.any(np.isnan(last_poses)), "Error: NaN in last_poses"
return poses, all_switches
def handle_file(file, args):
pf = pd.read_csv(file, skiprows=4, header=None, sep=args.sep)
if args.columns_per_entity is None:
......@@ -114,7 +167,6 @@ def handle_file(file, args):
all_col_types_matching.append(cols)
# print(cols, "\t", matching_types)
# pf.columns = DEFAULT_COLUMNS
if len(all_col_types_matching) == 0:
print(
"Error: Could not detect columns_per_entity. Please specify manually using --columns_per_entity"
......@@ -128,7 +180,7 @@ def handle_file(file, args):
print("Found columns_per_entity: %d" % columns_per_entity)
else:
columns_per_entity = args.columns_per_entity
columns_per_entity = int(args.columns_per_entity)
n_fish = len(pf.columns) // columns_per_entity
header_cols = len(pf.columns) % n_fish
......@@ -167,113 +219,48 @@ def handle_file(file, args):
+ f * columns_per_entity : header_cols
+ (f + 1) * columns_per_entity,
]
poses[:, f] = f_cols[:, [args.xcol, args.ycol, args.oricol]].astype(
np.float32
)
poses[:, f] = f_cols[
:, [int(args.xcol), int(args.ycol), int(args.oricol)]
].astype(np.float32)
if not args.disable_centering:
poses[:, :, :2] -= args.world_size / 2 # center world around 0,0
poses[:, :, 1] *= -1 # flip y axis
poses[:, :, 2] *= 2 * np.pi / 360 # convert to radians
poses[:, :, 2] -= np.pi # rotate 180 degrees
def handle_switches(poses, supress=None):
"""Find and handle switches in the data.
Switches are defined as a fish that is not at the same position as it was in the previous frame.
If a switch is found, the fish is moved to the position of the fish that was at the same position in the previous frame.
If multiple switches are found, the fish that is closest to the position of the fish that was at the same position in the previous frame is moved.
Args:
poses (np.ndarray): Array of shape (n_frames, n_fish, 3) containing the poses of the fish.
supress (list): List of frames to ignore.
Returns:
np.ndarray: Array of shape (n_frames, n_fish, 3) containing the poses of the fish with switches handled.
"""
all_switches = []
last_poses = np.copy(poses[0])
assert not np.any(np.isnan(last_poses)), "Error: NaN in first frame"
for t in range(1, poses.shape[0]):
if not np.any(np.isnan(poses[t])) and (
supress is None or t not in supress
):
distances = get_distances(poses[t], last_poses)
switches = {}
for i in range(n_fish):
if np.argmin(distances[i]) != i:
switches[i] = np.argmin(distances[i])
if len(switches) > 0:
print(f"Switches at time {t}: {switches}")
if sorted(switches.keys()) == sorted(switches.values()):
print("Switches are one-to-one")
for i in switches:
print(list(switches.keys()), list(switches.values()))
connections = np.arange(n_fish, dtype=int)
for k, v in switches.items():
connections[v] = k
else:
print("Attempting to fix switches...")
smallest_distances = np.argsort(distances.flatten())
connections = np.empty(n_fish)
connections[:] = np.nan
print(distances)
print(
smallest_distances[0] // n_fish,
smallest_distances[0] % n_fish,
)
for sd in smallest_distances:
if np.isnan(connections[sd // n_fish]) and not np.any(
connections == sd % n_fish
):
connections[sd // n_fish] = sd % n_fish
if not np.any(np.isnan(connections)):
break
assert np.sum(connections) == np.sum(
range(n_fish)
) # Simple check to see if all fish are connected
connections = connections.astype(int)
print(f"Connections: {connections}")
poses[t:] = poses[t:, connections]
all_switches.append(t)
last_poses = poses[t]
return poses, all_switches
supress = []
all_switches = None
if not args.disable_fix_switches:
for run in range(20):
print("RUN ", run)
switched_poses, all_switches = handle_switches(poses)
print(all_switches)
print("All switches: ", all_switches)
diff = np.diff(all_switches, axis=0)
new_supress = np.array(all_switches)[np.where(diff < 10)]
new_supress = np.array(all_switches)[
np.where(diff < args.min_timesteps_between_switches)
]
new_supress = [n for n in new_supress if n not in supress]
if len(new_supress) == 0:
print("No need to supress more.")
break
print(supress)
supress.extend(new_supress)
print("supressing in the next runn: ", supress)
distances = get_distances(switched_poses)
switched_poses[
np.where(np.diagonal(distances > 1, axis1=1, axis2=2))
] = np.nan
# distances = get_distances(switched_poses)
# switched_poses[
# np.where(np.diagonal(distances > 1, axis1=1, axis2=2))
# ] = np.nan
poses = switched_poses
for f in range(n_fish):
iof.create_entity("fish", poses[:, f])
if all_switches is not None:
iof.attrs["switches"] = all_switches
# assert np.all(
# poses[np.logical_not(np.isnan(poses[:, 0])), 0] >= 0
......@@ -299,6 +286,8 @@ parser.add_argument("--oricol", default=7)
parser.add_argument("--world_size", default=100)
parser.add_argument("--frequency", default=25)
parser.add_argument("--disable_fix_switches", action="store_true")
parser.add_argument("--disable_centering", action="store_true")
parser.add_argument("--min_timesteps_between_switches", type=int, default=0)
args = parser.parse_args()
for path in args.path:
......
import pandas as pd
import numpy as np
# python convert_from_csv.py test.csv --sep , --header 1 --columns_per_entity 3 --xcol 0 --ycol 1 --oricol 2 --disable_fix_switches
df = pd.DataFrame(columns=["header", "x1", "y1", "o1", "x2", "y2", "o2"])
df["header"] = 0
df["x1"] = np.linspace(20, 80, 60)
df["x2"] = np.linspace(20, 80, 60)
df["y1"] = 40
df["y2"] = 60
# Switch y1 and y2 after 50 samples
df.loc[20:30, "y1"] = 60
df.loc[20:30, "y2"] = 40
# Set x1 and y1 to np.nan for 5 samples
df.loc[40:45, "x1"] = np.nan
df.loc[40:45, "y1"] = np.nan
df.loc[46:50, "y1"] = 60
df.loc[46:50, "y2"] = 40
df.loc[51:55, "x2"] = np.nan
df.loc[51:55, "y2"] = np.nan
df["o1"] = 180
df["o2"] = 180
print(df)
df.to_csv("test.csv", index=False)
import robofish.io
import pandas as pd
import numpy as np
def print_file(args):
# Open the robofish.io.File
with robofish.io.File(args.filename, "r") as f:
timesteps = (
args.max_timesteps + args.skip_timesteps
if args.max_timesteps is not None
else f.entity_positions.shape[1]
)
n_fish = len(f.entities)
df = pd.DataFrame(
columns=np.concatenate([[f"x{i}", f"y{i}", f"o{i}"] for i in range(n_fish)])
)
for i in range(n_fish):
df[f"x{i}"] = f.entity_positions[i][args.skip_timesteps : timesteps, 0]
df[f"y{i}"] = f.entity_positions[i][args.skip_timesteps : timesteps, 1]
df[f"o{i}"] = f.entity_orientations_rad[i][
args.skip_timesteps : timesteps, 0
]
print(df)
df.to_csv("cutout.csv", index=False)
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Print a file")
parser.add_argument("filename", help="File to print")
parser.add_argument("--skip_timesteps", type=int, default=0, help="Skip timesteps")
parser.add_argument("--max_timesteps", type=int, default=None, help="Max timesteps")
args = parser.parse_args()
print_file(args)
......@@ -235,6 +235,7 @@ def render(args: argparse.Namespace = None) -> None:
"show_ids": False,
"render_goals": False,
"render_targets": False,
"highlight_switches": False,
"figsize": 10,
}
......
......@@ -1035,6 +1035,7 @@ class File(h5py.File):
"show_ids": False,
"render_goals": False,
"render_targets": False,
"highlight_switches": False,
"dpi": 200,
"figsize": 10,
}
......@@ -1086,16 +1087,20 @@ class File(h5py.File):
plt.plot([], [], linestyle="dotted", alpha=0.5, color="k", zorder=0)[0],
]
categories = [entity.attrs.get("category", None) for entity in self.entities]
entity_colors = [lines[entity].get_color() for entity in range(n_entities)]
entity_polygons = [
patches.Polygon(
shape_vertices(options["entity_scale"]),
edgecolor=edgecolor,
facecolor=color,
alpha=0.8,
)
for edgecolor, color in [
("k", "white") if category == "robot" else ("k", "k")
for category in categories
("k", "white")
if category == "robot"
else (entity_colors[entity], entity_colors[entity])
for entity, category in enumerate(categories)
]
]
......@@ -1197,6 +1202,8 @@ class File(h5py.File):
pbar = tqdm(range(n_frames))
def update(frame):
output_list = []
if "pbar" in locals().keys():
pbar.update(1)
pbar.refresh()
......@@ -1207,6 +1214,19 @@ class File(h5py.File):
file_frame = (frame * options["speedup"]) + frame_range[0]
this_pose = entity_poses[:, file_frame]
if options["highlight_switches"] and "switches" in self.attrs:
if any(
[
file_frame + i in self.attrs["switches"]
for i in range(options["speedup"])
]
):
ax.set_facecolor("lightgray")
else:
ax.set_facecolor("white")
output_list.append(ax)
if not options["fixed_view"]:
# Find the maximal distance between the entities in x or y direction
......@@ -1284,7 +1304,10 @@ class File(h5py.File):
raise Exception(
f"Frame is bigger than n_frames {file_frame} of {n_frames}"
)
return lines + entity_polygons + [border] + points + annotations
return (
output_list + lines + entity_polygons + [border] + points + annotations
)
print(f"Preparing to render n_frames: {n_frames}")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment