Skip to content
Snippets Groups Projects
Commit 8097ad54 authored by Andi Gerken's avatar Andi Gerken
Browse files

Possibility to get an entity property using a string.

parent 4a2d408e
No related branches found
No related tags found
3 merge requests!37Added calculation of individual ids,!34Added script to update individual ids, added docstrings massively,!32Added quiver plot in evaluation.
Pipeline #50120 failed
...@@ -40,7 +40,7 @@ from matplotlib import animation ...@@ -40,7 +40,7 @@ from matplotlib import animation
from matplotlib import patches from matplotlib import patches
from matplotlib import cm from matplotlib import cm
from tqdm import tqdm from tqdm.auto import tqdm
from subprocess import run from subprocess import run
...@@ -638,7 +638,7 @@ class File(h5py.File): ...@@ -638,7 +638,7 @@ class File(h5py.File):
def select_entity_property( def select_entity_property(
self, self,
predicate: types.LambdaType = None, predicate: types.LambdaType = None,
entity_property: property = Entity.poses, entity_property: Union[property, str] = Entity.poses,
) -> Iterable: ) -> Iterable:
"""Get a property of selected entities. """Get a property of selected entities.
...@@ -648,7 +648,7 @@ class File(h5py.File): ...@@ -648,7 +648,7 @@ class File(h5py.File):
Args: Args:
predicate: a lambda function, selecting entities predicate: a lambda function, selecting entities
(example: lambda e: e.category == "fish") (example: lambda e: e.category == "fish")
entity_property: a property of the Entity class (example: Entity.poses_rad) entity_property: a property of the Entity class (example: Entity.poses_rad) or a string with the name of the dataset.
Returns: Returns:
An three dimensional array of all properties of all entities with the shape (entity, time, property_length). An three dimensional array of all properties of all entities with the shape (entity, time, property_length).
If an entity has a shorter length of the property, the output will be filled with nans. If an entity has a shorter length of the property, the output will be filled with nans.
...@@ -661,6 +661,9 @@ class File(h5py.File): ...@@ -661,6 +661,9 @@ class File(h5py.File):
assert self.common_sampling(entities) is not None assert self.common_sampling(entities) is not None
# Initialize poses output array # Initialize poses output array
if isinstance(entity_property, str):
properties = [entity[entity_property] for entity in entities]
else:
properties = [entity_property.__get__(entity) for entity in entities] properties = [entity_property.__get__(entity) for entity in entities]
max_timesteps = max([0] + [p.shape[0] for p in properties]) max_timesteps = max([0] + [p.shape[0] for p in properties])
...@@ -992,6 +995,7 @@ class File(h5py.File): ...@@ -992,6 +995,7 @@ class File(h5py.File):
"show_text": False, "show_text": False,
"render_goals": False, "render_goals": False,
"render_targets": False, "render_targets": False,
"dpi": 200,
} }
options = { options = {
...@@ -1001,7 +1005,7 @@ class File(h5py.File): ...@@ -1001,7 +1005,7 @@ class File(h5py.File):
fig, ax = plt.subplots(figsize=(10, 10)) fig, ax = plt.subplots(figsize=(10, 10))
ax.set_facecolor("gray") ax.set_facecolor("gray")
plt.tight_layout(pad=0.05)
n_entities = len(self.entities) n_entities = len(self.entities)
lines = [ lines = [
plt.plot([], [], lw=options["linewidth"], zorder=0)[0] plt.plot([], [], lw=options["linewidth"], zorder=0)[0]
...@@ -1040,12 +1044,6 @@ class File(h5py.File): ...@@ -1040,12 +1044,6 @@ class File(h5py.File):
# border = plt.plot(border_vertices[0], border_vertices[1], "k") # border = plt.plot(border_vertices[0], border_vertices[1], "k")
border = patches.Polygon(border_vertices.T, facecolor="w", zorder=-1) border = patches.Polygon(border_vertices.T, facecolor="w", zorder=-1)
start_pose = self.entity_poses_rad[:, 0]
self.middle_of_swarm = np.mean(start_pose, axis=0)
min_view = np.max((np.max(start_pose, axis=0) - np.min(start_pose, axis=0))[:2])
self.view_size = np.max([options["view_size"], min_view + options["margin"]])
def title(file_frame: int) -> str: def title(file_frame: int) -> str:
"""Search for datasets containing text for displaying it in the video""" """Search for datasets containing text for displaying it in the video"""
output = [] output = []
...@@ -1094,6 +1092,8 @@ class File(h5py.File): ...@@ -1094,6 +1092,8 @@ class File(h5py.File):
if options["cut_frames_end"] == 0 or options["cut_frames_end"] is None: if options["cut_frames_end"] == 0 or options["cut_frames_end"] is None:
options["cut_frames_end"] = n_frames options["cut_frames_end"] = n_frames
if options["cut_frames_start"] is None:
options["cut_frames_start"] = 0
frame_range = ( frame_range = (
options["cut_frames_start"], options["cut_frames_start"],
min(n_frames, options["cut_frames_end"]), min(n_frames, options["cut_frames_end"]),
...@@ -1101,6 +1101,12 @@ class File(h5py.File): ...@@ -1101,6 +1101,12 @@ class File(h5py.File):
n_frames = int((frame_range[1] - frame_range[0]) / options["speedup"]) n_frames = int((frame_range[1] - frame_range[0]) / options["speedup"])
start_pose = self.entity_poses_rad[:, frame_range[0]]
self.middle_of_swarm = np.mean(start_pose, axis=0)
min_view = np.max((np.max(start_pose, axis=0) - np.min(start_pose, axis=0))[:2])
self.view_size = np.max([options["view_size"], min_view + options["margin"]])
if video_path is not None: if video_path is not None:
pbar = tqdm(range(n_frames)) pbar = tqdm(range(n_frames))
...@@ -1204,10 +1210,8 @@ class File(h5py.File): ...@@ -1204,10 +1210,8 @@ class File(h5py.File):
if not video_path.exists(): if not video_path.exists():
print(f"saving video to {video_path}") print(f"saving video to {video_path}")
writervideo = animation.FFMpegWriter(fps=25) writervideo = animation.FFMpegWriter(fps=self.frequency)
ani.save( ani.save(video_path, writer=writervideo, dpi=options["dpi"])
video_path, plt.close()
writer=writervideo,
)
else: else:
plt.show() plt.show()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment