diff --git a/src/robofish/evaluate/app.py b/src/robofish/evaluate/app.py index a73fe46fcd99d3791b02bdb56e9d02ceffaf5076..00e5c5488caa9ee77bdc3694fccf93df79fbe0f3 100644 --- a/src/robofish/evaluate/app.py +++ b/src/robofish/evaluate/app.py @@ -22,13 +22,13 @@ def function_dict(): "turn": base.evaluate_turn, "orientation": base.evaluate_orientation, "relative_orientation": base.evaluate_relative_orientation, - "distance_to_wall": base.evaluate_distanceToWall, - "tank_positions": base.evaluate_tankpositions, + "distance_to_wall": base.evaluate_distance_to_wall, + "tank_position": base.evaluate_tank_position, "tracks": base.evaluate_tracks, "tracks_distance": base.evaluate_tracks_distance, - "socialVec": base.evaluate_socialVec, + "social_vector": base.evaluate_social_vector, "follow_iid": base.evaluate_follow_iid, - "individual_speeds": base.evaluate_individual_speeds, + "individual_speed": base.evaluate_individual_speed, "individual_iid": base.evaluate_individual_iid, "all": base.evaluate_all, } diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py index 83b449dfdb0f42287864acc8014cbe77943e1732..c5f2982dae2a863a2c289ff2378b0f644c41a894 100644 --- a/src/robofish/evaluate/evaluate.py +++ b/src/robofish/evaluate/evaluate.py @@ -2,13 +2,16 @@ """Evaluation functions, to generate graphs from files.""" -# Feb 2021 Andreas Gerken, Marc Groeling Berlin, Germany +# Feb 2022 Andreas Gerken Berlin, Germany # Released under GNU 3.0 License # email andi.gerken@gmail.com -# Last doku update Feb 2021 +from turtle import speed import robofish.io import robofish.evaluate +from robofish.io import utils + +from tkinter import N from pathlib import Path import matplotlib.pyplot as plt @@ -16,13 +19,28 @@ import matplotlib.gridspec as gridspec import seaborn as sns import numpy as np import pandas as pd -from typing import Iterable +from typing import Iterable, Union from scipy import stats from tqdm import tqdm +import inspect import random -def get_all_poses_from_paths(paths: Iterable[str]): +def get_all_files_from_paths(paths: Iterable[Union[str, Path]]): + # Find all files with correct ending + files = [] + for path in [Path(p) for p in paths]: + if path.is_dir(): + files_path = [] + for ext in ("hdf", "hdf5", "h5", "he5"): + files_path += list(path.rglob(f"*.{ext}")) + files.append(files_path) + else: + files.append([path]) + return files + + +def get_all_poses_from_paths(paths: Iterable[Union[str, Path]], predicate=None): """Read all poses from given paths. The function shall be used by the evaluation functions. @@ -35,32 +53,56 @@ def get_all_poses_from_paths(paths: Iterable[str]): [paths][files][entities, timesteps, 4], the common frequency of the files """ - # Open all files, shape (paths, files) - files_per_path = [robofish.io.read_multiple_files(p) for p in paths] + return get_all_data_from_paths(paths, "poses_4d", predicate) - # Read all poses from the files, shape (paths, files) - poses_per_path = [ - [f.entity_poses for path, f in files_dict.items()] - for files_dict in files_per_path - ] - # close all files and collect the frequencies - frequencies = [] - for files_dict in files_per_path: - for path, f in files_dict.items(): - frequencies.append(f.frequency) - f.close() +def get_all_data_from_paths( + paths: Iterable[Union[str, Path]], request_type="poses_4d", predicate=None +): + expected_settings = None + all_data = [] + + files_per_path = get_all_files_from_paths(paths) + + # for each given path + for files_in_path in files_per_path: + data_from_files = [] + + # Open all files, gather the poses and check if all of them have the same world size and frequency + for i_path, file_path in enumerate(files_in_path): + with robofish.io.File(file_path, "r") as file: + file_settings = { + "world_size_cm_x": file.attrs["world_size_cm"][0], + "world_size_cm_y": file.attrs["world_size_cm"][1], + "frequency_hz": file.frequency, + } + if expected_settings is None: + expected_settings = file_settings + + assert file_settings == expected_settings + + properties = { + "poses_4d": robofish.io.Entity.poses, + "speeds_turns": robofish.io.Entity.speed_turn, + } + + pred = None if predicate is None else predicate[i_path] + data = file.select_entity_property( + pred, entity_property=properties[request_type] + ) + data_from_files.append(data) - # Check that all frequencies are equal - assert np.std(frequencies) == 0 + all_data.append(data_from_files) - return poses_per_path, frequencies[0] + return all_data, expected_settings def evaluate_speed( - paths: Iterable[str], + paths: Iterable[Union[str, Path]], labels: Iterable[str] = None, predicate=None, + speeds_turns_from_paths=None, + file_settings=None, ): """Evaluate the speed of the entities as histogram. @@ -71,21 +113,23 @@ def evaluate_speed( predicate: a lambda function, selecting entities (example: lambda e: e.category == "fish") """ - files_per_path = [robofish.io.read_multiple_files(p) for p in paths] + + if speeds_turns_from_paths is None: + speeds_turns_from_paths, file_settings = get_all_data_from_paths( + paths, "speeds_turns" + ) + speeds = [] - left_quantiles = [] - right_quantiles = [] - frequency = None - for k, files in enumerate(files_per_path): + # Iterate all paths + for speeds_turns_per_path in speeds_turns_from_paths: path_speeds = [] - for p, file in files.items(): - assert frequency is None or frequency == file.frequency - frequency = file.frequency + left_quantiles, right_quantiles = [], [] - for e_speeds_turns in file.entity_actions_speeds_turns: - path_speeds.extend(e_speeds_turns[:, 0] * frequency) - file.close() + # Iterate all files + for speeds_turns in speeds_turns_per_path: + for e_speeds_turns in speeds_turns: + path_speeds.extend(np.rad2deg(e_speeds_turns[:, 0])) # Exclude possible nans path_speeds = np.array(path_speeds) @@ -98,7 +142,7 @@ def evaluate_speed( fig = plt.figure() plt.hist( list(speeds), - bins=20, + bins=30, label=labels, density=True, range=[min(left_quantiles), max(right_quantiles)], @@ -113,9 +157,11 @@ def evaluate_speed( def evaluate_turn( - paths: Iterable[str], + paths: Iterable[Union[str, Path]], labels: Iterable[str] = None, predicate=None, + speeds_turns_from_paths=None, + file_settings=None, ): """Evaluate the turn angles of the entities as histogram. @@ -126,21 +172,23 @@ def evaluate_turn( predicate: a lambda function, selecting entities (example: lambda e: e.category == "fish") """ - files_per_path = [robofish.io.read_multiple_files(p) for p in paths] - turns = [] - frequency = None - for k, files in enumerate(files_per_path): + if speeds_turns_from_paths is None: + speeds_turns_from_paths, file_settings = get_all_data_from_paths( + paths, "speeds_turns" + ) + + turns = [] + # Iterate all paths + for speeds_turns_per_path in speeds_turns_from_paths: path_turns = [] left_quantiles, right_quantiles = [], [] - for p, file in files.items(): - assert frequency is None or frequency == file.frequency - frequency = file.frequency - for e_speeds_turns in file.entity_actions_speeds_turns: + # Iterate all files + for speeds_turns in speeds_turns_per_path: + for e_speeds_turns in speeds_turns: path_turns.extend(np.rad2deg(e_speeds_turns[:, 1])) - file.close() # Exclude possible nans path_turns = np.array(path_turns) @@ -159,7 +207,9 @@ def evaluate_turn( range=[min(left_quantiles), max(right_quantiles)], ) plt.title("Agent turns") - plt.xlabel("Change in orientation [Degree / timestep at %dhz]" % frequency) + plt.xlabel( + f"Change in orientation [Degree / timestep at {file_settings['frequency_hz']} hz" + ) plt.ylabel("Frequency") plt.ticklabel_format(useOffset=False, style="plain") plt.legend() @@ -169,9 +219,11 @@ def evaluate_turn( def evaluate_orientation( - paths: Iterable[str], + paths: Iterable[Union[str, Path]], labels: Iterable[str] = None, predicate=None, + poses_from_paths=None, + file_settings=None, ): """Evaluate the orientations of the entities on a 2d grid. @@ -182,30 +234,41 @@ def evaluate_orientation( predicate: a lambda function, selecting entities (example: lambda e: e.category == "fish") """ - files_per_path = [robofish.io.read_multiple_files(p) for p in paths] + if poses_from_paths is None: + poses_from_paths, file_settings = get_all_poses_from_paths(paths) + + world_bounds = [ + -file_settings["world_size_cm_x"] / 2, + -file_settings["world_size_cm_y"] / 2, + file_settings["world_size_cm_x"] / 2, + file_settings["world_size_cm_y"] / 2, + ] + orientations = [] - for k, files in enumerate(files_per_path): - for p, file in files.items(): - poses = file.select_entity_poses( - None if predicate is None else predicate[k] - ) - poses = poses.reshape((len(poses) * len(poses[0]), 4)) - world_size = file.attrs["world_size_cm"] - world_bounds = [ - -world_size[0] / 2, - -world_size[1] / 2, - world_size[0] / 2, - world_size[1] / 2, - ] + # Iterate all paths + for poses_per_path in poses_from_paths: + + # Iterate all files + for poses in poses_per_path: + reshaped_poses = poses.reshape((-1, 4)) + xbins = np.linspace(world_bounds[0], world_bounds[2], 11) ybins = np.linspace(world_bounds[1], world_bounds[3], 11) ret_1 = stats.binned_statistic_2d( - poses[:, 0], poses[:, 1], poses[:, 2], "mean", bins=[xbins, ybins] + reshaped_poses[:, 0], + reshaped_poses[:, 1], + reshaped_poses[:, 2], + "mean", + bins=[xbins, ybins], ) ret_2 = stats.binned_statistic_2d( - poses[:, 0], poses[:, 1], poses[:, 3], "mean", bins=[xbins, ybins] + reshaped_poses[:, 0], + reshaped_poses[:, 1], + reshaped_poses[:, 3], + "mean", + bins=[xbins, ybins], ) - file.close() + orientations.append((ret_1, ret_2)) fig, ax = plt.subplots(1, len(orientations), figsize=(8 * len(orientations), 8)) @@ -238,9 +301,11 @@ def evaluate_orientation( def evaluate_relative_orientation( - paths: Iterable[str], + paths: Iterable[Union[str, Path]], labels: Iterable[str] = None, predicate=None, + poses_from_paths=None, + file_settings=None, ): """Evaluate the relative orientations of the entities as a histogram. @@ -252,28 +317,30 @@ def evaluate_relative_orientation( predicate: a lambda function, selecting entities (example: lambda e: e.category == "fish") """ - files_per_path = [robofish.io.read_multiple_files(p) for p in paths] + + if poses_from_paths is None: + poses_from_paths, file_settings = get_all_poses_from_paths(paths) + orientations = [] - for k, files in enumerate(files_per_path): + # Iterate all paths + for poses_per_path in poses_from_paths: path_orientations = [] - for p, file in files.items(): - poses = file.select_entity_poses( - None if predicate is None else predicate[k] - ) + # Iterate all files + for poses in poses_per_path: for i in range(len(poses)): + angle1 = np.arctan2(poses[i, :, 3], poses[i, :, 2]) for j in range(len(poses)): if i != j: - ori_diff = ( - poses[j, :, 2] - poses[i, :, 2], - poses[j, :, 3] - poses[i, :, 3], + angle2 = np.arctan2(poses[j, :, 3], poses[i, :, 2]) + path_orientations.extend( + utils.limit_angle_range(angle1 - angle2) ) - path_orientations.extend(np.arctan2(ori_diff[1], ori_diff[0])) - file.close() + orientations.append(path_orientations) fig = plt.figure() - plt.hist(orientations, bins=40, label=labels, density=True, range=[0, np.pi]) + plt.hist(orientations, bins=40, label=labels, density=True, range=[-np.pi, np.pi]) plt.title("Relative orientation") plt.xlabel("orientation in radians") plt.ylabel("Frequency") @@ -284,10 +351,12 @@ def evaluate_relative_orientation( return fig -def evaluate_distanceToWall( - paths: Iterable[str], +def evaluate_distance_to_wall( + paths: Iterable[Union[str, Path]], labels: Iterable[str] = None, predicate=None, + poses_from_paths=None, + file_settings=None, ): """Evaluate the distances of the entities to the walls as a histogram. @@ -298,30 +367,32 @@ def evaluate_distanceToWall( predicate: a lambda function, selecting entities (example: lambda e: e.category == "fish") """ - files_per_path = [robofish.io.read_multiple_files(p) for p in paths] + + if poses_from_paths is None: + poses_from_paths, file_settings = get_all_poses_from_paths(paths) + + world_bounds = [ + -file_settings["world_size_cm_x"] / 2, + -file_settings["world_size_cm_y"] / 2, + file_settings["world_size_cm_x"] / 2, + file_settings["world_size_cm_y"] / 2, + ] + wall_lines = [ + (world_bounds[0], world_bounds[1], world_bounds[0], world_bounds[3]), + (world_bounds[0], world_bounds[1], world_bounds[2], world_bounds[1]), + (world_bounds[2], world_bounds[3], world_bounds[2], world_bounds[1]), + (world_bounds[2], world_bounds[3], world_bounds[0], world_bounds[3]), + ] + distances = [] - worldBoundsX, worldBoundsY = [], [] - for k, files in enumerate(files_per_path): + + # Iterate all paths + for poses_per_path in poses_from_paths: path_distances = [] - for p, file in files.items(): - worldBoundsX.append(file.attrs["world_size_cm"][0]) - worldBoundsY.append(file.attrs["world_size_cm"][1]) - poses = file.select_entity_poses( - None if predicate is None else predicate[k] - ) - world_size = file.attrs["world_size_cm"] - world_bounds = [ - -world_size[0] / 2, - -world_size[1] / 2, - world_size[0] / 2, - world_size[1] / 2, - ] - wall_lines = [ - (world_bounds[0], world_bounds[1], world_bounds[0], world_bounds[3]), - (world_bounds[0], world_bounds[1], world_bounds[2], world_bounds[1]), - (world_bounds[2], world_bounds[3], world_bounds[2], world_bounds[1]), - (world_bounds[2], world_bounds[3], world_bounds[0], world_bounds[3]), - ] + + # Iterate all files + for poses in poses_per_path: + for e_poses in poses: dist = [] for wall in wall_lines: @@ -333,18 +404,24 @@ def evaluate_distanceToWall( # use distance only from closest wall dist = np.stack(dist).min(axis=0) path_distances.extend(dist) - file.close() - distances.append(path_distances) - worldBoundsX, worldBoundsY = max(worldBoundsX), max(worldBoundsY) + distances.append(path_distances) fig = plt.figure() + plt.hist( distances, - bins=20, + bins=30, label=labels, density=True, - range=[0, min(worldBoundsX / 2, worldBoundsY / 2)], + range=[ + 0, + # Maximum distance is always half of the shorter side. + min( + file_settings["world_size_cm_x"] / 2, + file_settings["world_size_cm_y"] / 2, + ), + ], ) plt.title("Distance to closest wall") plt.xlabel("Distance [cm]") @@ -357,15 +434,15 @@ def evaluate_distanceToWall( return fig -def evaluate_tankpositions( - paths: Iterable[str], +def evaluate_tank_position( + paths: Iterable[Union[str, Path]], labels: Iterable[str] = None, predicate=None, + poses_from_paths=None, + file_settings=None, ): """Evaluate the positions of the entities as a heatmap. - The original implementation is by Moritz Maxeiner. - Args: paths: An array of strings, with files of folders. labels: Labels for the paths. If no labels are given, the paths will @@ -375,53 +452,56 @@ def evaluate_tankpositions( """ # Parameter to set how many steps are skipped in the kde plot. - poses_step = 5 - - files_per_path = [robofish.io.read_multiple_files(p) for p in paths] - x_pos, y_pos = [], [] - world_bounds = [] - for k, files in enumerate(files_per_path): - path_x_pos, path_y_pos = [], [] - for p, file in files.items(): - poses = file.select_entity_poses( - None if predicate is None else predicate[k] - ) - world_bounds.append(file.attrs["world_size_cm"]) - for e_poses in poses: - path_x_pos.extend(e_poses[::poses_step, 0]) - path_y_pos.extend(e_poses[::poses_step, 1]) - file.close() + poses_step = 20 - # Exclude possible nans - path_x_pos = np.array(path_x_pos) - path_y_pos = np.array(path_y_pos) - path_x_pos = path_x_pos[~np.isnan(path_x_pos)] - path_y_pos = path_y_pos[~np.isnan(path_y_pos)] + if poses_from_paths is None: + poses_from_paths, file_settings = get_all_poses_from_paths(paths) - x_pos.append(path_x_pos) - y_pos.append(path_y_pos) + xy_positions = [] + + # Iterate all paths + for poses_per_path in poses_from_paths: + + # Get all positions for each file (skipping steps with poses_step), flatten them to (..., 2) and concatenate all positions. + new_xy_positions = np.concatenate( + [p[:, ::poses_step, :2].reshape(-1, 2) for p in poses_per_path], axis=0 + ) + # Exclude possible nans + xy_positions.append(new_xy_positions[~np.isnan(new_xy_positions).any(axis=1)]) - fig, ax = plt.subplots(1, len(x_pos), figsize=(8 * len(x_pos), 8)) - if len(x_pos) == 1: + fig, ax = plt.subplots(1, len(xy_positions), figsize=(8 * len(xy_positions), 8)) + if len(xy_positions) == 1: ax = [ax] - for i in range(len(x_pos)): + for i in range(len(xy_positions)): ax[i].set_title("Tankpositions (%s)" % labels[i]) - ax[i].set_xlim(-world_bounds[i][0] / 2, world_bounds[i][0] / 2) - ax[i].set_ylim(-world_bounds[i][1] / 2, world_bounds[i][1] / 2) + ax[i].set_xlim( + -file_settings["world_size_cm_x"] / 2, file_settings["world_size_cm_x"] / 2 + ) + ax[i].set_ylim( + -file_settings["world_size_cm_y"] / 2, file_settings["world_size_cm_y"] / 2 + ) ax[i].set_xlabel("x [cm]") ax[i].set_ylabel("y [cm]") - sns.kdeplot(x=x_pos[i], y=y_pos[i], n_levels=20, shade=True, ax=ax[i]) + sns.kdeplot( + x=xy_positions[i][:, 0], + y=xy_positions[i][:, 1], + n_levels=20, + shade=True, + ax=ax[i], + ) return fig -def evaluate_socialVec( - paths: Iterable[str], +def evaluate_social_vector( + paths: Iterable[Union[str, Path]], labels: Iterable[str] = None, predicate=None, + poses_from_paths=None, + file_settings=None, ): """Evaluate the vectors pointing from the focal fish to the conspecifics as heatmap. @@ -432,23 +512,18 @@ def evaluate_socialVec( predicate: a lambda function, selecting entities (example: lambda e: e.category == "fish") """ - files_per_path = [robofish.io.read_multiple_files(p) for p in paths] + + if poses_from_paths is None: + poses_from_paths, file_settings = get_all_poses_from_paths(paths) + socialVec = [] - worldBoundsX, worldBoundsY = [], [] - for k, files in enumerate(files_per_path): + + # Iterate all paths + for poses_per_path in poses_from_paths: path_socialVec = [] - for p, file in files.items(): - worldBoundsX.append(file.attrs["world_size_cm"][0]) - worldBoundsY.append(file.attrs["world_size_cm"][1]) - poses = file.select_entity_poses( - None if predicate is None else predicate[k] - ) - if poses.shape[0] != 2: - print( - "The positionVec can only be calculated when there are exactly 2 fish." - ) - return + # Iterate all files + for poses in poses_per_path: # calculate socialVec for every fish combination for i in range(len(poses)): for j in range(len(poses)): @@ -463,7 +538,6 @@ def evaluate_socialVec( np.arctan2(poses[i, :, 3], poses[i, :, 2]), ) ) - file.close() # Concatenate and exclude possible nans path_socialVec = np.concatenate(path_socialVec, axis=0) @@ -472,7 +546,6 @@ def evaluate_socialVec( socialVec.append(path_socialVec) grids = [] - worldBoundsX, worldBoundsY = max(worldBoundsX), max(worldBoundsY) for i in range(len(socialVec)): df = pd.DataFrame({"x": socialVec[i][:, 0], "y": socialVec[i][:, 1]}) @@ -480,8 +553,8 @@ def evaluate_socialVec( grid.axes[0, 0].set_xlabel("x [cm]") grid.axes[0, 0].set_ylabel("y [cm]") # Limits set by educated guesses. If it doesnt work for your data adjust it. - grid.set(xlim=(-worldBoundsX / 7.0, worldBoundsX / 7.0)) - grid.set(ylim=(-worldBoundsY / 7.0, worldBoundsY / 7.0)) + grid.set(xlim=(-10, 10)) + grid.set(ylim=(-10, 10)) grids.append(grid) fig = plt.figure(figsize=(8 * len(grids), 8)) @@ -496,9 +569,11 @@ def evaluate_socialVec( def evaluate_follow_iid( - paths: Iterable[str], + paths: Iterable[Union[str, Path]], labels: Iterable[str] = None, predicate=None, + poses_from_paths=None, + file_settings=None, ): """Evaluate the follow metric in respect to the inter individual distance (iid). @@ -511,40 +586,32 @@ def evaluate_follow_iid( predicate: a lambda function, selecting entities (example: lambda e: e.category == "fish") """ - files_per_path = [robofish.io.read_multiple_files(p) for p in paths] + if poses_from_paths is None: + poses_from_paths, file_settings = get_all_poses_from_paths(paths) + follow, iid = [], [] - worldBoundsX, worldBoundsY = [], [] - for k, files in enumerate(files_per_path): - path_follow, path_iid = [], [] - for p, file in files.items(): - worldBoundsX.append(file.attrs["world_size_cm"][0]) - worldBoundsY.append(file.attrs["world_size_cm"][1]) - poses = file.select_entity_poses( - None if predicate is None else predicate[k] - ) - if poses.shape[0] < 2: - print( - "The FollowIID can only be calculated when there are at least 2 fish." - ) - return + # Iterate all paths + for poses_per_path in poses_from_paths: + path_iid, path_follow = [], [] + # Iterate all files + for poses in poses_per_path: for i in range(len(poses)): for j in range(len(poses)): if i != j and (poses[i] != poses[j]).any(): path_iid.append( - calculate_iid(poses[i, :-1, 0:2], poses[j, :-1, 0:2]) + calculate_iid(poses[i, :-1, :2], poses[j, :-1, :2]) ) path_follow.append( - calculate_follow(poses[i, :, 0:2], poses[j, :, 0:2]) + calculate_follow(poses[i, :, :2], poses[j, :, :2]) ) - file.close() + follow.append(np.array(path_follow)) iid.append(np.array(path_iid)) grids = [] - worldBoundsX, worldBoundsY = max(worldBoundsX), max(worldBoundsY) follow_flat = np.concatenate([f.flatten() for f in follow]) iid_flat = np.concatenate([i.flatten() for i in iid]) @@ -602,10 +669,10 @@ def evaluate_follow_iid( def evaluate_tracks_distance( - paths: Iterable[str], + paths: Iterable[Union[str, Path]], labels: Iterable[str] = None, predicate=None, - max_timesteps=None, + max_timesteps=1000, ): """Evaluate the distances of two or more fish on the track. @@ -622,12 +689,14 @@ def evaluate_tracks_distance( def evaluate_tracks( - paths: Iterable[str], + paths: Iterable[Union[str, Path]], labels: Iterable[str] = None, predicate=None, + poses_from_paths=None, + file_settings=None, lw_distances=False, seed=42, - max_timesteps=None, + max_timesteps=1000, ): """Evaluate the track. @@ -641,7 +710,8 @@ def evaluate_tracks( random.seed(seed) - files_per_path = [robofish.io.read_multiple_files(p) for p in paths] + files_per_path = get_all_files_from_paths(paths) + max_files_per_path = max([len(files) for files in files_per_path]) rows, cols = len(files_per_path), min(4, max_files_per_path) @@ -652,34 +722,34 @@ def evaluate_tracks( fig, ax = plt.subplots(rows, cols, figsize=(cols * 4, rows * 4), squeeze=False) - for k, files in enumerate(files_per_path): - file_paths = list(files.keys()) - random.shuffle(file_paths) - for i, path in enumerate(file_paths): - file = files[path] - if multirow: - if i >= cols * rows: - break + # Iterate all paths + for k, files_in_path in enumerate(files_per_path): + + random.shuffle(files_in_path) + # Iterate all files + for i, file_path in enumerate(files_in_path): + with robofish.io.File(file_path, "r") as file: + if multirow: + if i >= cols * rows: + break + selected_ax = ax[i // cols][i % cols] + else: + if i >= cols: + break + selected_ax = ax[k][i] + file.plot( - ax[i // cols][i % cols], + selected_ax, lw_distances=lw_distances, max_timesteps=max_timesteps, ) - else: - if i >= cols: - break - file.plot( - ax[k][i], lw_distances=lw_distances, max_timesteps=max_timesteps - ) - - file.close() plt.tight_layout() return fig -def evaluate_individual_speeds(**kwargs): +def evaluate_individual_speed(**kwargs): """Evaluate the average speeds per individual""" return evaluate_individuals(mode="speed", **kwargs) @@ -691,8 +761,11 @@ def evaluate_individual_iid(**kwargs): def evaluate_individuals( mode: str, - paths: Iterable[str], + paths: Iterable[Union[str, Path]], labels: Iterable[str] = None, + speeds_turns_from_paths=None, + poses_from_paths=None, + file_settings=None, predicate=None, lw_distances=False, threshold=7, @@ -708,23 +781,32 @@ def evaluate_individuals( (example: lambda e: e.category == "fish") """ - files_per_path = [robofish.io.read_multiple_files(p) for p in paths] + if speeds_turns_from_paths is None and mode == "speed": + speeds_turns_from_paths, file_settings = get_all_data_from_paths( + paths, "speeds_turns" + ) + if poses_from_paths is None and mode == "iid": + poses_from_paths, file_settings = get_all_data_from_paths(paths, "poses_4d") + + files_from_paths = get_all_files_from_paths(paths) fig = plt.figure(figsize=(10, 4)) # small_iid_files = [] offset = 0 - for k, files in enumerate(files_per_path): + for k, files_in_paths in enumerate(files_from_paths): all_avg = [] all_std = [] - for path, file in files.items(): + for f, file_path in enumerate(files_in_paths): if mode == "speed": - metric = file.entity_actions_speeds_turns[..., 0] * file.frequency + metric = ( + speeds_turns_from_paths[k][f][:, 0] * file_settings["frequency_hz"] + ) elif mode == "iid": - poses = file.entity_poses_rad + poses = poses_from_paths[k][f] if poses.shape[0] != 2: print( - "The evaluate_individual_iid function only works when there are exactly 2 individuals." + "The evaluate_individual_iid function only works when there are exactly 2 individuals. If you need it for multiple agents please implement it." ) return metric = calculate_iid(poses[0], poses[1])[None] @@ -734,7 +816,7 @@ def evaluate_individuals( all_avg.append(np.nanmean(metric, axis=1)) all_std.append(np.nanstd(metric, axis=1)) - file.close() + all_avg = np.concatenate(all_avg, axis=0) all_std = np.concatenate(all_std, axis=0) individuals = all_avg.shape[0] @@ -783,7 +865,7 @@ def evaluate_individuals( def evaluate_all( - paths: Iterable[str], + paths: Iterable[Union[str, Path]], labels: Iterable[str] = None, save_folder: Path = None, fdict: dict = None, @@ -813,12 +895,32 @@ def evaluate_all( fdict = robofish.evaluate.app.function_dict() fdict.pop("all") + print("Loading all poses and actions.") + poses_from_paths, file_settings = get_all_poses_from_paths(paths) + speeds_turns_from_paths, file_settings = get_all_data_from_paths( + paths, "speeds_turns" + ) + + input_dict = { + "poses_from_paths": poses_from_paths, + "speeds_turns_from_paths": speeds_turns_from_paths, + "file_settings": file_settings, + } + t = tqdm(fdict.items(), desc="Evaluation", leave=True) for f_name, f_callable in t: t.set_description(f_name) - t.refresh() # to show immediately the update + t.refresh() # to show the update immediately save_path = save_folder / (f_name + ".png") - fig = f_callable(paths=paths, labels=labels, predicate=predicate) + + requested_inputs = { + k: input_dict[k] + for k in inspect.signature(f_callable).parameters.keys() + if k in input_dict + } + fig = f_callable( + paths=paths, labels=labels, predicate=predicate, **requested_inputs + ) if fig is not None: fig.savefig(save_path) plt.close(fig) diff --git a/src/robofish/io/app.py b/src/robofish/io/app.py index a9344741f866ff3a6ad1fb4f85d4f44cef1c1161..d4634948dd740276eb6d910f86778db1188883a8 100644 --- a/src/robofish/io/app.py +++ b/src/robofish/io/app.py @@ -37,12 +37,11 @@ def print_file(args=None): if args is None: args = parser.parse_args() - sf = robofish.io.File(path=args.path, strict_validate=False) - - print(sf.to_string(args.output_format)) - print() - valid = sf.validate(strict_validate=False)[0] - print("Valid file" if valid else "Invalid file") + with robofish.io.File(path=args.path, strict_validate=False) as f: + print(f.to_string(args.output_format)) + print() + valid = f.validate(strict_validate=False)[0] + print("Valid file" if valid else "Invalid file") return not valid @@ -95,6 +94,7 @@ def validate(args=None): error_code = 0 for file, sf in sf_dict.items(): filled_file = (str)(file).ljust(max_filename_width + 3) + file.close() validity, validity_message = sf.validate(strict_validate=False) if not validity: error_code = 1 diff --git a/src/robofish/io/io.py b/src/robofish/io/io.py index 7645a44ca5d2292cbcaf089853557ab4d763305a..b1c41c692cf25932689df2f0c0a21cd07c8599fc 100644 --- a/src/robofish/io/io.py +++ b/src/robofish/io/io.py @@ -7,6 +7,7 @@ import logging import numpy as np import pandas import random +import deprecation def now_iso8061() -> str: @@ -20,6 +21,11 @@ def now_iso8061() -> str: ) +@deprecation.deprecated( + deprecated_in="0.2.10", + removed_in="0.3", + details="Loading all files first and then handling them is slow, memory intensive and inconvenient. Don't use this method.", +) def read_multiple_files( paths: Union[Path, str, Iterable[Path], Iterable[str]], strict_validate: bool = False, diff --git a/tests/robofish/evaluate/test_evaluate.py b/tests/robofish/evaluate/test_evaluate.py index 0476c18985fe065ee8f7566e0581461ed6666379..1059d1ee71881ac0507673517c12e2ba0986e0bc 100644 --- a/tests/robofish/evaluate/test_evaluate.py +++ b/tests/robofish/evaluate/test_evaluate.py @@ -6,8 +6,7 @@ import numpy as np def test_get_all_poses_from_paths(): valid_file_path = utils.full_path(__file__, "../../resources/valid_1.hdf5") - poses, frequency = robofish.evaluate.get_all_poses_from_paths([valid_file_path]) + poses, file_settings = robofish.evaluate.get_all_poses_from_paths([valid_file_path]) # (1 input array, 1 file, 2 fishes, 100 timesteps, 4 poses) assert np.array(poses).shape == (1, 1, 2, 100, 4) - assert frequency == 25