From 719e2e4e3a8a47c72c1de0eb3f53e4a964050273 Mon Sep 17 00:00:00 2001 From: gianluv00 <gianluv00@mi.fu-berlin.de> Date: Tue, 9 Aug 2022 17:27:26 +0200 Subject: [PATCH] Changed return type of create_entity and create_multiple_entities to robofish.io.Entity --- src/robofish/io/entity.py | 2 +- src/robofish/io/file.py | 10 +++++----- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/robofish/io/entity.py b/src/robofish/io/entity.py index 9fe37a3..d35495e 100644 --- a/src/robofish/io/entity.py +++ b/src/robofish/io/entity.py @@ -30,7 +30,7 @@ class Entity(h5py.Group): orientations: Iterable = None, outlines: Iterable = None, sampling: str = None, - ): + ) -> robofish.io.Entity: poses, positions, orientations, outlines = utils.np_array( poses, positions, orientations, outlines ) diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index 86a580e..0010b17 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -439,7 +439,7 @@ class File(h5py.File): orientations: Iterable = None, outlines: Iterable = None, sampling: str = None, - ) -> str: + ) -> robofish.io.Entity: """Creates a new single entity. Args: @@ -478,7 +478,7 @@ class File(h5py.File): names: Iterable[str] = None, outlines=None, sampling=None, - ) -> Iterable: + ) -> Iterable[robofish.io.Entity]: """Creates multiple entities. Args: @@ -496,7 +496,7 @@ class File(h5py.File): ), f"A 3 dimensional array was expected (entity, timestep, 3). There were {poses.ndim} dimensions in poses: {poses.shape}" assert poses.shape[2] in [3, 4] agents = poses.shape[0] - entity_names = [] + entities = [] for i in range(agents): e_name = None if names is None else names[i] @@ -504,7 +504,7 @@ class File(h5py.File): outlines if outlines is None or outlines.ndim == 3 else outlines[i] ) - entity_names.append( + entities.append( self.create_entity( category=category, sampling=sampling, @@ -513,7 +513,7 @@ class File(h5py.File): outlines=e_outline, ) ) - return entity_names + return entities def update_calculated_data(self, verbose=False): changed = any([e.update_calculated_data(verbose) for e in self.entities]) -- GitLab