Skip to content
Snippets Groups Projects
Commit 574b27ee authored by mhocke's avatar mhocke
Browse files

Fix some type annotations

parent 06bfd9fa
No related branches found
No related tags found
1 merge request!50Merge dev_mathis with multiple changes
Pipeline #61549 failed
...@@ -2,13 +2,11 @@ ...@@ -2,13 +2,11 @@
.. include:: ../../../docs/entity.md .. include:: ../../../docs/entity.md
""" """
import robofish.io
import robofish.io.utils as utils import robofish.io.utils as utils
import h5py import h5py
import numpy as np import numpy as np
from typing import Iterable, Union from typing import Iterable, Optional
import datetime
import logging import logging
import deprecation import deprecation
...@@ -24,14 +22,14 @@ class Entity(h5py.Group): ...@@ -24,14 +22,14 @@ class Entity(h5py.Group):
cla, cla,
entities_group, entities_group,
category: str, category: str,
poses: Iterable = None, poses: Optional[Iterable] = None,
name: str = None, name: Optional[str] = None,
individual_id: int = None, individual_id: Optional[int] = None,
positions: Iterable = None, positions: Optional[Iterable] = None,
orientations: Iterable = None, orientations: Optional[Iterable] = None,
outlines: Iterable = None, outlines: Optional[Iterable] = None,
sampling: str = None, sampling: Optional[str] = None,
): ) -> "Entity":
poses, positions, orientations, outlines = utils.np_array( poses, positions, orientations, outlines = utils.np_array(
poses, positions, orientations, outlines poses, positions, orientations, outlines
) )
...@@ -92,29 +90,29 @@ class Entity(h5py.Group): ...@@ -92,29 +90,29 @@ class Entity(h5py.Group):
return ori_vec return ori_vec
@property @property
def group_name(self): def group_name(self) -> str:
return super().name return super().name
@property @property
def name(self): def name(self) -> str:
return self.group_name.split("/")[-1] return self.group_name.split("/")[-1]
@property @property
def category(self): def category(self) -> str:
return self.attrs["category"] return self.attrs["category"]
def create_outlines(self, outlines: Iterable, sampling=None): def create_outlines(self, outlines: Iterable, sampling=None) -> None:
outlines = self.create_dataset("outlines", data=outlines, dtype=np.float32) outlines = self.create_dataset("outlines", data=outlines, dtype=np.float32)
if sampling is not None: if sampling is not None:
outlines.attrs["sampling"] = sampling outlines.attrs["sampling"] = sampling
def create_poses( def create_poses(
self, self,
poses: Iterable = None, poses: Optional[Iterable] = None,
positions: Iterable = None, positions: Optional[Iterable] = None,
orientations: Iterable = None, orientations: Optional[Iterable] = None,
sampling: str = None, sampling: Optional[str] = None,
): ) -> None:
poses, positions, orientations = utils.np_array(poses, positions, orientations) poses, positions, orientations = utils.np_array(poses, positions, orientations)
# Either poses or positions not both # Either poses or positions not both
...@@ -147,11 +145,11 @@ class Entity(h5py.Group): ...@@ -147,11 +145,11 @@ class Entity(h5py.Group):
orientations.attrs["sampling"] = sampling orientations.attrs["sampling"] = sampling
@property @property
def positions(self): def positions(self) -> np.ndarray:
return self["positions"] return self["positions"]
@property @property
def orientations(self): def orientations(self) -> np.ndarray:
if not "orientations" in self: if not "orientations" in self:
# If no orientation is given, the default direction is to the right # If no orientation is given, the default direction is to the right
return np.tile([0, 1], (self.positions.shape[0], 1)) return np.tile([0, 1], (self.positions.shape[0], 1))
...@@ -205,7 +203,7 @@ class Entity(h5py.Group): ...@@ -205,7 +203,7 @@ class Entity(h5py.Group):
return poses_with_calculated_orientation return poses_with_calculated_orientation
@property @property
def poses_hash(self): def poses_hash(self) -> int:
# The hash of h5py datasets changes each time the file is reopened. # The hash of h5py datasets changes each time the file is reopened.
# Also the hash of casting the array to bytes and calculating the hash changes. # Also the hash of casting the array to bytes and calculating the hash changes.
def npsumhash(a): def npsumhash(a):
...@@ -220,11 +218,11 @@ class Entity(h5py.Group): ...@@ -220,11 +218,11 @@ class Entity(h5py.Group):
return h return h
@property @property
def poses(self): def poses(self) -> np.ndarray:
return np.concatenate([self.positions, self.orientations], axis=1) return np.concatenate([self.positions, self.orientations], axis=1)
@property @property
def poses_rad(self): def poses_rad(self) -> np.ndarray:
return np.concatenate([self.positions, self.orientations_rad], axis=1) return np.concatenate([self.positions, self.orientations_rad], axis=1)
@property @property
...@@ -260,7 +258,7 @@ class Entity(h5py.Group): ...@@ -260,7 +258,7 @@ class Entity(h5py.Group):
turn = utils.limit_angle_range(diff[:, 2], _range=(-np.pi, np.pi)) turn = utils.limit_angle_range(diff[:, 2], _range=(-np.pi, np.pi))
return np.stack([speed, turn], axis=-1) return np.stack([speed, turn], axis=-1)
def update_calculated_data(self, verbose=False, force_update=False): def update_calculated_data(self, verbose=False, force_update=False) -> bool:
changed = False changed = False
if ( if (
"poses_hash" not in self.attrs "poses_hash" not in self.attrs
...@@ -341,7 +339,7 @@ class Entity(h5py.Group): ...@@ -341,7 +339,7 @@ class Entity(h5py.Group):
return np.stack([speed, turn], axis=-1) return np.stack([speed, turn], axis=-1)
@property @property
def actions_speeds_turns(self): def actions_speeds_turns(self) -> np.ndarray:
if "calculated_actions_speeds_turns" in self: if "calculated_actions_speeds_turns" in self:
assert ( assert (
self.attrs["poses_hash"] == self.poses_hash self.attrs["poses_hash"] == self.poses_hash
...@@ -351,7 +349,7 @@ class Entity(h5py.Group): ...@@ -351,7 +349,7 @@ class Entity(h5py.Group):
return self.calculate_actions_speeds_turns() return self.calculate_actions_speeds_turns()
@property @property
def orientations_rad(self): def orientations_rad(self) -> np.ndarray:
if "calculated_orientations_rad" in self: if "calculated_orientations_rad" in self:
assert ( assert (
self.attrs["poses_hash"] == self.poses_hash self.attrs["poses_hash"] == self.poses_hash
......
...@@ -63,20 +63,20 @@ class File(h5py.File): ...@@ -63,20 +63,20 @@ class File(h5py.File):
def __init__( def __init__(
self, self,
path: Union[str, Path] = None, path: Optinal[Union[str, Path]] = None,
mode: str = "r", mode: str = "r",
*, # PEP 3102 *, # PEP 3102
world_size_cm: List[int] = None, world_size_cm: Optional[List[int]] = None,
world_shape: str = "rectangle", world_shape: str = "rectangle",
validate: bool = False, validate: bool = False,
validate_when_saving: bool = True, validate_when_saving: bool = True,
strict_validate: bool = False, strict_validate: bool = False,
format_version: List[int] = default_format_version, format_version: List[int] = default_format_version,
format_url: str = default_format_url, format_url: str = default_format_url,
sampling_name: str = None, sampling_name: Optional[str] = None,
frequency_hz: int = None, frequency_hz: Optional[int] = None,
monotonic_time_points_us: Iterable = None, monotonic_time_points_us: Optional[Iterable] = None,
calendar_time_points: Iterable = None, calendar_time_points: Optional[Iterable] = None,
open_copy: bool = False, open_copy: bool = False,
validate_poses_hash: bool = True, validate_poses_hash: bool = True,
calculate_data_on_close: bool = True, calculate_data_on_close: bool = True,
...@@ -332,11 +332,11 @@ class File(h5py.File): ...@@ -332,11 +332,11 @@ class File(h5py.File):
def create_sampling( def create_sampling(
self, self,
name: str = None, name: Optional[str] = None,
frequency_hz: int = None, frequency_hz: Optional[int] = None,
monotonic_time_points_us: Iterable = None, monotonic_time_points_us: Optional[Iterable] = None,
calendar_time_points: Iterable = None, calendar_time_points: Optional[Iterable] = None,
default: bool = False, default: Optional[bool] = False,
): ):
# Find Name for sampling if none is given # Find Name for sampling if none is given
if name is None: if name is None:
...@@ -465,24 +465,24 @@ class File(h5py.File): ...@@ -465,24 +465,24 @@ class File(h5py.File):
def create_entity( def create_entity(
self, self,
category: str, category: str,
poses: Iterable = None, poses: Optional[Iterable] = None,
name: str = None, name: Optional[str] = None,
individual_id: int = None, individual_id: Optional[int] = None,
positions: Iterable = None, positions: Optional[Iterable] = None,
orientations: Iterable = None, orientations: Optional[Iterable] = None,
outlines: Iterable = None, outlines: Optional[Iterable] = None,
sampling: str = None, sampling: Optional[str] = None,
) -> str: ) -> Entity:
"""Creates a new single entity. """Creates a new single entity.
Args: Args:
TODO TODO
category: the of the entity. The canonical values are ['organism', 'robot', 'obstacle']. category: the of the entity. The canonical values are ['organism', 'robot', 'obstacle'].
poses: optional two dimensional array, containing the poses of the entity (x,y,orientation_x, orientation_y). poses: optional two dimensional array, containing the poses of the entity (x,y,orientation_x, orientation_y).
poses_rad: optional two dimensional containing the poses of the entity (x,y, orientation_rad).
name: optional name of the entity. If no name is given, the is used with an id (e.g. 'fish_1') name: optional name of the entity. If no name is given, the is used with an id (e.g. 'fish_1')
individual_id (int, optional): invididual id of the entity. individual_id (int, optional): invididual id of the entity.
outlines: optional three dimensional array, containing the outlines of the entity outlines: optional three dimensional array, containing the outlines of the entity
sampling: The string refference to the sampling. If none is given, the standard sampling from creating the file is used.
Returns: Returns:
Name of the created entity Name of the created entity
""" """
...@@ -510,11 +510,11 @@ class File(h5py.File): ...@@ -510,11 +510,11 @@ class File(h5py.File):
self, self,
category: str, category: str,
poses: Iterable, poses: Iterable,
names: Iterable[str] = None, names: Optional[Iterable[str]] = None,
individual_ids: Iterable[int] = None, individual_ids: Optional[Iterable[int]] = None,
outlines=None, outlines=None,
sampling=None, sampling=None,
) -> Iterable: ) -> Iterable[Entity]:
"""Creates multiple entities. """Creates multiple entities.
Args: Args:
...@@ -868,13 +868,13 @@ class File(h5py.File): ...@@ -868,13 +868,13 @@ class File(h5py.File):
def plot( def plot(
self, self,
ax: matplotlib.axes = None, ax: Optional[matplotlib.axes] = None,
lw_distances: bool = False, lw_distances: bool = False,
lw: int = 2, lw: int = 2,
ms: int = 32, ms: int = 32,
figsize: Tuple[int] = None, figsize: Optional[Tuple[int]] = None,
step_size: int = 25, step_size: int = 25,
c: List = None, c: Optional[List] = None,
cmap: matplotlib.colors.Colormap = "Set1", cmap: matplotlib.colors.Colormap = "Set1",
skip_timesteps=0, skip_timesteps=0,
max_timesteps=None, max_timesteps=None,
...@@ -985,7 +985,7 @@ class File(h5py.File): ...@@ -985,7 +985,7 @@ class File(h5py.File):
return ax return ax
def render(self, video_path: Union[str, Path] = None, **kwargs: Dict) -> None: def render(self, video_path: Optional[Union[str, Path]] = None, **kwargs: Dict) -> None:
"""Render a video of the file. """Render a video of the file.
The tracks are rendered in a video using matplotlib.animation.FuncAnimation. The tracks are rendered in a video using matplotlib.animation.FuncAnimation.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment