diff --git a/commands.md b/commands.md index f1647c156f117969da9fc5363c27d262c9eeb800..ec78dce2ec0cfb32d3007e1c1920e0b1a86172f0 100644 --- a/commands.md +++ b/commands.md @@ -4,3 +4,6 @@ pdoc --html robofish.io robofish.evaluate --html-dir docs --force ## Code coverage: pytest --cov=src --cov-report=html + +## Flake +flake8 --ignore E203 --max-line-length 88 diff --git a/pyproject.toml b/pyproject.toml index 57e73c268595219f619cc4da8846d349ab426f21..139acbc7f3d8ee316285ade2c7c1ef2d9efe4eed 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,7 @@ [build-system] # Minimum requirements for the build system to execute. -requires = ["setuptools", "wheel", "pytest", "pandas", "deprecation"] # PEP 508 specifications. +requires = ["setuptools", "wheel"] # PEP 508 specifications. + +[flake8] +max-line-length = 88 +extend-ignore = "E203" diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py index db339f14f3ef5cbaea23da1cf3ccffa09809f804..6670645a03ce8a6460a6e4febbbc63b3c8e32e74 100644 --- a/src/robofish/evaluate/evaluate.py +++ b/src/robofish/evaluate/evaluate.py @@ -1,8 +1,6 @@ # -*- coding: utf-8 -*- -""" -Evaluation functions, to generate graphs from files. -""" +"""Evaluation functions, to generate graphs from files.""" # Feb 2021 Andreas Gerken, Marc Groeling Berlin, Germany # Released under GNU 3.0 License @@ -10,6 +8,8 @@ Evaluation functions, to generate graphs from files. # Last doku update Feb 2021 import robofish.io +from pathlib import Path + import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec import seaborn as sns @@ -20,29 +20,56 @@ from scipy import stats def get_all_poses_from_paths(paths: Iterable[str]): - """This function reads all poses from given paths. - - # Args: - # paths: An array of strings, with files or folders - # Returns: - # An array, containing poses with the shape [paths][files][entities, timesteps, 4] - #""" - # # Open all files, shape (paths, files) - # files_per_path = [robofish.io.read_multiple_files(p) for p in paths] - - # Read all poses from the files, shape (paths, files) - poses_per_path = [[f.entity_poses for f in files] for files in files_per_path] - - -# # close all files -# for p in files_per_path: -# for f in p: -# f.close() + """Read all poses from given paths. -# return poses_per_path + The function shall be used by the evaluation functions. + Args: + paths: An array of strings, with files or folders. + The files are checked to have the same frequency. + Returns: + An array, containing poses with the shape + [paths][files][entities, timesteps, 4], + the common frequency of the files + """ + # Open all files, shape (paths, files) + files_per_path = [robofish.io.read_multiple_files(p) for p in paths] -def evaluate_speed(paths, names=None, save_path=None, predicate=None): + # Read all poses from the files, shape (paths, files) + poses_per_path = [ + [f.entity_poses for path, f in files_dict.items()] + for files_dict in files_per_path + ] + + # close all files + frequencies = [] + for files_dict in files_per_path: + for path, f in files_dict.items(): + frequencies.append(f.frequency) + f.close() + + # Check that all frequencies are equal + assert np.std(frequencies) == 0 + + return poses_per_path, frequencies[0] + + +def evaluate_speed( + paths: Iterable[str], + labels: Iterable[str] = None, + save_path: str = None, + predicate=None, +): + """Evaluate the speed of the entities as histogram. + + Args: + paths: An array of strings, with files of folders. + labels: Labels for the paths. If no labels are given, the paths will + be used + save_path: A path to a save location. + predicate: a lambda function, selecting entities + (example: lambda e: e.category == "fish") + """ files_per_path = [robofish.io.read_multiple_files(p) for p in paths] speeds = [] for k, files in enumerate(files_per_path): @@ -57,10 +84,10 @@ def evaluate_speed(paths, names=None, save_path=None, predicate=None): path_speeds.extend(e_speeds) speeds.append(path_speeds) - if names is None: - names = paths + if labels is None: + labels = paths - plt.hist(speeds, bins=20, label=names, density=True, range=[0, 50]) + plt.hist(speeds, bins=20, label=labels, density=True, range=[0, 50]) plt.title("Agent speeds") plt.xlabel("Speed [cm/s]") plt.ylabel("Frequency") @@ -75,7 +102,22 @@ def evaluate_speed(paths, names=None, save_path=None, predicate=None): plt.close() -def evaluate_turn(paths, names=None, save_path=None, predicate=None): +def evaluate_turn( + paths: Iterable[str], + labels: Iterable[str] = None, + save_path: str = None, + predicate=None, +): + """Evaluate the turn angles of the entities as histogram. + + Args: + paths: An array of strings, with files of folders. + labels: Labels for the paths. If no labels are given, the paths will + be used + save_path: A path to a save location. + predicate: a lambda function, selecting entities + (example: lambda e: e.category == "fish") + """ files_per_path = [robofish.io.read_multiple_files(p) for p in paths] turns = [] for k, files in enumerate(files_per_path): @@ -91,7 +133,8 @@ def evaluate_turn(paths, names=None, save_path=None, predicate=None): for e_poses in poses: # convert ori_x, ori_y to radians ori_rad = np.arctan2(e_poses[:, 3], e_poses[:, 2]) - # calculate difference, also make it take the "shorter" turn (so that is in range of -pi, pi) + # calculate difference, also make it take the "shorter" + # turn (so that is in range of -pi, pi) e_turns = ori_rad[1:] - ori_rad[:-1] e_turns = np.where(e_turns < -np.pi, e_turns + 2 * np.pi, e_turns) e_turns = np.where(e_turns > np.pi, e_turns - 2 * np.pi, e_turns) @@ -100,11 +143,11 @@ def evaluate_turn(paths, names=None, save_path=None, predicate=None): path_turns.extend(e_turns) turns.append(path_turns) - if names is None: - names = paths + if labels is None: + labels = paths # TODO: Quantil range - plt.hist(turns, bins=40, label=names, density=True, range=[-30, 30]) + plt.hist(turns, bins=40, label=labels, density=True, range=[-30, 30]) plt.title("Agent turns") plt.xlabel("Change in orientation [Degree / timestep at %dhz]" % frequency) plt.ylabel("Frequency") @@ -119,7 +162,22 @@ def evaluate_turn(paths, names=None, save_path=None, predicate=None): plt.close() -def evaluate_orientation(paths, names=None, save_path=None, predicate=None): +def evaluate_orientation( + paths: Iterable[str], + labels: Iterable[str] = None, + save_path: str = None, + predicate=None, +): + """Evaluate the orientations of the entities on a 2d grid. + + Args: + paths: An array of strings, with files of folders. + labels: Labels for the paths. If no labels are given, the paths will + be used + save_path: A path to a save location. + predicate: a lambda function, selecting entities + (example: lambda e: e.category == "fish") + """ files_per_path = [robofish.io.read_multiple_files(p) for p in paths] orientations = [] for k, files in enumerate(files_per_path): @@ -149,18 +207,18 @@ def evaluate_orientation(paths, names=None, save_path=None, predicate=None): if len(orientations) == 1: ax = [ax] - if names is None: - names = paths + if labels is None: + labels = paths for i in range(len(orientations)): orientation = orientations[i] s_1, x_edges, y_edges, bnr = orientation[0] s_2, x_edges, y_edges, bnr = orientation[1] - if names is None: + if labels is None: ax[i].set_title("Mean orientation in tank") else: - ax[i].set_title("Mean orientation in tank (%s)" % names[i]) + ax[i].set_title("Mean orientation in tank (%s)" % labels[i]) ax[i].set_xlabel("x [cm]") ax[i].set_ylabel("y [cm]") @@ -174,7 +232,7 @@ def evaluate_orientation(paths, names=None, save_path=None, predicate=None): vmax=np.pi, cmap="twilight", ) - cbar = plt.colorbar(plot, ax=ax[i], pad=0.015, aspect=10) + # cbar = plt.colorbar(plot, ax=ax[i], pad=0.015, aspect=10) show_values(plot) if save_path is None: @@ -184,7 +242,22 @@ def evaluate_orientation(paths, names=None, save_path=None, predicate=None): plt.close() -def evaluate_relativeOrientation(paths, names=None, save_path=None, predicate=None): +def evaluate_relativeOrientation( + paths: Iterable[str], + labels: Iterable[str] = None, + save_path: str = None, + predicate=None, +): + """Evaluate the relative orientations of the entities as a histogram. + + Args: + paths: An array of strings, with files of folders. + labels: Labels for the paths. If no labels are given, the paths will + be used + save_path: A path to a save location. + predicate: a lambda function, selecting entities + (example: lambda e: e.category == "fish") + """ files_per_path = [robofish.io.read_multiple_files(p) for p in paths] orientations = [] for k, files in enumerate(files_per_path): @@ -204,10 +277,10 @@ def evaluate_relativeOrientation(paths, names=None, save_path=None, predicate=No path_orientations.extend(np.arctan2(ori_diff[1], ori_diff[0])) orientations.append(path_orientations) - if names is None: - names = paths + if labels is None: + labels = paths - plt.hist(orientations, bins=40, label=names, density=True, range=[0, np.pi]) + plt.hist(orientations, bins=40, label=labels, density=True, range=[0, np.pi]) plt.title("Relative orientation") plt.xlabel("orientation in radians") plt.ylabel("Frequency") @@ -222,9 +295,21 @@ def evaluate_relativeOrientation(paths, names=None, save_path=None, predicate=No plt.close() -def evaluate_distanceToWall(paths, names=None, save_path=None, predicate=None): - """ - only works for rectangular tanks +def evaluate_distanceToWall( + paths: Iterable[str], + labels: Iterable[str] = None, + save_path: str = None, + predicate=None, +): + """Evaluate the distances of the entities to the walls as a histogram. + + Args: + paths: An array of strings, with files of folders. + labels: Labels for the paths. If no labels are given, the paths will + be used + save_path: A path to a save location. + predicate: a lambda function, selecting entities + (example: lambda e: e.category == "fish") """ files_per_path = [robofish.io.read_multiple_files(p) for p in paths] distances = [] @@ -265,13 +350,13 @@ def evaluate_distanceToWall(paths, names=None, save_path=None, predicate=None): worldBoundsX, worldBoundsY = max(worldBoundsX), max(worldBoundsY) - if names is None: - names = paths + if labels is None: + labels = paths plt.hist( distances, bins=20, - label=names, + label=labels, density=True, range=[0, min(worldBoundsX / 2, worldBoundsY / 2)], ) @@ -289,10 +374,23 @@ def evaluate_distanceToWall(paths, names=None, save_path=None, predicate=None): plt.close() -def evaluate_tankpositions(paths, names=None, save_path=None, predicate=None): - """ - Heatmap of fishpositions - By Moritz Maxeiner +def evaluate_tankpositions( + paths: Iterable[str], + labels: Iterable[str] = None, + save_path: str = None, + predicate=None, +): + """Evaluate the positions of the entities as a heatmap. + + The original implementation is by Moritz Maxeiner. + + Args: + paths: An array of strings, with files of folders. + labels: Labels for the paths. If no labels are given, the paths will + be used + save_path: A path to a save location. + predicate: a lambda function, selecting entities + (example: lambda e: e.category == "fish") """ files_per_path = [robofish.io.read_multiple_files(p) for p in paths] x_pos, y_pos = [], [] @@ -313,11 +411,11 @@ def evaluate_tankpositions(paths, names=None, save_path=None, predicate=None): fig, ax = plt.subplots(1, len(x_pos), figsize=(8 * len(x_pos), 8)) if len(x_pos) == 1: ax = [ax] - if names is None: - names = paths + if labels is None: + labels = paths for i in range(len(x_pos)): - ax[i].set_title("Tankpositions (%s)" % names[i]) + ax[i].set_title("Tankpositions (%s)" % labels[i]) ax[i].set_xlim(-world_bounds[i][0] / 2, world_bounds[i][0] / 2) ax[i].set_ylim(-world_bounds[i][1] / 2, world_bounds[i][1] / 2) @@ -330,10 +428,23 @@ def evaluate_tankpositions(paths, names=None, save_path=None, predicate=None): plt.close() -def evaluate_trajectories(paths, names=None, save_path=None, predicate=None): - """ - trajectories of fishes - By Moritz Maxeiner +def evaluate_trajectories( + paths: Iterable[str], + labels: Iterable[str] = None, + save_path: str = None, + predicate=None, +): + """Evaluate the trajectories of the entities. + + The original implementation is by Moritz Maxeiner. + + Args: + paths: An array of strings, with files of folders. + labels: Labels for the paths. If no labels are given, the paths will + be used + save_path: A path to a save location. + predicate: a lambda function, selecting entities + (example: lambda e: e.category == "fish") """ files_per_path = [robofish.io.read_multiple_files(p) for p in paths] pos = [] @@ -359,15 +470,15 @@ def evaluate_trajectories(paths, names=None, save_path=None, predicate=None): fig, ax = plt.subplots(1, len(pos), figsize=(len(pos) * 8, 8)) if len(pos) == 1: ax = [ax] - if names is None: - names = paths + if labels is None: + labels = paths for i in range(len(pos)): sns.set_style("white", {"axes.linewidth": 2, "axes.edgecolor": "black"}) sns.scatterplot( x="x", y="y", hue="Agent", linewidth=0, s=4, data=pos[i][1], ax=ax[i] ) - ax[i].set_title("Trajectories (%s)" % names[i]) + ax[i].set_title("Trajectories (%s)" % labels[i]) ax[i].set_xlim(-world_bounds[i][0] / 2, world_bounds[i][0] / 2) ax[i].set_ylim(-world_bounds[i][1] / 2, world_bounds[i][1] / 2) ax[i].invert_yaxis() @@ -400,7 +511,22 @@ def evaluate_trajectories(paths, names=None, save_path=None, predicate=None): plt.close() -def evaluate_positionVec(paths, names=None, save_path=None, predicate=None): +def evaluate_positionVec( + paths: Iterable[str], + labels: Iterable[str] = None, + save_path: str = None, + predicate=None, +): + """Evaluate the vectors pointing from the focal fish to the conspecifics as heatmap. + + Args: + paths: An array of strings, with files of folders. + labels: Labels for the paths. If no labels are given, the paths will + be used + save_path: A path to a save location. + predicate: a lambda function, selecting entities + (example: lambda e: e.category == "fish") + """ files_per_path = [robofish.io.read_multiple_files(p) for p in paths] posVec = [] worldBoundsX, worldBoundsY = [], [] @@ -442,7 +568,7 @@ def evaluate_positionVec(paths, names=None, save_path=None, predicate=None): fig = plt.figure(figsize=(8 * len(grids), 8)) - fig.suptitle("positionVector: from left to right: " + str(names), fontsize=16) + fig.suptitle("positionVector: from left to right: " + str(labels), fontsize=16) gs = gridspec.GridSpec(1, len(grids)) for i in range(len(grids)): @@ -455,7 +581,24 @@ def evaluate_positionVec(paths, names=None, save_path=None, predicate=None): plt.close() -def evaluate_follow_iid(paths, names=None, save_path=None, predicate=None): +def evaluate_follow_iid( + paths: Iterable[str], + labels: Iterable[str] = None, + save_path: str = None, + predicate=None, +): + """Evaluate the follow metric in respect to the inter individual distance (iid). + + The original implementation is from Moritz Maxeiner. + + Args: + paths: An array of strings, with files of folders. + labels: Labels for the paths. If no labels are given, the paths will + be used + save_path: A path to a save location. + predicate: a lambda function, selecting entities + (example: lambda e: e.category == "fish") + """ files_per_path = [robofish.io.read_multiple_files(p) for p in paths] follow, iid = [], [] worldBoundsX, worldBoundsY = [], [] @@ -510,7 +653,7 @@ def evaluate_follow_iid(paths, names=None, save_path=None, predicate=None): fig = plt.figure(figsize=(8 * len(grids), 8)) - fig.suptitle("follow/iid: from left to right: " + str(names), fontsize=16) + fig.suptitle("follow/iid: from left to right: " + str(labels), fontsize=16) gs = gridspec.GridSpec(1, len(grids)) for i in range(len(grids)): @@ -523,23 +666,43 @@ def evaluate_follow_iid(paths, names=None, save_path=None, predicate=None): plt.close() -def evaluate_all(paths, names=None, save_folder=None, predicate=None): - evaluate_speed(paths, names, save_folder + "speed.png", predicate) - evaluate_turn(paths, names, save_folder + "turn.png", predicate) - evaluate_orientation(paths, names, save_folder + "orientation.png", predicate) +def evaluate_all( + paths: Iterable[str], + labels: Iterable[str] = None, + save_folder: str = None, + predicate=None, +): + """Generate all evaluation graphs and save them to a folder. + + Args: + paths: An array of strings, with files of folders. + labels: Labels for the paths. If no labels are given, the paths will + be used + save_path: A path to a save location. + predicate: a lambda function, selecting entities + (example: lambda e: e.category == "fish") + """ + save_folder = Path(save_folder) + evaluate_speed(paths, labels, save_folder + "speed.png", predicate) + evaluate_turn(paths, labels, save_folder + "turn.png", predicate) + evaluate_orientation(paths, labels, save_folder + "orientation.png", predicate) evaluate_relativeOrientation( - paths, names, save_folder + "relativeOrientation.png", predicate + paths, labels, save_folder + "relativeOrientation.png", predicate + ) + evaluate_distanceToWall( + paths, labels, save_folder + "distanceToWall.png", predicate ) - evaluate_distanceToWall(paths, names, save_folder + "distanceToWall.png", predicate) - evaluate_tankpositions(paths, names, save_folder + "tankpositions.png", predicate) - evaluate_trajectories(paths, names, save_folder + "trajectories.png", predicate) - evaluate_positionVec(paths, names, save_folder + "posVec.png", predicate) - evaluate_follow_iid(paths, names, save_folder + "follow_iid.png", predicate) + evaluate_tankpositions(paths, labels, save_folder + "tankpositions.png", predicate) + evaluate_trajectories(paths, labels, save_folder + "trajectories.png", predicate) + evaluate_positionVec(paths, labels, save_folder + "posVec.png", predicate) + evaluate_follow_iid(paths, labels, save_folder + "follow_iid.png", predicate) def calculate_follow(a, b): - """ - Given two series of poses - with X and Y coordinates of their positions as the first two elements - + """Calculate the follow metric. + + Given two series of poses - with X and Y coordinates of their + positions as the first two elements return the follow metric from the first to the second series. """ a_v = a[1:, :2] - a[:-1, :2] @@ -548,15 +711,18 @@ def calculate_follow(a, b): def calculate_iid(a, b): - """ - Given two series of poses - with X and Y coordinates of their positions as the first two elements - + """Calculate the iid metric. + + Given two series of poses - with X and Y coordinates of their positions + as the first two elements return the inter-individual distance (between the positions). """ return np.linalg.norm(b[:, :2] - a[:, :2], axis=-1) def normalize_series(x): - """ + """Normalize series. + Given a series of vectors, return a series of normalized vectors. Null vectors are mapped to `NaN` vectors. """ @@ -564,9 +730,11 @@ def normalize_series(x): def calculate_posVec(data, angle): - """ - data should be of the form (n, (x1, y1, x2, y2)) and angle of the form (n, 1) - returns x, y distance from fish1 to fish2 with respect to the direction fish1 is facing (shape: (n, (x, y))) + """Calculate position vectors. + + Data should be of the form (n, (x1, y1, x2, y2)) and angle of the form (n, 1) + returns x, y distance from fish1 to fish2 with respect to the direction + fish1 is facing (shape: (n, (x, y))) """ # rotate axes by angle and adjust x,y of points temp = np.copy(data) @@ -579,10 +747,13 @@ def calculate_posVec(data, angle): def calculate_distLinePoint(x1, y1, x2, y2, points): - """ - computes distance between line (x1, y1, x2, y2) and points (np array of shape (n, 2), (x,y) on each row) + """Compute the distance between a line and points. + + Compute the distance between line (x1, y1, x2, y2) and points + (np array of shape (n, 2), (x,y) on each row) returns np array of shape (n, ) with corresponding distances - from: https://stackoverflow.com/questions/849211/shortest-distance-between-a-point-and-a-line-segment + from: https://stackoverflow.com/questions/849211/ + shortest-distance-between-a-point-and-a-line-segment theory: http://paulbourke.net/geometry/pointlineplane/ """ px = x2 - x1 @@ -607,9 +778,10 @@ def calculate_distLinePoint(x1, y1, x2, y2, points): def show_values(pc, fmt="%.2f", **kw): - """ - shows numbers on plt.ax.pccolormesh plot - https://stackoverflow.com/questions/25071968/heatmap-with-text-in-each-cell-with-matplotlibs-pyplot + """Show numbers on plt.ax.pccolormesh plot. + + https://stackoverflow.com/questions/25071968/ + heatmap-with-text-in-each-cell-with-matplotlibs-pyplot """ pc.update_scalarmappable() ax = pc.axes @@ -624,11 +796,14 @@ def show_values(pc, fmt="%.2f", **kw): class SeabornFig2Grid: - """ - copied from https://stackoverflow.com/questions/35042255/how-to-plot-multiple-seaborn-jointplot-in-subplot + """Seaborn Figure class. + + copied from https://stackoverflow.com/questions/35042255/ + how-to-plot-multiple-seaborn-jointplot-in-subplot """ def __init__(self, seaborngrid, fig, subplot_spec): + """Init function.""" self.fig = fig self.sg = seaborngrid self.subplot = subplot_spec @@ -641,7 +816,7 @@ class SeabornFig2Grid: self._finalize() def _movegrid(self): - """ Move PairGrid or Facetgrid """ + """Move PairGrid or Facetgrid.""" self._resize() n = self.sg.axes.shape[0] m = self.sg.axes.shape[1] @@ -651,7 +826,7 @@ class SeabornFig2Grid: self._moveaxes(self.sg.axes[i, j], self.subgrid[i, j]) def _movejointgrid(self): - """ Move Jointgrid """ + """Move Jointgrid.""" h = self.sg.ax_joint.get_position().height h2 = self.sg.ax_marg_x.get_position().height r = int(np.round(h / h2)) @@ -665,7 +840,11 @@ class SeabornFig2Grid: self._moveaxes(self.sg.ax_marg_y, self.subgrid[1:, -1]) def _moveaxes(self, ax, gs): - # https://stackoverflow.com/a/46906599/4124317 + """ + Move axes. + + https://stackoverflow.com/a/46906599/4124317 + """ ax.remove() ax.figure = self.fig self.fig.axes.append(ax) @@ -675,9 +854,11 @@ class SeabornFig2Grid: ax.set_subplotspec(gs) def _finalize(self): + """Finalize the graphics.""" plt.close(self.sg.fig) self.fig.canvas.mpl_connect("resize_event", self._resize) self.fig.canvas.draw() def _resize(self, evt=None): + """Resize the graphics.""" self.sg.fig.set_size_inches(self.fig.get_size_inches()) diff --git a/src/robofish/io/validation.py b/src/robofish/io/validation.py index 0e1ae54889dcbe9a674e19187ff6128d33889c84..8ea10175c7fed8d0e4a813856290116d86d83b54 100644 --- a/src/robofish/io/validation.py +++ b/src/robofish/io/validation.py @@ -334,13 +334,18 @@ def validate_positions_range(world_size, positions, e_name): def validate_orientations_length(orientations, e_name): ori_lengths = np.linalg.norm(orientations, axis=1) + # import matplotlib.pyplot as plt + # + # plt.plot(ori_lengths) + # plt.show() + # Check if all orientation lengths are all 1. Different lengths cause warnings. assert_validate( np.isclose(ori_lengths, 1).all(), - "The orientation vectors were not unit vectors. Their length was in the range [%.2f, %.2f] when it should be 1" + "The orientation vectors were not unit vectors. Their length was in the range [%.4f, %.4f] when it should be 1" % (min(ori_lengths), max(ori_lengths)), e_name, - strict_validate=False, + strict_validate=True, ) diff --git a/tests/robofish/evaluate/test_evaluate.py b/tests/robofish/evaluate/test_evaluate.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..65aa7271cff6a82026cf38620a1fd68220656446 100644 --- a/tests/robofish/evaluate/test_evaluate.py +++ b/tests/robofish/evaluate/test_evaluate.py @@ -0,0 +1,13 @@ +import pytest +import robofish.evaluate +from robofish.io import utils +import numpy as np + + +def test_get_all_poses_from_paths(): + valid_file_path = utils.full_path(__file__, "../../resources/valid.hdf5") + poses, frequency = robofish.evaluate.get_all_poses_from_paths([valid_file_path]) + + # (1 input array, 1 file, 2 fishes, 100 timesteps, 4 poses) + assert np.array(poses).shape == (1, 1, 2, 100, 4) + assert frequency == 25 diff --git a/tests/robofish/io/test_examples.py b/tests/robofish/io/test_examples.py index c81b1eed99221e9d92a82f0d354352fa869cfbae..82ea16dd4153f2449a52be691d457d3bf25df827 100644 --- a/tests/robofish/io/test_examples.py +++ b/tests/robofish/io/test_examples.py @@ -1,7 +1,7 @@ import robofish.io from robofish.io import utils from pathlib import Path -from testbook import testbook + import sys @@ -29,6 +29,8 @@ def test_example_basic(): # This test can be executed manually. The CI/CD System has issues with testbook. def manual_test_example_basic_ipynb(): + from testbook import testbook + # Executing the notebook should not lead to an exception with testbook(str(ipynb_path), execute=True) as tb: pass