diff --git a/src/robofish/io/app.py b/src/robofish/io/app.py index 0518dedd81d5f9c4842583d1e687b14d7e353d4a..269d0116817b1943f09ad945464b55e112d68302 100644 --- a/src/robofish/io/app.py +++ b/src/robofish/io/app.py @@ -130,6 +130,8 @@ def render(args=None): "slow_view": 0.8, "cut_frames": 0, "show_text": False, + "render_goals": False, + "render_targets": False, } for key, value in default_options.items(): diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index 0637e8961c5125d92d90979afecb5243e9e150bf..1670f688f46697ed67a7ddaeb7159db0aa290c8a 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -22,7 +22,7 @@ import h5py import numpy as np import logging -from typing import Iterable, Union, Tuple, List +from typing import Iterable, Union, Tuple, List, Optional from pathlib import Path import shutil import datetime @@ -840,6 +840,8 @@ class File(h5py.File): "slow_zoom": 0.95, "cut_frames": None, "show_text": False, + "render_goals": False, + "render_targets": False, } options = { @@ -855,6 +857,10 @@ class File(h5py.File): plt.plot([], [], lw=options["linewidth"], zorder=0)[0] for _ in range(n_entities) ] + points = [ + plt.scatter([], [], marker="x", color="k"), + plt.plot([], [], linestyle="dotted", alpha=0.5, color="k", zorder=0)[0], + ] categories = [entity.attrs.get("category", None) for entity in self.entities] entity_polygons = [ patches.Polygon(shape_vertices(options["entity_scale"]), facecolor=color) @@ -899,6 +905,28 @@ class File(h5py.File): output.append(f"{e.name}.{key}='{val[file_frame].decode()}'") return ", ".join(output) + def get_goal(file_frame: int) -> Optional[np.ndarray]: + """Return current goal of robot, if robot exists and has a goal.""" + goal = None + if "robot" in categories: + robot = self.entities[categories.index("robot")] + try: + goal = robot["goals"][file_frame] + except KeyError: + pass + if goal is not None and np.isnan(goal).any(): + goal = None + return goal + + def get_target(file_frame: int) -> Tuple[List, List]: + """Return line points from robot to target""" + if "robot" in categories: + robot = self.entities[categories.index("robot")] + rpos = robot["positions"][file_frame] + target = robot["targets"][file_frame] + return [rpos[0], target[0]], [rpos[1], target[1]] + return [], [] + def init(): ax.set_xlim(-0.5 * self.world_size[0], 0.5 * self.world_size[0]) ax.set_ylim(-0.5 * self.world_size[1], 0.5 * self.world_size[1]) @@ -910,7 +938,7 @@ class File(h5py.File): for e_poly in entity_polygons: ax.add_patch(e_poly) ax.add_patch(border) - return lines + entity_polygons + [border] + return lines + entity_polygons + [border] + points n_frames = self.entity_poses.shape[1] if options["cut_frames"] > 0: @@ -966,6 +994,14 @@ class File(h5py.File): if options["show_text"]: ax.set_title(title(file_frame)) + if options["render_goals"]: + goal = get_goal(file_frame) + if goal is not None: + points[0].set_offsets(goal) + + if options["render_targets"]: + points[1].set_data(get_target(file_frame)) + poses_trails = entity_poses[ :, max(0, file_frame - options["trail"]) : file_frame ] @@ -985,7 +1021,7 @@ class File(h5py.File): raise Exception( f"Frame is bigger than n_frames {file_frame} of {n_frames}" ) - return lines + entity_polygons + [border] + return lines + entity_polygons + [border] + points print(f"Preparing to render n_frames: {n_frames}")