From 719e2e4e3a8a47c72c1de0eb3f53e4a964050273 Mon Sep 17 00:00:00 2001
From: gianluv00 <gianluv00@mi.fu-berlin.de>
Date: Tue, 9 Aug 2022 17:27:26 +0200
Subject: [PATCH] Changed return type of create_entity and
 create_multiple_entities to robofish.io.Entity

---
 src/robofish/io/entity.py |  2 +-
 src/robofish/io/file.py   | 10 +++++-----
 2 files changed, 6 insertions(+), 6 deletions(-)

diff --git a/src/robofish/io/entity.py b/src/robofish/io/entity.py
index 9fe37a3..d35495e 100644
--- a/src/robofish/io/entity.py
+++ b/src/robofish/io/entity.py
@@ -30,7 +30,7 @@ class Entity(h5py.Group):
         orientations: Iterable = None,
         outlines: Iterable = None,
         sampling: str = None,
-    ):
+    ) -> robofish.io.Entity:
         poses, positions, orientations, outlines = utils.np_array(
             poses, positions, orientations, outlines
         )
diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py
index 86a580e..0010b17 100644
--- a/src/robofish/io/file.py
+++ b/src/robofish/io/file.py
@@ -439,7 +439,7 @@ class File(h5py.File):
         orientations: Iterable = None,
         outlines: Iterable = None,
         sampling: str = None,
-    ) -> str:
+    ) -> robofish.io.Entity:
         """Creates a new single entity.
 
         Args:
@@ -478,7 +478,7 @@ class File(h5py.File):
         names: Iterable[str] = None,
         outlines=None,
         sampling=None,
-    ) -> Iterable:
+    ) -> Iterable[robofish.io.Entity]:
         """Creates multiple entities.
 
         Args:
@@ -496,7 +496,7 @@ class File(h5py.File):
         ), f"A 3 dimensional array was expected (entity, timestep, 3). There were {poses.ndim} dimensions in poses: {poses.shape}"
         assert poses.shape[2] in [3, 4]
         agents = poses.shape[0]
-        entity_names = []
+        entities = []
 
         for i in range(agents):
             e_name = None if names is None else names[i]
@@ -504,7 +504,7 @@ class File(h5py.File):
                 outlines if outlines is None or outlines.ndim == 3 else outlines[i]
             )
 
-            entity_names.append(
+            entities.append(
                 self.create_entity(
                     category=category,
                     sampling=sampling,
@@ -513,7 +513,7 @@ class File(h5py.File):
                     outlines=e_outline,
                 )
             )
-        return entity_names
+        return entities
 
     def update_calculated_data(self, verbose=False):
         changed = any([e.update_calculated_data(verbose) for e in self.entities])
-- 
GitLab