diff --git a/src/robofish/evaluate/app.py b/src/robofish/evaluate/app.py index 0a5e1e649a7bc1c5b7dbedd07049d4d3b061f87e..f0c7526fe8458582b01c56f7e7c130d0da0884bd 100644 --- a/src/robofish/evaluate/app.py +++ b/src/robofish/evaluate/app.py @@ -36,20 +36,21 @@ def function_dict() -> dict: "follow_iid": base.evaluate_follow_iid, "individual_speed": base.evaluate_individual_speed, "individual_iid": base.evaluate_individual_iid, - "quiver": base.evaluate_quiver, + # "quiver": base.evaluate_quiver, # Quiver has issues with multiple paths and raises exceptions. The function is not used for now. "all": base.evaluate_all, } -def evaluate(args=None): +def evaluate(args: dict = None) -> None: """This function can be called from the commandline to evaluate files. The function is called with robofish-io-evaluate. Different evaluation methods can be called, which generate graphs from the given files Args: - args: a dictionary to overwrite the argument parser - (robofish-io-evaluate --help for more info) + args (dict, optional): a dictionary to overwrite the argument parser (robofish-io-evaluate --help for more info) + Raises: + ValueError: When the analysis type is all, and no save_path is given. """ fdict = function_dict() @@ -104,7 +105,7 @@ def evaluate(args=None): args = parser.parse_args() if args.analysis_type == "all" and args.save_path is None: - raise Exception("When the analysis type is all, a --save_path must be given.") + raise ValueError("When the analysis type is all, a --save_path must be given.") if args.analysis_type in fdict: if args.labels is None: diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py index a597237ad389878518f59d4d01cc7f084e24e900..ac351f7251fc81eba25ab8ed54847a261b9f3e61 100644 --- a/src/robofish/evaluate/evaluate.py +++ b/src/robofish/evaluate/evaluate.py @@ -515,13 +515,17 @@ def evaluate_quiver( # ) # print(speeds_turns_from_paths) try: - speeds_turns_from_paths = np.array(speeds_turns_from_paths) + print(speeds_turns_from_paths) + speeds_turns_from_paths = np.stack( + [np.array(st, dtype=np.float32) for st in speeds_turns_from_paths] + ) except Exception as e: warnings.warn( - f"The conversion to numpy array failed:\nspeeds_turns_from_path.shape was {speeds_turns_from_paths.shape}. Exception was {e}" + f"The conversion to numpy array failed:\nlen(speeds_turns_from_path) was {len(speeds_turns_from_paths)}. Exception was {e}" ) return - + print(speeds_turns_from_paths.dtype) + print(speeds_turns_from_paths) all_poses = torch.tensor(poses_from_paths)[0, :, :, :-1].reshape((-1, 4)) speed = torch.tensor(speeds_turns_from_paths)[0, ..., 0].flatten() all_poses_speed = torch.clone(all_poses) diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index 5ca9801b002dd9ab637c7222deb2807979b6d43a..cacbcb6a1e156b12d908a2faaf208e7059453cca 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -16,7 +16,6 @@ # ----------------------------------------------------------- from __future__ import annotations -from pytest import skip import robofish.io from robofish.io.entity import Entity import h5py @@ -36,11 +35,10 @@ import warnings from textwrap import wrap import platform -import matplotlib as mpl +import matplotlib import matplotlib.pyplot as plt from matplotlib import animation from matplotlib import patches -from matplotlib import cm from tqdm.auto import tqdm @@ -850,19 +848,19 @@ class File(h5py.File): def plot( self, - ax: mpl.axes = None, + ax: matplotlib.axes = None, lw_distances: bool = False, lw: int = 2, ms: int = 32, figsize: Tuple[int] = None, step_size: int = 4, c: List = None, - cmap: mpl.colors.Colormap = "Set1", + cmap: matplotlib.colors.Colormap = "Set1", skip_timesteps=0, max_timesteps=None, show=False, legend=True, - ) -> mpl.axes: + ) -> matplotlib.axes: """Plot the file using matplotlib.pyplot The tracks in the file are plotted using matplotlib.plot(). @@ -910,7 +908,7 @@ class File(h5py.File): else: step_size = poses.shape[1] - cmap = cm.get_cmap(cmap) + cmap = matplotlib.colormaps[cmap] x_world, y_world = self.world_size if figsize is None: @@ -1260,10 +1258,10 @@ class File(h5py.File): ) current_pose = entity_poses[i_entity, file_frame] - t = mpl.transforms.Affine2D().translate( + t = matplotlib.transforms.Affine2D().translate( current_pose[0], current_pose[1] ) - r = mpl.transforms.Affine2D().rotate(current_pose[2]) + r = matplotlib.transforms.Affine2D().rotate(current_pose[2]) tra = r + t + ax.transData entity_polygons[i_entity].set_transform(tra) else: