Skip to content
Snippets Groups Projects
Commit 33228dc5 authored by mhocke's avatar mhocke
Browse files

Render swarm center

parent 07869025
No related branches found
No related tags found
1 merge request!50Merge dev_mathis with multiple changes
......@@ -236,6 +236,8 @@ def render(args: argparse.Namespace = None) -> None:
"show_ids": False,
"render_goals": False,
"render_targets": False,
"render_zones": False,
"render_swarm_center": False,
"highlight_switches": False,
"figsize": 10,
}
......
......@@ -1060,6 +1060,8 @@ class File(h5py.File):
"show_ids": False,
"render_goals": False,
"render_targets": False,
"render_zones": False,
"render_swarm_center": False,
"highlight_switches": False,
"dpi": 200,
"figsize": 10,
......@@ -1081,6 +1083,26 @@ class File(h5py.File):
plt.plot([], [], lw=options["linewidth"], zorder=0)[0]
for _ in range(n_entities)
]
zone_sizes = (
get_zone_sizes(self.attrs.get("guppy_model_rollout", ""))
if options["render_zones"]
else {}
)
zones = [
plt.Circle((0, 0), zone_size, color="k", alpha=0.5, fill=False)
for zone_size in zone_sizes.values()
]
if options["render_swarm_center"] and len([e for e in self.entities if e.category == "organism"]) > 0:
swarm_center_position = np.stack(
[e.positions for e in self.entities if e.category == "organism"]
).mean(axis=0)
swarm_center = [
plt.scatter([], [], marker=".", color="k"),
]
else:
swarm_center = []
for zone in zones:
ax.add_artist(zone)
annotations = []
if options["show_ids"]:
ids = None
......@@ -1233,7 +1255,7 @@ class File(h5py.File):
for e_poly in entity_polygons:
ax.add_patch(e_poly)
ax.add_patch(border)
return lines + entity_polygons + [border] + points
return lines + entity_polygons + [border] + points + zones + swarm_center
n_frames = self.entity_poses.shape[1]
......@@ -1348,6 +1370,15 @@ class File(h5py.File):
poses_trails = entity_poses[
:, max(0, file_frame - options["trail"]) : file_frame
]
for zone in zones:
zone.center = (
this_pose[0, 0],
this_pose[0, 1],
)
if options["render_swarm_center"]:
swarm_center[0].set_offsets(swarm_center_position[file_frame])
for i_entity in range(n_entities):
lines[i_entity].set_data(
poses_trails[i_entity, :, 0], poses_trails[i_entity, :, 1]
......@@ -1366,7 +1397,14 @@ class File(h5py.File):
)
return (
output_list + lines + entity_polygons + [border] + points + annotations
output_list
+ lines
+ entity_polygons
+ [border]
+ points
+ annotations
+ zones
+ swarm_center
)
print(f"Preparing to render n_frames: {n_frames}")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment