From 574b27ee1107675c5ef83304ad4d48b16925719a Mon Sep 17 00:00:00 2001 From: Mathis Hocke <mathis.hocke@fu-berlin.de> Date: Thu, 25 Jul 2024 14:14:03 +0200 Subject: [PATCH] Fix some type annotations --- src/robofish/io/entity.py | 54 +++++++++++++++++++-------------------- src/robofish/io/file.py | 54 +++++++++++++++++++-------------------- 2 files changed, 53 insertions(+), 55 deletions(-) diff --git a/src/robofish/io/entity.py b/src/robofish/io/entity.py index bdea2c2..c2e7a67 100644 --- a/src/robofish/io/entity.py +++ b/src/robofish/io/entity.py @@ -2,13 +2,11 @@ .. include:: ../../../docs/entity.md """ -import robofish.io import robofish.io.utils as utils import h5py import numpy as np -from typing import Iterable, Union -import datetime +from typing import Iterable, Optional import logging import deprecation @@ -24,14 +22,14 @@ class Entity(h5py.Group): cla, entities_group, category: str, - poses: Iterable = None, - name: str = None, - individual_id: int = None, - positions: Iterable = None, - orientations: Iterable = None, - outlines: Iterable = None, - sampling: str = None, - ): + poses: Optional[Iterable] = None, + name: Optional[str] = None, + individual_id: Optional[int] = None, + positions: Optional[Iterable] = None, + orientations: Optional[Iterable] = None, + outlines: Optional[Iterable] = None, + sampling: Optional[str] = None, + ) -> "Entity": poses, positions, orientations, outlines = utils.np_array( poses, positions, orientations, outlines ) @@ -92,29 +90,29 @@ class Entity(h5py.Group): return ori_vec @property - def group_name(self): + def group_name(self) -> str: return super().name @property - def name(self): + def name(self) -> str: return self.group_name.split("/")[-1] @property - def category(self): + def category(self) -> str: return self.attrs["category"] - def create_outlines(self, outlines: Iterable, sampling=None): + def create_outlines(self, outlines: Iterable, sampling=None) -> None: outlines = self.create_dataset("outlines", data=outlines, dtype=np.float32) if sampling is not None: outlines.attrs["sampling"] = sampling def create_poses( self, - poses: Iterable = None, - positions: Iterable = None, - orientations: Iterable = None, - sampling: str = None, - ): + poses: Optional[Iterable] = None, + positions: Optional[Iterable] = None, + orientations: Optional[Iterable] = None, + sampling: Optional[str] = None, + ) -> None: poses, positions, orientations = utils.np_array(poses, positions, orientations) # Either poses or positions not both @@ -147,11 +145,11 @@ class Entity(h5py.Group): orientations.attrs["sampling"] = sampling @property - def positions(self): + def positions(self) -> np.ndarray: return self["positions"] @property - def orientations(self): + def orientations(self) -> np.ndarray: if not "orientations" in self: # If no orientation is given, the default direction is to the right return np.tile([0, 1], (self.positions.shape[0], 1)) @@ -205,7 +203,7 @@ class Entity(h5py.Group): return poses_with_calculated_orientation @property - def poses_hash(self): + def poses_hash(self) -> int: # The hash of h5py datasets changes each time the file is reopened. # Also the hash of casting the array to bytes and calculating the hash changes. def npsumhash(a): @@ -220,11 +218,11 @@ class Entity(h5py.Group): return h @property - def poses(self): + def poses(self) -> np.ndarray: return np.concatenate([self.positions, self.orientations], axis=1) @property - def poses_rad(self): + def poses_rad(self) -> np.ndarray: return np.concatenate([self.positions, self.orientations_rad], axis=1) @property @@ -260,7 +258,7 @@ class Entity(h5py.Group): turn = utils.limit_angle_range(diff[:, 2], _range=(-np.pi, np.pi)) return np.stack([speed, turn], axis=-1) - def update_calculated_data(self, verbose=False, force_update=False): + def update_calculated_data(self, verbose=False, force_update=False) -> bool: changed = False if ( "poses_hash" not in self.attrs @@ -341,7 +339,7 @@ class Entity(h5py.Group): return np.stack([speed, turn], axis=-1) @property - def actions_speeds_turns(self): + def actions_speeds_turns(self) -> np.ndarray: if "calculated_actions_speeds_turns" in self: assert ( self.attrs["poses_hash"] == self.poses_hash @@ -351,7 +349,7 @@ class Entity(h5py.Group): return self.calculate_actions_speeds_turns() @property - def orientations_rad(self): + def orientations_rad(self) -> np.ndarray: if "calculated_orientations_rad" in self: assert ( self.attrs["poses_hash"] == self.poses_hash diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index 1ce2f1e..58ece8b 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -63,20 +63,20 @@ class File(h5py.File): def __init__( self, - path: Union[str, Path] = None, + path: Optinal[Union[str, Path]] = None, mode: str = "r", *, # PEP 3102 - world_size_cm: List[int] = None, + world_size_cm: Optional[List[int]] = None, world_shape: str = "rectangle", validate: bool = False, validate_when_saving: bool = True, strict_validate: bool = False, format_version: List[int] = default_format_version, format_url: str = default_format_url, - sampling_name: str = None, - frequency_hz: int = None, - monotonic_time_points_us: Iterable = None, - calendar_time_points: Iterable = None, + sampling_name: Optional[str] = None, + frequency_hz: Optional[int] = None, + monotonic_time_points_us: Optional[Iterable] = None, + calendar_time_points: Optional[Iterable] = None, open_copy: bool = False, validate_poses_hash: bool = True, calculate_data_on_close: bool = True, @@ -332,11 +332,11 @@ class File(h5py.File): def create_sampling( self, - name: str = None, - frequency_hz: int = None, - monotonic_time_points_us: Iterable = None, - calendar_time_points: Iterable = None, - default: bool = False, + name: Optional[str] = None, + frequency_hz: Optional[int] = None, + monotonic_time_points_us: Optional[Iterable] = None, + calendar_time_points: Optional[Iterable] = None, + default: Optional[bool] = False, ): # Find Name for sampling if none is given if name is None: @@ -465,24 +465,24 @@ class File(h5py.File): def create_entity( self, category: str, - poses: Iterable = None, - name: str = None, - individual_id: int = None, - positions: Iterable = None, - orientations: Iterable = None, - outlines: Iterable = None, - sampling: str = None, - ) -> str: + poses: Optional[Iterable] = None, + name: Optional[str] = None, + individual_id: Optional[int] = None, + positions: Optional[Iterable] = None, + orientations: Optional[Iterable] = None, + outlines: Optional[Iterable] = None, + sampling: Optional[str] = None, + ) -> Entity: """Creates a new single entity. Args: TODO category: the of the entity. The canonical values are ['organism', 'robot', 'obstacle']. poses: optional two dimensional array, containing the poses of the entity (x,y,orientation_x, orientation_y). - poses_rad: optional two dimensional containing the poses of the entity (x,y, orientation_rad). name: optional name of the entity. If no name is given, the is used with an id (e.g. 'fish_1') individual_id (int, optional): invididual id of the entity. outlines: optional three dimensional array, containing the outlines of the entity + sampling: The string refference to the sampling. If none is given, the standard sampling from creating the file is used. Returns: Name of the created entity """ @@ -510,11 +510,11 @@ class File(h5py.File): self, category: str, poses: Iterable, - names: Iterable[str] = None, - individual_ids: Iterable[int] = None, + names: Optional[Iterable[str]] = None, + individual_ids: Optional[Iterable[int]] = None, outlines=None, sampling=None, - ) -> Iterable: + ) -> Iterable[Entity]: """Creates multiple entities. Args: @@ -868,13 +868,13 @@ class File(h5py.File): def plot( self, - ax: matplotlib.axes = None, + ax: Optional[matplotlib.axes] = None, lw_distances: bool = False, lw: int = 2, ms: int = 32, - figsize: Tuple[int] = None, + figsize: Optional[Tuple[int]] = None, step_size: int = 25, - c: List = None, + c: Optional[List] = None, cmap: matplotlib.colors.Colormap = "Set1", skip_timesteps=0, max_timesteps=None, @@ -985,7 +985,7 @@ class File(h5py.File): return ax - def render(self, video_path: Union[str, Path] = None, **kwargs: Dict) -> None: + def render(self, video_path: Optional[Union[str, Path]] = None, **kwargs: Dict) -> None: """Render a video of the file. The tracks are rendered in a video using matplotlib.animation.FuncAnimation. -- GitLab