diff --git a/src/robofish/io/entity.py b/src/robofish/io/entity.py index 5a6c806960a74512c6317c46fc3d813258da9cdd..3d1649cd7604053dac8e74b0a673883403a056db 100644 --- a/src/robofish/io/entity.py +++ b/src/robofish/io/entity.py @@ -26,6 +26,7 @@ class Entity(h5py.Group): category: str, poses: Iterable = None, name: str = None, + individual_id: int = None, positions: Iterable = None, orientations: Iterable = None, outlines: Iterable = None, @@ -55,6 +56,9 @@ class Entity(h5py.Group): entity.attrs["category"] = category + if individual_id is not None: + entity.attrs["individual_id"] = individual_id + entity.create_poses(poses, positions, orientations, sampling) if outlines is not None: diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index 8e746c4869334e461a809e10cbdc012435945757..08dabb9f1c8845330e6044981a3627f16fe0c48c 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -447,6 +447,7 @@ class File(h5py.File): category: str, poses: Iterable = None, name: str = None, + individual_id: int = None, positions: Iterable = None, orientations: Iterable = None, outlines: Iterable = None, @@ -460,6 +461,7 @@ class File(h5py.File): poses: optional two dimensional array, containing the poses of the entity (x,y,orientation_x, orientation_y). poses_rad: optional two dimensional containing the poses of the entity (x,y, orientation_rad). name: optional name of the entity. If no name is given, the is used with an id (e.g. 'fish_1') + individual_id (int, optional): invididual id of the entity. outlines: optional three dimensional array, containing the outlines of the entity Returns: Name of the created entity @@ -471,14 +473,15 @@ class File(h5py.File): ) entity = robofish.io.Entity.create_entity( - self["entities"], - category, - poses, - name, - positions, - orientations, - outlines, - sampling, + entities_group=self["entities"], + category=category, + poses=poses, + name=name, + individual_id=individual_id, + positions=positions, + orientations=orientations, + outlines=outlines, + sampling=sampling, ) return entity @@ -488,6 +491,7 @@ class File(h5py.File): category: str, poses: Iterable, names: Iterable[str] = None, + individual_ids: Iterable[int] = None, outlines=None, sampling=None, ) -> Iterable: @@ -497,6 +501,7 @@ class File(h5py.File): category: The common category for the entities. The canonical values are ['organism', 'robot', 'obstacle']. poses: three dimensional array, containing the poses of the entity. name: optional array of names of the entities. If no names are given, the category is used with an id (e.g. 'fish_1') + individual_ids (Iterable[int]): optional array of individual ids of the entities. outlines: optional array, containing the outlines of the entities, either a three dimensional common outline array can be given, or a four dimensional array. sampling: The string refference to the sampling. If none is given, the standard sampling from creating the file is used. Returns: @@ -515,13 +520,14 @@ class File(h5py.File): e_outline = ( outlines if outlines is None or outlines.ndim == 3 else outlines[i] ) - + individual_id = None if individual_ids is None else individual_ids[i] entity_names.append( self.create_entity( category=category, sampling=sampling, poses=poses[i], name=e_name, + individual_id=individual_id, outlines=e_outline, ) )