Skip to content
Snippets Groups Projects
Commit ec1b712a authored by mhocke's avatar mhocke
Browse files

Render goals and targets

parent 5a78c9f7
No related branches found
No related tags found
No related merge requests found
Pipeline #48638 passed
...@@ -130,6 +130,8 @@ def render(args=None): ...@@ -130,6 +130,8 @@ def render(args=None):
"slow_view": 0.8, "slow_view": 0.8,
"cut_frames": 0, "cut_frames": 0,
"show_text": False, "show_text": False,
"render_goals": False,
"render_targets": False,
} }
for key, value in default_options.items(): for key, value in default_options.items():
......
...@@ -22,7 +22,7 @@ import h5py ...@@ -22,7 +22,7 @@ import h5py
import numpy as np import numpy as np
import logging import logging
from typing import Iterable, Union, Tuple, List from typing import Iterable, Union, Tuple, List, Optional
from pathlib import Path from pathlib import Path
import shutil import shutil
import datetime import datetime
...@@ -837,6 +837,8 @@ class File(h5py.File): ...@@ -837,6 +837,8 @@ class File(h5py.File):
"slow_zoom": 0.95, "slow_zoom": 0.95,
"cut_frames": None, "cut_frames": None,
"show_text": False, "show_text": False,
"render_goals": False,
"render_targets": False,
} }
options = { options = {
...@@ -852,6 +854,10 @@ class File(h5py.File): ...@@ -852,6 +854,10 @@ class File(h5py.File):
plt.plot([], [], lw=options["linewidth"], zorder=0)[0] plt.plot([], [], lw=options["linewidth"], zorder=0)[0]
for _ in range(n_entities) for _ in range(n_entities)
] ]
points = [
plt.scatter([], [], marker="x", color="k"),
plt.plot([], [], linestyle="dotted", alpha=0.5, color="k", zorder=0)[0],
]
categories = [entity.attrs.get("category", None) for entity in self.entities] categories = [entity.attrs.get("category", None) for entity in self.entities]
entity_polygons = [ entity_polygons = [
patches.Polygon(shape_vertices(options["entity_scale"]), facecolor=color) patches.Polygon(shape_vertices(options["entity_scale"]), facecolor=color)
...@@ -896,6 +902,28 @@ class File(h5py.File): ...@@ -896,6 +902,28 @@ class File(h5py.File):
output.append(f"{e.name}.{key}='{val[file_frame].decode()}'") output.append(f"{e.name}.{key}='{val[file_frame].decode()}'")
return ", ".join(output) return ", ".join(output)
def get_goal(file_frame: int) -> Optional[np.ndarray]:
"""Return current goal of robot, if robot exists and has a goal."""
goal = None
if "robot" in categories:
robot = self.entities[categories.index("robot")]
try:
goal = robot["goals"][file_frame]
except KeyError:
pass
if goal is not None and np.isnan(goal).any():
goal = None
return goal
def get_target(file_frame: int) -> Tuple[List, List]:
"""Return line points from robot to target"""
if "robot" in categories:
robot = self.entities[categories.index("robot")]
rpos = robot["positions"][file_frame]
target = robot["targets"][file_frame]
return [rpos[0], target[0]], [rpos[1], target[1]]
return [], []
def init(): def init():
ax.set_xlim(-0.5 * self.world_size[0], 0.5 * self.world_size[0]) ax.set_xlim(-0.5 * self.world_size[0], 0.5 * self.world_size[0])
ax.set_ylim(-0.5 * self.world_size[1], 0.5 * self.world_size[1]) ax.set_ylim(-0.5 * self.world_size[1], 0.5 * self.world_size[1])
...@@ -907,7 +935,7 @@ class File(h5py.File): ...@@ -907,7 +935,7 @@ class File(h5py.File):
for e_poly in entity_polygons: for e_poly in entity_polygons:
ax.add_patch(e_poly) ax.add_patch(e_poly)
ax.add_patch(border) ax.add_patch(border)
return lines + entity_polygons + [border] return lines + entity_polygons + [border] + points
n_frames = self.entity_poses.shape[1] n_frames = self.entity_poses.shape[1]
if options["cut_frames"] > 0: if options["cut_frames"] > 0:
...@@ -963,6 +991,14 @@ class File(h5py.File): ...@@ -963,6 +991,14 @@ class File(h5py.File):
if options["show_text"]: if options["show_text"]:
ax.set_title(title(file_frame)) ax.set_title(title(file_frame))
if options["render_goals"]:
goal = get_goal(file_frame)
if goal is not None:
points[0].set_offsets(goal)
if options["render_targets"]:
points[1].set_data(get_target(file_frame))
poses_trails = entity_poses[ poses_trails = entity_poses[
:, max(0, file_frame - options["trail"]) : file_frame :, max(0, file_frame - options["trail"]) : file_frame
] ]
...@@ -982,7 +1018,7 @@ class File(h5py.File): ...@@ -982,7 +1018,7 @@ class File(h5py.File):
raise Exception( raise Exception(
f"Frame is bigger than n_frames {file_frame} of {n_frames}" f"Frame is bigger than n_frames {file_frame} of {n_frames}"
) )
return lines + entity_polygons + [border] return lines + entity_polygons + [border] + points
print(f"Preparing to render n_frames: {n_frames}") print(f"Preparing to render n_frames: {n_frames}")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment