Skip to content
Snippets Groups Projects
Commit 3c053452 authored by mhocke's avatar mhocke
Browse files

Refactor render cli

- use dashes in flag names like this: --some-flag (as used by python, pip,
  git...), while keeping old version --some_flag for backwards compatibility
- rename --cut_frames_start and --cut_frames_end to --start-frame and
  --end-frame (old flag still works)
- remove `options` and `default_options` from `File.render` (Options and
  default values needed to be added in three different places.)
parent d5ac5316
No related branches found
No related tags found
1 merge request!50Merge dev_mathis with multiple changes
Pipeline #66404 failed
......@@ -179,7 +179,7 @@ def render_file(kwargs: Dict) -> None:
kwargs (Dict, optional): A dictionary containing the arguments for the render function.
"""
with robofish.io.File(path=kwargs["path"]) as f:
f.render(**kwargs)
f.render(**{k:v for k,v in kwargs.items() if k not in ("path", "reference_track")})
def overwrite_user_configs() -> None:
......@@ -205,24 +205,42 @@ def render(args: argparse.Namespace = None) -> None:
"path",
type=str,
nargs="+",
help="The path to one file.",
help="The path to one or more files. "
"Multiple files get rendered in parallel in multiple windows.",
)
parser.add_argument(
"--video-path",
"-vp",
"--video_path",
"--video_path", # backwards compatibility
default=None,
type=str,
help="Path to save the video to (mp4). If a path is given, the animation won't be played.",
)
parser.add_argument(
"--reference_track",
"--reference-track",
"--reference_track", # backwards compatibility
action="store_true",
help="If true, the reference track will be rendered in parallel.",
default=False,
)
parser.add_argument(
"--start-frame",
"--cut_frames_start", # backwards compatibility
type=int,
default=0,
help="Skip all frames before this frame",
)
parser.add_argument(
"--end-frame",
"--cut_frames_end", # backwards compatibility
type=int,
default=0,
help="Last frame to render. Render until the end if set to 0.",
)
default_options = {
"linewidth": 2,
"speedup": 1,
......@@ -231,8 +249,6 @@ def render(args: argparse.Namespace = None) -> None:
"fixed_view": False,
"view_size": 60,
"slow_view": 0.8,
"cut_frames_start": 0,
"cut_frames_end": 0,
"show_text": False,
"show_ids": False,
"render_goals": False,
......@@ -246,17 +262,17 @@ def render(args: argparse.Namespace = None) -> None:
for key, value in default_options.items():
if isinstance(value, bool):
parser.add_argument(
f"--{key}",
*sorted({f"--{key.replace('_','-')}", f"--{key}"}),
default=value,
action="store_true" if value is False else "store_false",
help=f"Optional setter for video option {key}.",
help=f"{key.replace('_', ' ').capitalize()}",
)
else:
parser.add_argument(
f"--{key}",
*sorted({f"--{key.replace('_','-')}", f"--{key}"}),
default=value,
type=type(value),
help=f"Optional video option {key} with type {type(value)}.",
help=f"{key.replace('_', ' ').capitalize()}: {type(value).__name__}",
)
if args is None:
......
......@@ -983,33 +983,37 @@ class File(h5py.File):
return ax
def render(self, video_path: Optional[Union[str, Path]] = None, **kwargs: Dict) -> None:
def render(
self,
video_path: Optional[Union[str, Path]] = None,
linewidth: int = 2,
speedup: int = 1,
trail: int = 100,
entity_scale: float = 0.2,
fixed_view: bool = False,
view_size: int = 50,
margin: int = 15,
slow_view: float = 0.8,
slow_zoom: float = 0.95,
start_frame: bool = None,
end_frame: bool = None,
show_text: bool = False,
show_ids: bool = False,
render_goals: bool = False,
render_targets: bool = False,
render_zones: bool = False,
render_swarm_center: bool = False,
highlight_switches: bool = False,
dpi: int = 200,
figsize: int = 10,
) -> None:
"""Render a video of the file.
The tracks are rendered in a video using matplotlib.animation.FuncAnimation.
Additional options can be given as keyword arguments and overwrite the following default values:
- linewidth: 2,
- speedup: 1,
- trail: 100,
- entity_scale: 0.2,
- fixed_view: False,
- view_size: 50,
- margin: 15,
- slow_view: 0.8,
- slow_zoom: 0.95,
- cut_frames_start: None,
- cut_frames_end: None,
- show_text: False,
- show_ids: False,
- render_goals: False,
- render_targets: False,
- dpi: 200,
- figsize: 10,
Args:
video_path (Union[str, Path], optional): Path to save the video to. If None is given, the video is not saved.
kwargs (Dict): Additional arguments passed to the plot function.
...
Raises:
Exception: If ffmpeg is not installed and video_path is not None.
"""
......@@ -1046,37 +1050,7 @@ class File(h5py.File):
return base_shape * scale
default_options = {
"linewidth": 2,
"speedup": 1,
"trail": 100,
"entity_scale": 0.2,
"fixed_view": False,
"view_size": 50,
"margin": 15,
"slow_view": 0.8,
"slow_zoom": 0.95,
"cut_frames_start": None,
"cut_frames_end": None,
"show_text": False,
"show_ids": False,
"render_goals": False,
"render_targets": False,
"render_zones": False,
"render_swarm_center": False,
"highlight_switches": False,
"dpi": 200,
"figsize": 10,
}
options = {
key: kwargs[key] if key in kwargs else default_options[key]
for key in default_options.keys()
}
fig, ax = plt.subplots(
figsize=(options["figsize"], options["figsize"]), num=Path(self.path).name
)
fig, ax = plt.subplots(figsize=(figsize, figsize), num=Path(self.path).name)
ax.set_aspect("equal")
ax.set_facecolor("gray")
plt.tight_layout(pad=0.05)
......@@ -1084,10 +1058,7 @@ class File(h5py.File):
categories = [entity.attrs.get("category", None) for entity in self.entities]
n_fish = len([c for c in categories if c == "organism"])
lines = [
plt.plot([], [], lw=options["linewidth"], zorder=0)[0]
for _ in range(n_entities)
]
lines = [plt.plot([], [], lw=linewidth, zorder=0)[0] for _ in range(n_entities)]
entity_colors = [lines[entity].get_color() for entity in range(n_entities)]
fish_colors = [
color
......@@ -1098,7 +1069,7 @@ class File(h5py.File):
zone_sizes = [
(
get_zone_sizes(self.attrs.get("guppy_model_rollout", ""))
if options["render_zones"]
if render_zones
else {}
)
for _ in range(n_fish)
......@@ -1117,7 +1088,7 @@ class File(h5py.File):
for zone in zones_fish:
ax.add_artist(zone)
zones_flat.append(zone)
if options["render_swarm_center"] and n_fish > 0:
if render_swarm_center and n_fish > 0:
swarm_center_position = np.stack(
[e.positions for e in self.entities if e.category == "organism"]
).mean(axis=0)
......@@ -1127,7 +1098,7 @@ class File(h5py.File):
else:
swarm_center = []
annotations = []
if options["show_ids"]:
if show_ids:
ids = None
if all(["individual_id" in e.attrs for e in self.entities]):
print("Getting IDs from individuals")
......@@ -1150,7 +1121,7 @@ class File(h5py.File):
else:
# Could not read individual ids
warnings.warn("Could not read individual ids.")
options["show_ids"] = False
show_ids = False
points = [
plt.scatter([], [], marker="x", color="k"),
......@@ -1160,7 +1131,7 @@ class File(h5py.File):
entity_polygons = [
patches.Polygon(
shape_vertices(
options["entity_scale"],
entity_scale,
has_orientation=entity.orientations_rad[:].std() > 0,
),
edgecolor=edgecolor,
......@@ -1284,16 +1255,16 @@ class File(h5py.File):
n_frames = self.entity_poses.shape[1]
if options["cut_frames_end"] == 0 or options["cut_frames_end"] is None:
options["cut_frames_end"] = n_frames
if options["cut_frames_start"] is None:
options["cut_frames_start"] = 0
if end_frame == 0 or end_frame is None:
end_frame = n_frames
if start_frame is None:
start_frame = 0
frame_range = (
options["cut_frames_start"],
min(n_frames, options["cut_frames_end"]),
start_frame,
min(n_frames, end_frame),
)
n_frames = int((frame_range[1] - frame_range[0]) / options["speedup"])
n_frames = int((frame_range[1] - frame_range[0]) / speedup)
for skip in range(20):
start_pose = self.entity_poses_rad[:, frame_range[0] + skip]
......@@ -1306,7 +1277,7 @@ class File(h5py.File):
self.middle_of_swarm = np.mean(start_pose, axis=0)
min_view = np.max((np.max(start_pose, axis=0) - np.min(start_pose, axis=0))[:2])
self.view_size = np.max([options["view_size"], min_view + options["margin"]])
self.view_size = np.max([view_size, min_view + margin])
pbar = tqdm(range(n_frames)) if video_path is not None else None
......@@ -1320,14 +1291,14 @@ class File(h5py.File):
if frame < n_frames:
entity_poses = self.entity_poses_rad
file_frame = (frame * options["speedup"]) + frame_range[0]
file_frame = (frame * speedup) + frame_range[0]
this_pose = entity_poses[:, file_frame]
if options["highlight_switches"] and "switches" in self.attrs:
if highlight_switches and "switches" in self.attrs:
if any(
[
file_frame + i in self.attrs["switches"]
for i in range(options["speedup"])
for i in range(speedup)
]
):
ax.set_facecolor("lightgray")
......@@ -1336,7 +1307,7 @@ class File(h5py.File):
ax.set_facecolor("white")
output_list.append(ax)
if not options["fixed_view"]:
if not fixed_view:
# Find the maximal distance between the entities in x or y direction
min_view = np.nanmax(
(np.nanmax(this_pose, axis=0) - np.nanmin(this_pose, axis=0))[
......@@ -1344,26 +1315,19 @@ class File(h5py.File):
]
)
new_view_size = np.nanmax(
[options["view_size"], min_view + options["margin"]]
)
new_view_size = np.nanmax([view_size, min_view + margin])
if (
not np.any(np.isnan(min_view))
and not np.any(np.isnan(new_view_size))
and not np.any(np.isnan(this_pose))
):
self.middle_of_swarm = options[
"slow_view"
] * self.middle_of_swarm + (
1 - options["slow_view"]
) * np.nanmean(
this_pose, axis=0
)
self.middle_of_swarm = slow_view * self.middle_of_swarm + (
1 - slow_view
) * np.nanmean(this_pose, axis=0)
self.view_size = (
options["slow_zoom"] * self.view_size
+ (1 - options["slow_zoom"]) * new_view_size
slow_zoom * self.view_size + (1 - slow_zoom) * new_view_size
)
ax.set_xlim(
......@@ -1375,25 +1339,23 @@ class File(h5py.File):
self.middle_of_swarm[1] + self.view_size / 2,
)
if options["show_text"]:
if show_text:
ax.set_title(title(file_frame))
if options["show_ids"]:
if show_ids:
for i in range(n_entities):
annotations[i].set_position(
(this_pose[i, 0] + 1, this_pose[i, 1])
)
annotations[i].xy = this_pose[i, :2]
if options["render_goals"]:
if render_goals:
update_goal(file_frame)
if options["render_targets"]:
if render_targets:
points[0].set_data(get_target(file_frame))
poses_trails = entity_poses[
:, max(0, file_frame - options["trail"]) : file_frame
]
poses_trails = entity_poses[:, max(0, file_frame - trail) : file_frame]
for i_entity in range(n_entities):
if categories[i_entity] == "organism":
for zone in zones[i_entity]:
......@@ -1402,7 +1364,7 @@ class File(h5py.File):
this_pose[i_entity, 1],
)
if options["render_swarm_center"]:
if render_swarm_center:
swarm_center[0].set_offsets(swarm_center_position[file_frame])
for i_entity in range(n_entities):
......@@ -1455,7 +1417,7 @@ class File(h5py.File):
if video_path is not None:
writervideo = animation.FFMpegWriter(fps=self.frequency)
ani.save(video_path, writer=writervideo, dpi=options["dpi"])
ani.save(video_path, writer=writervideo, dpi=dpi)
plt.close()
else:
plt.show()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment