From 574b27ee1107675c5ef83304ad4d48b16925719a Mon Sep 17 00:00:00 2001
From: Mathis Hocke <mathis.hocke@fu-berlin.de>
Date: Thu, 25 Jul 2024 14:14:03 +0200
Subject: [PATCH] Fix some type annotations

---
 src/robofish/io/entity.py | 54 +++++++++++++++++++--------------------
 src/robofish/io/file.py   | 54 +++++++++++++++++++--------------------
 2 files changed, 53 insertions(+), 55 deletions(-)

diff --git a/src/robofish/io/entity.py b/src/robofish/io/entity.py
index bdea2c2..c2e7a67 100644
--- a/src/robofish/io/entity.py
+++ b/src/robofish/io/entity.py
@@ -2,13 +2,11 @@
 .. include:: ../../../docs/entity.md
 """
 
-import robofish.io
 import robofish.io.utils as utils
 
 import h5py
 import numpy as np
-from typing import Iterable, Union
-import datetime
+from typing import Iterable, Optional
 import logging
 import deprecation
 
@@ -24,14 +22,14 @@ class Entity(h5py.Group):
         cla,
         entities_group,
         category: str,
-        poses: Iterable = None,
-        name: str = None,
-        individual_id: int = None,
-        positions: Iterable = None,
-        orientations: Iterable = None,
-        outlines: Iterable = None,
-        sampling: str = None,
-    ):
+        poses: Optional[Iterable] = None,
+        name: Optional[str] = None,
+        individual_id: Optional[int] = None,
+        positions: Optional[Iterable] = None,
+        orientations: Optional[Iterable] = None,
+        outlines: Optional[Iterable] = None,
+        sampling: Optional[str] = None,
+    ) -> "Entity":
         poses, positions, orientations, outlines = utils.np_array(
             poses, positions, orientations, outlines
         )
@@ -92,29 +90,29 @@ class Entity(h5py.Group):
         return ori_vec
 
     @property
-    def group_name(self):
+    def group_name(self) -> str:
         return super().name
 
     @property
-    def name(self):
+    def name(self) -> str:
         return self.group_name.split("/")[-1]
 
     @property
-    def category(self):
+    def category(self) -> str:
         return self.attrs["category"]
 
-    def create_outlines(self, outlines: Iterable, sampling=None):
+    def create_outlines(self, outlines: Iterable, sampling=None) -> None:
         outlines = self.create_dataset("outlines", data=outlines, dtype=np.float32)
         if sampling is not None:
             outlines.attrs["sampling"] = sampling
 
     def create_poses(
         self,
-        poses: Iterable = None,
-        positions: Iterable = None,
-        orientations: Iterable = None,
-        sampling: str = None,
-    ):
+        poses: Optional[Iterable] = None,
+        positions: Optional[Iterable] = None,
+        orientations: Optional[Iterable] = None,
+        sampling: Optional[str] = None,
+    ) -> None:
         poses, positions, orientations = utils.np_array(poses, positions, orientations)
 
         # Either poses or positions not both
@@ -147,11 +145,11 @@ class Entity(h5py.Group):
                 orientations.attrs["sampling"] = sampling
 
     @property
-    def positions(self):
+    def positions(self) -> np.ndarray:
         return self["positions"]
 
     @property
-    def orientations(self):
+    def orientations(self) -> np.ndarray:
         if not "orientations" in self:
             # If no orientation is given, the default direction is to the right
             return np.tile([0, 1], (self.positions.shape[0], 1))
@@ -205,7 +203,7 @@ class Entity(h5py.Group):
         return poses_with_calculated_orientation
 
     @property
-    def poses_hash(self):
+    def poses_hash(self) -> int:
         # The hash of h5py datasets changes each time the file is reopened.
         # Also the hash of casting the array to bytes and calculating the hash changes.
         def npsumhash(a):
@@ -220,11 +218,11 @@ class Entity(h5py.Group):
         return h
 
     @property
-    def poses(self):
+    def poses(self) -> np.ndarray:
         return np.concatenate([self.positions, self.orientations], axis=1)
 
     @property
-    def poses_rad(self):
+    def poses_rad(self) -> np.ndarray:
         return np.concatenate([self.positions, self.orientations_rad], axis=1)
 
     @property
@@ -260,7 +258,7 @@ class Entity(h5py.Group):
         turn = utils.limit_angle_range(diff[:, 2], _range=(-np.pi, np.pi))
         return np.stack([speed, turn], axis=-1)
 
-    def update_calculated_data(self, verbose=False, force_update=False):
+    def update_calculated_data(self, verbose=False, force_update=False) -> bool:
         changed = False
         if (
             "poses_hash" not in self.attrs
@@ -341,7 +339,7 @@ class Entity(h5py.Group):
         return np.stack([speed, turn], axis=-1)
 
     @property
-    def actions_speeds_turns(self):
+    def actions_speeds_turns(self) -> np.ndarray:
         if "calculated_actions_speeds_turns" in self:
             assert (
                 self.attrs["poses_hash"] == self.poses_hash
@@ -351,7 +349,7 @@ class Entity(h5py.Group):
             return self.calculate_actions_speeds_turns()
 
     @property
-    def orientations_rad(self):
+    def orientations_rad(self) -> np.ndarray:
         if "calculated_orientations_rad" in self:
             assert (
                 self.attrs["poses_hash"] == self.poses_hash
diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py
index 1ce2f1e..58ece8b 100644
--- a/src/robofish/io/file.py
+++ b/src/robofish/io/file.py
@@ -63,20 +63,20 @@ class File(h5py.File):
 
     def __init__(
         self,
-        path: Union[str, Path] = None,
+        path: Optinal[Union[str, Path]] = None,
         mode: str = "r",
         *,  # PEP 3102
-        world_size_cm: List[int] = None,
+        world_size_cm: Optional[List[int]] = None,
         world_shape: str = "rectangle",
         validate: bool = False,
         validate_when_saving: bool = True,
         strict_validate: bool = False,
         format_version: List[int] = default_format_version,
         format_url: str = default_format_url,
-        sampling_name: str = None,
-        frequency_hz: int = None,
-        monotonic_time_points_us: Iterable = None,
-        calendar_time_points: Iterable = None,
+        sampling_name: Optional[str] = None,
+        frequency_hz: Optional[int] = None,
+        monotonic_time_points_us: Optional[Iterable] = None,
+        calendar_time_points: Optional[Iterable] = None,
         open_copy: bool = False,
         validate_poses_hash: bool = True,
         calculate_data_on_close: bool = True,
@@ -332,11 +332,11 @@ class File(h5py.File):
 
     def create_sampling(
         self,
-        name: str = None,
-        frequency_hz: int = None,
-        monotonic_time_points_us: Iterable = None,
-        calendar_time_points: Iterable = None,
-        default: bool = False,
+        name: Optional[str] = None,
+        frequency_hz: Optional[int] = None,
+        monotonic_time_points_us: Optional[Iterable] = None,
+        calendar_time_points: Optional[Iterable] = None,
+        default: Optional[bool] = False,
     ):
         # Find Name for sampling if none is given
         if name is None:
@@ -465,24 +465,24 @@ class File(h5py.File):
     def create_entity(
         self,
         category: str,
-        poses: Iterable = None,
-        name: str = None,
-        individual_id: int = None,
-        positions: Iterable = None,
-        orientations: Iterable = None,
-        outlines: Iterable = None,
-        sampling: str = None,
-    ) -> str:
+        poses: Optional[Iterable] = None,
+        name: Optional[str] = None,
+        individual_id: Optional[int] = None,
+        positions: Optional[Iterable] = None,
+        orientations: Optional[Iterable] = None,
+        outlines: Optional[Iterable] = None,
+        sampling: Optional[str] = None,
+    ) -> Entity:
         """Creates a new single entity.
 
         Args:
             TODO
             category: the  of the entity. The canonical values are ['organism', 'robot', 'obstacle'].
             poses: optional two dimensional array, containing the poses of the entity (x,y,orientation_x, orientation_y).
-            poses_rad: optional two dimensional containing the poses of the entity (x,y, orientation_rad).
             name: optional name of the entity. If no name is given, the  is used with an id (e.g. 'fish_1')
             individual_id (int, optional): invididual id of the entity.
             outlines: optional three dimensional array, containing the outlines of the entity
+            sampling: The string refference to the sampling. If none is given, the standard sampling from creating the file is used.
         Returns:
             Name of the created entity
         """
@@ -510,11 +510,11 @@ class File(h5py.File):
         self,
         category: str,
         poses: Iterable,
-        names: Iterable[str] = None,
-        individual_ids: Iterable[int] = None,
+        names: Optional[Iterable[str]] = None,
+        individual_ids: Optional[Iterable[int]] = None,
         outlines=None,
         sampling=None,
-    ) -> Iterable:
+    ) -> Iterable[Entity]:
         """Creates multiple entities.
 
         Args:
@@ -868,13 +868,13 @@ class File(h5py.File):
 
     def plot(
         self,
-        ax: matplotlib.axes = None,
+        ax: Optional[matplotlib.axes] = None,
         lw_distances: bool = False,
         lw: int = 2,
         ms: int = 32,
-        figsize: Tuple[int] = None,
+        figsize: Optional[Tuple[int]] = None,
         step_size: int = 25,
-        c: List = None,
+        c: Optional[List] = None,
         cmap: matplotlib.colors.Colormap = "Set1",
         skip_timesteps=0,
         max_timesteps=None,
@@ -985,7 +985,7 @@ class File(h5py.File):
 
         return ax
 
-    def render(self, video_path: Union[str, Path] = None, **kwargs: Dict) -> None:
+    def render(self, video_path: Optional[Union[str, Path]] = None, **kwargs: Dict) -> None:
         """Render a video of the file.
 
         The tracks are rendered in a video using matplotlib.animation.FuncAnimation.
-- 
GitLab