diff --git a/src/robofish/io/app.py b/src/robofish/io/app.py index 0efc84c86cf9f4a4fd986da6ba8d4ca863e47665..21f8365c5476a748959aaf16cfc4107bcd5e1cef 100644 --- a/src/robofish/io/app.py +++ b/src/robofish/io/app.py @@ -438,3 +438,7 @@ def update_world_shape(args: dict = None) -> None: pbar.update(1) print("Update finished.") + + +if __name__ == "__main__": + render() diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index 58ece8b137f2801349548f5eb0b88b8faee06008..f289c87eadde64d1c60d33a9bc68d5fac5396b74 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -16,34 +16,31 @@ # ----------------------------------------------------------- from __future__ import annotations -import robofish.io -from robofish.io.entity import Entity -import h5py - -import numpy as np +import datetime import logging -from typing import Iterable, Union, Tuple, List, Optional, Dict -from pathlib import Path +import platform import shutil -import datetime import tempfile -import uuid -import deprecation import types +import uuid import warnings +from pathlib import Path +from subprocess import run from textwrap import wrap -import platform +from typing import Dict, Iterable, List, Optional, Tuple, Union +import deprecation +import h5py import matplotlib -import matplotlib.pyplot as plt import matplotlib.cm -from matplotlib import animation -from matplotlib import patches - +import matplotlib.pyplot as plt +import numpy as np +from matplotlib import animation, patches from tqdm.auto import tqdm -from subprocess import run +import robofish.io +from robofish.io.entity import Entity # Remember: Update docstring when updating these two global variables default_format_version = np.array([1, 0], dtype=np.int32) @@ -1024,19 +1021,27 @@ class File(h5py.File): f"ffmpeg is required to store videos. Please install it.\n{e}" ) - def shape_vertices(scale=1) -> np.ndarray: - base_shape = np.array( - [ - (+3.0, +0.0), - (+2.5, +1.0), - (+1.5, +1.5), - (-2.5, +1.0), - (-4.5, +0.0), - (-2.5, -1.0), - (+1.5, -1.5), - (+2.5, -1.0), - ] - ) + def shape_vertices(scale=1, has_orientation=True) -> np.ndarray: + if has_orientation: + base_shape = np.array( + [ + (+3.0, +0.0), + (+2.5, +1.0), + (+1.5, +1.5), + (-2.5, +1.0), + (-4.5, +0.0), + (-2.5, -1.0), + (+1.5, -1.5), + (+2.5, -1.0), + ] + ) + else: + # show a circle if there is no orientation + num_vertices = 8 + radius = 2 + angles = np.linspace(0, 2 * np.pi, num_vertices, endpoint=False) + base_shape = np.column_stack((np.cos(angles), np.sin(angles))) * radius + return base_shape * scale default_options = { @@ -1111,15 +1116,22 @@ class File(h5py.File): entity_polygons = [ patches.Polygon( - shape_vertices(options["entity_scale"]), + shape_vertices( + options["entity_scale"], + has_orientation=entity.orientations_rad[:].std() > 0, + ), edgecolor=edgecolor, facecolor=color, alpha=0.8, ) - for edgecolor, color in [ - ("k", "white") + for edgecolor, color, entity in [ + ("k", "white", self.entities[entity]) if category == "robot" - else (entity_colors[entity], entity_colors[entity]) + else ( + entity_colors[entity], + entity_colors[entity], + self.entities[entity], + ) for entity, category in enumerate(categories) ] ] @@ -1152,13 +1164,33 @@ class File(h5py.File): xv, yv = np.meshgrid(x, y) points = [ - plt.scatter([], [], marker="x", color="k"), plt.plot([], [], linestyle="dotted", alpha=0.5, color="k", zorder=0)[0], plt.scatter(xv, yv, c="gray", s=1.5), ] + goal_components = [ + plt.scatter([], [], marker="x", color="k"), # goal + plt.plot([], [], linestyle="dotted", alpha=0.5, color="k")[0], + ] + border = patches.Polygon(border_vertices.T, facecolor="w", zorder=-1) + goal_threshold_cm = self.attrs.get("change_goal_threshold", 0.05) * 100 + + def update_goal(file_frame: int): + num_vertices = 32 + goal = get_goal(file_frame) + if goal is not None: + goal_components[0].set_offsets(goal) + angles = np.linspace(0, 2 * np.pi, num_vertices) + shape = ( + np.column_stack((np.cos(angles), np.sin(angles))) + * goal_threshold_cm + ) + goal_components[1].set_data( + goal[0] + shape[:, 0], goal[1] + shape[:, 1] + ) + def title(file_frame: int) -> str: """Search for datasets containing text for displaying it in the video""" output = [] @@ -1308,12 +1340,10 @@ class File(h5py.File): annotations[i].xy = this_pose[i, :2] if options["render_goals"]: - goal = get_goal(file_frame) - if goal is not None: - points[0].set_offsets(goal) + update_goal(file_frame) if options["render_targets"]: - points[1].set_data(get_target(file_frame)) + points[0].set_data(get_target(file_frame)) poses_trails = entity_poses[ :, max(0, file_frame - options["trail"]) : file_frame