diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index 0b1126130e5a3f9228fa6f1d88e1c937ccbe20a5..e70562afa0672c511328afc30eaa5b45d1710521 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -27,6 +27,7 @@ from pathlib import Path import shutil import datetime import tempfile +import socket import uuid import deprecation import types @@ -853,12 +854,19 @@ class File(h5py.File): for key in default_options.keys() } - fig, ax = plt.subplots(figsize=(10, 10)) + if options["show_text"]: + fig, axes = plt.subplots(ncols=2, figsize=(20, 10)) + ax = axes[0] + ax.set(adjustable="box", aspect="equal") + ax_text = axes[1] + ax_text.axis("off") + else: + fig, ax = plt.subplots(figsize=(10, 10)) ax.set_facecolor("gray") n_entities = len(self.entities) lines = [ - plt.plot([], [], lw=options["linewidth"], zorder=0)[0] + ax.plot([], [], lw=options["linewidth"], zorder=0)[0] for _ in range(n_entities) ] categories = [entity.attrs.get("category", None) for entity in self.entities] @@ -885,15 +893,22 @@ class File(h5py.File): ) xv, yv = np.meshgrid(x, y) - grid_points = plt.scatter(xv, yv, c="gray", s=1.5) + grid_points = ax.scatter(xv, yv, c="gray", s=1.5) - goal_point = plt.scatter([], [], marker="x", color="k") - target_line = plt.plot( + goal_point = ax.scatter([], [], marker="x", color="k") + target_line = ax.plot( [], [], linestyle="dotted", alpha=0.5, color="k", zorder=0 )[0] points = [grid_points, goal_point, target_line] - # border = plt.plot(border_vertices[0], border_vertices[1], "k") + if options["show_text"]: + label_top_right = ax_text.text(0, 0.95, "", fontsize=10) + label_bottom_right = ax_text.text(0, 0, "", fontsize=10) + labels = [label_top_right, label_bottom_right] + else: + labels = [] + + # border = ax.plot(border_vertices[0], border_vertices[1], "k") border = patches.Polygon(border_vertices.T, facecolor="w", zorder=-1) start_pose = self.entity_poses_rad[:, 0] @@ -902,14 +917,35 @@ class File(h5py.File): min_view = np.max((np.max(start_pose, axis=0) - np.min(start_pose, axis=0))[:2]) self.view_size = np.max([options["view_size"], min_view + options["margin"]]) - def title(file_frame: int) -> str: + def get_host_and_filename() -> str: + return f"{socket.gethostname()}:{self.path}\n\n" + + def get_text_datasets(file_frame: int) -> str: """Search for datasets containing text for displaying it in the video""" output = [] for e in self.entities: + entity_output = [] for key, val in e.items(): if val.dtype == object and type(val[0]) == bytes: - output.append(f"{e.name}.{key}='{val[file_frame].decode()}'") - return ", ".join(output) + entity_output.append(f"{key} = '{val[file_frame].decode()}'") + if entity_output: + output.append(e.name) + output += entity_output + return "\n".join(output) + + def get_attributes_as_text() -> str: + attrs = {} + for key, val in self.attrs.items(): + if key == "experiment_setup": + continue + attrs[key] = val + separator = "-" * 130 + "\n" + text = "" + text += separator + text += f"experiment_setup:\n{self.attrs['experiment_setup']}\n" + text += separator + text += "\n".join([f"{key}: {val}" for key, val in attrs.items()]) + return text def get_goal(file_frame: int) -> Optional[np.ndarray]: """Return current goal of robot, if robot exists and has a goal.""" @@ -944,7 +980,7 @@ class File(h5py.File): for e_poly in entity_polygons: ax.add_patch(e_poly) ax.add_patch(border) - return lines + entity_polygons + [border] + points + return lines + entity_polygons + [border] + points + labels n_frames = self.entity_poses.shape[1] if options["cut_frames"] > 0: @@ -998,7 +1034,10 @@ class File(h5py.File): self.middle_of_swarm[1] + self.view_size / 2, ) if options["show_text"]: - ax.set_title(title(file_frame)) + label_top_right.set_text( + get_host_and_filename() + get_text_datasets(file_frame) + ) + label_bottom_right.set_text(get_attributes_as_text()) if options["render_goals"]: goal = get_goal(file_frame) @@ -1027,7 +1066,7 @@ class File(h5py.File): raise Exception( f"Frame is bigger than n_frames {file_frame} of {n_frames}" ) - return lines + entity_polygons + [border] + points + return lines + entity_polygons + [border] + points + labels print(f"Preparing to render n_frames: {n_frames}")