Skip to content
Snippets Groups Projects
Commit 688bfc48 authored by Andi Gerken's avatar Andi Gerken
Browse files

Added quiver plot in evaluation.

Fixed Bugs with unfinished calculations in files.
parent c77fa515
No related branches found
No related tags found
3 merge requests!37Added calculation of individual ids,!34Added script to update individual ids, added docstrings massively,!32Added quiver plot in evaluation.
Pipeline #50064 failed
repos: repos:
- repo: https://github.com/psf/black - repo: https://github.com/psf/black
rev: 21.6b0 rev: 22.3.0
hooks: hooks:
- id: black - id: black
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/pre-commit/pre-commit-hooks
......
...@@ -68,6 +68,7 @@ setup( ...@@ -68,6 +68,7 @@ setup(
"Programming Language :: Python :: 3 :: Only", "Programming Language :: Python :: 3 :: Only",
"Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
], ],
python_requires=">=3.6", python_requires=">=3.6",
packages=[f"robofish.{p}" for p in find_packages("src/robofish")], packages=[f"robofish.{p}" for p in find_packages("src/robofish")],
......
...@@ -31,6 +31,7 @@ def function_dict(): ...@@ -31,6 +31,7 @@ def function_dict():
"follow_iid": base.evaluate_follow_iid, "follow_iid": base.evaluate_follow_iid,
"individual_speed": base.evaluate_individual_speed, "individual_speed": base.evaluate_individual_speed,
"individual_iid": base.evaluate_individual_iid, "individual_iid": base.evaluate_individual_iid,
"quiver": base.evaluate_quiver,
"all": base.evaluate_all, "all": base.evaluate_all,
} }
......
...@@ -22,6 +22,7 @@ from scipy import stats ...@@ -22,6 +22,7 @@ from scipy import stats
from tqdm import tqdm from tqdm import tqdm
import inspect import inspect
import random import random
import warnings
def evaluate_speed( def evaluate_speed(
...@@ -420,6 +421,145 @@ def evaluate_tank_position( ...@@ -420,6 +421,145 @@ def evaluate_tank_position(
return fig return fig
def evaluate_quiver(
paths: Iterable[Union[str, Path]],
labels: Iterable[str] = None,
predicate=None,
poses_from_paths=None,
speeds_turns_from_paths=None,
file_settings=None,
max_files=None,
bins=25,
):
"""Plot the flow of movement in the files."""
import torch
if poses_from_paths is None:
poses_from_paths, file_settings = utils.get_all_poses_from_paths(paths)
if speeds_turns_from_paths is None:
speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths(
paths, "speeds_turns", max_files=max_files
)
# print(poses_from_paths.shape)
if len(poses_from_paths) != 1:
warnings.warn(
"NotImplemented: Only one path is supported for now in quiver plot."
)
poses_from_paths = poses_from_paths[:1]
poses_from_paths = np.array(poses_from_paths)
speeds_turns_from_paths = np.array(speeds_turns_from_paths)
all_poses = torch.tensor(poses_from_paths)[0, :, :, :-1].reshape((-1, 4))
speed = torch.tensor(speeds_turns_from_paths)[0, ..., 0].flatten()
all_poses_speed = torch.clone(all_poses)
all_poses_speed[:, 2] *= speed
all_poses_speed[:, 3] *= speed
tank_b = torch.linspace(
torch.min(all_poses[:, :2]), torch.max(all_poses[:, :2]), bins
)
nb = len(tank_b)
poses_buckets = torch.bucketize(all_poses[:, :2], tank_b)
tank_directions = np.zeros((nb, nb, 2))
tank_directions_speed = np.zeros_like(tank_directions)
tank_count = np.zeros((nb, nb))
for x in tqdm(range(nb)):
for y in range(nb):
d = torch.where(
torch.logical_and(poses_buckets[:, 0] == x, poses_buckets[:, 1] == y)
)[0]
if len(d) > 0:
# print("d", d)
# d_v = torch.stack((torch.cos(all_poses[d,2]), torch.sin(all_poses[d,2])), dim=1)
# print("dv",d_v.shape)
# tank_directions[x, y] = torch.mean(all_poses[d, 2:], dim=0)
tank_directions_speed[x, y] = torch.mean(all_poses_speed[d, 2:], dim=0)
# print(tank_directions[x,y] - tank_directions_speed[x,y])
tank_count[x, y] = len(d)
from fish_models.models.pascals_lstms.attribution import SocialVectors
sv = SocialVectors(torch.tensor(poses_from_paths[0]))
sv_r = torch.tensor(sv.social_vectors_without_focal_zeros)[:, :, :-1].reshape(
(-1, 3)
) # [:1000]
sv_r = torch.cat(
(sv_r[:, :2], torch.cos(sv_r[:, 2:]), torch.sin(sv_r[:, 2:])), dim=1
)
sv_r_s = torch.clone(sv_r)
sv_r_s[:, 2] *= speed
sv_r_s[:, 3] *= speed
social_b = torch.linspace(-20, 20, bins)
nb = len(social_b)
poses_buckets = torch.bucketize(sv_r[:, :2], social_b)
social_directions = np.zeros((nb, nb, 2))
social_directions_speed = np.zeros_like(social_directions)
social_count = np.zeros((nb, nb))
for x in tqdm(range(nb)):
for y in range(nb):
d = torch.where(
torch.logical_and(poses_buckets[:, 0] == x, poses_buckets[:, 1] == y)
)[0]
if len(d) > 0:
# print("d", d)
# print("dv",d_v.shape)
social_directions[x, y] = torch.mean(sv_r[d, 2:], dim=0)
social_directions_speed[x, y] = torch.mean(sv_r_s[d, 2:], dim=0)
social_count[x, y] = len(d)
tank_xs, tank_ys = np.meshgrid(tank_b, tank_b)
social_xs, social_ys = np.meshgrid(social_b, social_b)
fig, axs = plt.subplots(1, 2, figsize=(20, 10))
axs[0].quiver(
tank_xs,
tank_ys,
tank_directions_speed[..., 0],
-tank_directions_speed[..., 1],
alpha=tank_count / (np.max(tank_count)),
)
axs[0].set_title("Avg Direction weighted by speed")
axs[1].quiver(
social_xs,
social_ys,
social_directions_speed[..., 0],
-social_directions_speed[..., 1],
alpha=0.3 + social_count * 0.7 / (np.max(social_count)),
)
axs[1].set_title("Avg Direction weighted by speed in social situations")
fish_size = 50 / 20
linewidth = 2
head = axs[1].add_patch(
plt.Circle(
xy=(0, 0),
radius=fish_size,
linewidth=linewidth,
# color="white",
edgecolor="black",
facecolor="white",
fill=True,
alpha=0.8,
)
)
(tail,) = axs[1].plot(
[-fish_size, -fish_size * 2],
[0, 0],
color="black",
linewidth=linewidth,
)
return fig
def evaluate_social_vector( def evaluate_social_vector(
paths: Iterable[Union[str, Path]], paths: Iterable[Union[str, Path]],
labels: Iterable[str] = None, labels: Iterable[str] = None,
......
...@@ -16,6 +16,7 @@ from robofish.io import utils ...@@ -16,6 +16,7 @@ from robofish.io import utils
import argparse import argparse
import logging import logging
import warnings import warnings
from tqdm.auto import tqdm
def print_file(args=None): def print_file(args=None):
...@@ -75,11 +76,15 @@ def update_calculated_data(args=None): ...@@ -75,11 +76,15 @@ def update_calculated_data(args=None):
assert len(files) > 0, f"No files found in path {args.path}." assert len(files) > 0, f"No files found in path {args.path}."
for fp in files: pbar = tqdm(files)
print(f"File {fp}") for fp in pbar:
try: try:
with robofish.io.File(fp, "r+", validate_poses_hash=False) as f: with robofish.io.File(fp, "r+", validate_poses_hash=False) as f:
f.update_calculated_data(verbose=True) if f.update_calculated_data(verbose=False):
pbar.set_description(f"File {fp} was updated")
else:
pbar.set_description(f"File {fp} was already up to date")
except Exception as e: except Exception as e:
warnings.warn(f"The file {fp} could not be updated.") warnings.warn(f"The file {fp} could not be updated.")
print(e) print(e)
......
...@@ -248,6 +248,7 @@ class Entity(h5py.Group): ...@@ -248,6 +248,7 @@ class Entity(h5py.Group):
return np.stack([speed, turn], axis=-1) return np.stack([speed, turn], axis=-1)
def update_calculated_data(self, verbose=False, force_update=False): def update_calculated_data(self, verbose=False, force_update=False):
changed = False
if ( if (
"poses_hash" not in self.attrs "poses_hash" not in self.attrs
or self.attrs["poses_hash"] != self.poses_hash or self.attrs["poses_hash"] != self.poses_hash
...@@ -258,8 +259,8 @@ class Entity(h5py.Group): ...@@ -258,8 +259,8 @@ class Entity(h5py.Group):
): ):
try: try:
self.attrs["poses_hash"] = self.poses_hash self.attrs["poses_hash"] = self.poses_hash
self.attrs["unfinished_calculations"] = True
if "orientations" in self: if "orientations" in self:
self.attrs["unfinished_calculations"] = True
ori_rad = self.calculate_orientations_rad() ori_rad = self.calculate_orientations_rad()
if "calculated_orientations_rad" in self: if "calculated_orientations_rad" in self:
del self["calculated_orientations_rad"] del self["calculated_orientations_rad"]
...@@ -274,6 +275,7 @@ class Entity(h5py.Group): ...@@ -274,6 +275,7 @@ class Entity(h5py.Group):
del self.attrs["unfinished_calculations"] del self.attrs["unfinished_calculations"]
if verbose: if verbose:
changed = True
print( print(
f"Updated calculated data for entity {self.name} with poses_hash {self.poses_hash}" f"Updated calculated data for entity {self.name} with poses_hash {self.poses_hash}"
) )
...@@ -291,6 +293,7 @@ class Entity(h5py.Group): ...@@ -291,6 +293,7 @@ class Entity(h5py.Group):
) )
assert self.attrs["poses_hash"] == self.poses_hash assert self.attrs["poses_hash"] == self.poses_hash
return changed
def calculate_orientations_rad(self): def calculate_orientations_rad(self):
ori_rad = utils.limit_angle_range( ori_rad = utils.limit_angle_range(
......
...@@ -516,8 +516,8 @@ class File(h5py.File): ...@@ -516,8 +516,8 @@ class File(h5py.File):
return entity_names return entity_names
def update_calculated_data(self, verbose=False): def update_calculated_data(self, verbose=False):
for e in self.entities: changed = any([e.update_calculated_data(verbose) for e in self.entities])
e.update_calculated_data(verbose) return changed
def clear_calculated_data(self, verbose=True): def clear_calculated_data(self, verbose=True):
"""Delete all calculated data from the files.""" """Delete all calculated data from the files."""
...@@ -829,6 +829,8 @@ class File(h5py.File): ...@@ -829,6 +829,8 @@ class File(h5py.File):
self, self,
ax=None, ax=None,
lw_distances=False, lw_distances=False,
lw=2,
ms=32,
figsize=None, figsize=None,
step_size=4, step_size=4,
c=None, c=None,
...@@ -876,7 +878,6 @@ class File(h5py.File): ...@@ -876,7 +878,6 @@ class File(h5py.File):
) )
else: else:
step_size = poses.shape[1] step_size = poses.shape[1]
line_width = 1
cmap = cm.get_cmap(cmap) cmap = cm.get_cmap(cmap)
...@@ -907,8 +908,7 @@ class File(h5py.File): ...@@ -907,8 +908,7 @@ class File(h5py.File):
for t in range(skip_timesteps, timesteps, step_size): for t in range(skip_timesteps, timesteps, step_size):
if lw_distances: if lw_distances:
lw = np.mean(line_width[t : t + step_size + 1]) lw = np.mean(line_width[t : t + step_size + 1])
else:
lw = 1
ax.plot( ax.plot(
poses[fish_id, t : t + step_size + 1, 0], poses[fish_id, t : t + step_size + 1, 0],
poses[fish_id, t : t + step_size + 1, 1], poses[fish_id, t : t + step_size + 1, 1],
...@@ -917,26 +917,29 @@ class File(h5py.File): ...@@ -917,26 +917,29 @@ class File(h5py.File):
) )
# Plotting outside of the figure to have the label # Plotting outside of the figure to have the label
ax.plot([550, 600], [550, 600], lw=5, c=this_c, label=fish_id) ax.plot([550, 600], [550, 600], lw=5, c=this_c, label=fish_id)
ax.scatter(
[poses[:, skip_timesteps, 0]], # ax.scatter(
[poses[:, skip_timesteps, 1]], # [poses[:, skip_timesteps, 0]],
marker="h", # [poses[:, skip_timesteps, 1]],
c="black", # marker="h",
s=32, # c="black",
label="Start", # s=ms,
zorder=5, # label="Start",
) # zorder=5,
# )
ax.scatter( ax.scatter(
[poses[:, max_timesteps, 0]], [poses[:, max_timesteps, 0]],
[poses[:, max_timesteps, 1]], [poses[:, max_timesteps, 1]],
marker="x", marker="x",
c="black", c="black",
s=32, s=ms,
label="End", label="End",
zorder=5, zorder=5,
) )
if legend: if legend and isinstance(legend, str):
ax.legend(loc="lower right") ax.legend(legend)
elif legend:
ax.legend()
ax.set_xlabel("x [cm]") ax.set_xlabel("x [cm]")
ax.set_ylabel("y [cm]") ax.set_ylabel("y [cm]")
...@@ -1088,6 +1091,7 @@ class File(h5py.File): ...@@ -1088,6 +1091,7 @@ class File(h5py.File):
return lines + entity_polygons + [border] + points return lines + entity_polygons + [border] + points
n_frames = self.entity_poses.shape[1] n_frames = self.entity_poses.shape[1]
if options["cut_frames_end"] == 0 or options["cut_frames_end"] is None: if options["cut_frames_end"] == 0 or options["cut_frames_end"] is None:
options["cut_frames_end"] = n_frames options["cut_frames_end"] = n_frames
frame_range = ( frame_range = (
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment