From 8feae667bc6c42c4d51e149cbce7e94f00eaac6b Mon Sep 17 00:00:00 2001 From: Andi <andi.gerken@gmail.com> Date: Tue, 25 Jul 2023 12:42:22 +0200 Subject: [PATCH] Added world shape --- src/conversion_scripts/convert_from_csv.py | 30 ++++++++++++------- src/robofish/evaluate/app.py | 35 ++++++++++++++++++++-- src/robofish/evaluate/evaluate.py | 35 ++++++++++++---------- src/robofish/io/file.py | 28 +++++++++++++---- src/robofish/io/validation.py | 7 +++++ 5 files changed, 101 insertions(+), 34 deletions(-) diff --git a/src/conversion_scripts/convert_from_csv.py b/src/conversion_scripts/convert_from_csv.py index 5ef4107..aadec8b 100644 --- a/src/conversion_scripts/convert_from_csv.py +++ b/src/conversion_scripts/convert_from_csv.py @@ -141,6 +141,10 @@ def handle_switches(poses, supress=None): def handle_file(file, args): pf = pd.read_csv(file, skiprows=4, header=None, sep=args.sep) + assert ( + len(pf.columns) >= 2 + ), f"{pf}\nThere were only {len(pf.columns)} columns in {file}. Looks like the seperator was maybe wrong? Choose it with --sep. Default is ;" + if args.columns_per_entity is None: all_col_types_matching = [] for cols in range(1, len(pf.columns) // 2 + 1): @@ -163,18 +167,15 @@ def handle_file(file, args): all_col_types_matching.append(cols) # print(cols, "\t", matching_types) - if len(all_col_types_matching) == 0: - print( - "Error: Could not detect columns_per_entity. Please specify manually using --columns_per_entity" - ) - else: - assert [ - all_col_types_matching[i] % all_col_types_matching[0] == 0 - for i in range(0, len(all_col_types_matching)) - ], f"Error: Found multiple columns_per_entity which were not multiples of each other {all_col_types_matching}" - columns_per_entity = all_col_types_matching[0] + assert ( + len(all_col_types_matching) > 0 + ), f"Error: Could not detect columns_per_entity. Please specify manually using --columns_per_entity\ndatatypes were {dt}" + assert [ + all_col_types_matching[i] % all_col_types_matching[0] == 0 + for i in range(0, len(all_col_types_matching)) + ], f"Error: Found multiple columns_per_entity which were not multiples of each other {all_col_types_matching}" + columns_per_entity = all_col_types_matching[0] print("Found columns_per_entity: %d" % columns_per_entity) - else: columns_per_entity = int(args.columns_per_entity) @@ -204,6 +205,7 @@ def handle_file(file, args): io_file_path, "w", world_size_cm=[args.world_size, args.world_size], + world_shape=args.world_shape, frequency_hz=args.frequency, ) as iof: poses = np.empty((pf_np.shape[0], n_fish, 3), dtype=np.float32) @@ -380,6 +382,7 @@ parser = argparse.ArgumentParser( ) parser.add_argument("path", nargs="+") parser.add_argument("-o", "--output", default=None) +parser.add_argument("--world_shape", type=str) parser.add_argument("--sep", default=";") parser.add_argument("--header", default=3) parser.add_argument("--columns_per_entity", default=None) @@ -399,6 +402,11 @@ parser.add_argument( ) args = parser.parse_args() +assert args.world_shape in [ + "rectangle", + "ellipse", +], f"--world_shape has to be 'rectangle' or 'ellipse' but is '{args.world_shape}'" + if args.analysis_path is not None: args.analysis_path = Path(args.analysis_path) diff --git a/src/robofish/evaluate/app.py b/src/robofish/evaluate/app.py index f0c7526..c7d0a44 100644 --- a/src/robofish/evaluate/app.py +++ b/src/robofish/evaluate/app.py @@ -14,6 +14,7 @@ import robofish.evaluate import argparse from pathlib import Path import matplotlib.pyplot as plt +from copy import copy def function_dict() -> dict: @@ -98,6 +99,12 @@ def evaluate(args: dict = None) -> None: help="Filename for saving resulting graphics.", default=None, ) + parser.add_argument( + "--add_train_data", + action="store_true", + help="Add the training data to the evaluation.", + default=False, + ) # TODO: ignore fish/ consider_names @@ -108,11 +115,33 @@ def evaluate(args: dict = None) -> None: raise ValueError("When the analysis type is all, a --save_path must be given.") if args.analysis_type in fdict: - if args.labels is None: - args.labels = args.paths + paths = args.paths + labels = args.labels + print("starting", paths, labels) + + if labels is None: + labels = copy(paths) + + if args.add_train_data: + # Open any file to get the path + if len(paths) > 1: + warnings.warn( + "Multiple paths given. Only the first path is used for the training data." + ) + + files = list(Path(args.paths[0]).rglob("*.hdf5")) + if len(files) == 0: + warnings.warn("No hdf5 files found in the given path.") + with robofish.io.File(files[0]) as f: + train_data = str(f.attrs["training_data"]) + + paths += [train_data] + labels += ["/".join(Path(train_data).parts[-2:])] + + print("starting", paths, labels) save_path = None if args.save_path is None else Path(args.save_path) - params = {"paths": args.paths, "labels": args.labels} + params = {"paths": paths, "labels": labels} if args.analysis_type == "all": normal_functions = function_dict() normal_functions.pop("all") diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py index 1d5b3f4..bede8f7 100644 --- a/src/robofish/evaluate/evaluate.py +++ b/src/robofish/evaluate/evaluate.py @@ -185,7 +185,6 @@ def evaluate_orientation( orientations = [] # Iterate all paths for poses_per_path in poses_from_paths: - # Iterate all files for poses in poses_per_path: reshaped_poses = poses.reshape((-1, 4)) @@ -340,7 +339,6 @@ def evaluate_distance_to_wall( # Iterate all files for poses in poses_per_path: - for e_poses in poses: dist = [] for wall in wall_lines: @@ -414,7 +412,6 @@ def evaluate_tank_position( # Iterate all paths for poses_per_path in poses_from_paths: - # Get all positions for each file (skipping steps with poses_step), flatten them to (..., 2) and concatenate all positions. new_xy_positions = np.concatenate( [p[:, :, :2].reshape(-1, 2) for p in poses_per_path], axis=0 @@ -427,7 +424,6 @@ def evaluate_tank_position( ax = [ax] for i in range(len(xy_positions)): - ax[i].set_xlim( -file_settings["world_size_cm_x"] / 2, file_settings["world_size_cm_x"] / 2 ) @@ -682,14 +678,21 @@ def evaluate_social_vector( axis=-1, ) - print(np.stack(poses).shape) social_vec = SocialVectors(poses).social_vectors_without_focal_zeros flat_sv = social_vec.reshape((-1, 3)) + bins = ( + 30 if flat_sv.shape[0] < 40000 else 50 if flat_sv.shape[0] < 65000 else 100 + ) + ax[i].hist2d( - flat_sv[:, 0], flat_sv[:, 1], range=[[-7.5, 7.5], [-7.5, 7.5]], bins=100 + flat_sv[:, 0], flat_sv[:, 1], range=[[-7.5, 7.5], [-7.5, 7.5]], bins=bins ) ax[i].set_title(labels[i]) + if bins < 100: + ax[i].set_title( + f"{labels[i]} downsampled to {bins} bins.\nFor more details generate {60000/flat_sv.shape[0]:.1f} times as much." + ) plt.suptitle("Social Vectors") return fig @@ -734,9 +737,13 @@ def evaluate_follow_iid( calculate_iid(poses[i, :-1, :2], poses[j, :-1, :2]) ) - path_follow.append( - calculate_follow(poses[i, :, :2], poses[j, :, :2]) - ) + file_follow = calculate_follow(poses[i, :, :2], poses[j, :, :2]) + + # Remove inf values and replace them with 0 + file_follow = np.where(np.isnan(file_follow), 0, file_follow) + file_follow = np.where(np.isinf(file_follow), 0, file_follow) + + path_follow.append(file_follow) follow.append(np.concatenate([np.atleast_1d(f) for f in path_follow])) iid.append(np.concatenate([np.atleast_1d(i) for i in path_iid])) @@ -753,7 +760,6 @@ def evaluate_follow_iid( max_iid = np.quantile(np.concatenate(iid), 0.9) for i in range(len(follow)): - # Mask data which is outside of the ranges mask = ( (follow[i] > -1 * follow_range) @@ -914,7 +920,6 @@ def evaluate_tracks( if verbose: print(f"Using file {file_path}") with robofish.io.File(new_file_path, "r") as new_file: - new_file.plot( selected_ax, lw_distances=lw_distances, @@ -925,13 +930,11 @@ def evaluate_tracks( f"{selected_ax.get_title()} (reference track: id {initial_poses_info['track_id']})", ) else: - with robofish.io.File(file_path, "r") as file: if ( "initial_poses_info" in file and "track_id" in file["initial_poses_info"].attrs ): - # Save the corresponding file that a model mimics. initial_poses_info = { a: file["initial_poses_info"].attrs[a] @@ -1139,6 +1142,7 @@ def evaluate_all( predicate: Callable[[robofish.io.Entity], bool] = None, max_files: int = None, evaluations: List[str] = None, + file_format: str = "png", ) -> Iterable[Path]: """Generate all evaluation graphs and save them to a folder. @@ -1184,7 +1188,7 @@ def evaluate_all( if evaluations is None or f_name in evaluations: t.set_description(f_name) t.refresh() # to show the update immediately - save_path = save_folder / (f_name + ".png") + save_path = save_folder / (f_name + "." + file_format) requested_inputs = { k: input_dict[k] @@ -1195,7 +1199,8 @@ def evaluate_all( paths=paths, labels=labels, predicate=predicate, **requested_inputs ) if fig is not None: - fig.savefig(save_path) + # Increase resolution + fig.savefig(save_path, dpi=300, bbox_inches="tight") plt.close(fig) save_paths.append(save_path) diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index acdec48..8df570f 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -67,7 +67,7 @@ class File(h5py.File): mode: str = "r", *, # PEP 3102 world_size_cm: List[int] = None, - world_shape: str = "rectangle", + world_shape: str = None, validate: bool = False, validate_when_saving: bool = True, strict_validate: bool = False, @@ -211,6 +211,15 @@ class File(h5py.File): ), "world_size_cm and format_version have to be given when creating a new file." self.attrs["world_size_cm"] = np.array(world_size_cm, dtype=np.float32) + + assert ( + world_shape is not None + ), "world_shape has to be given when creating a new file." + assert world_shape in [ + "rectangle", + "ellipse", + ], f"Unknown world shape {world_shape}." + self.attrs["world_shape"] = world_shape self.attrs["format_version"] = np.array(format_version, dtype=np.int32) self.attrs["format_url"] = format_url @@ -405,6 +414,15 @@ class File(h5py.File): def world_size(self): return self.attrs["world_size_cm"] + @property + def world_shape(self): + if "world_shape" not in self.attrs: + warnings.warn( + "File did not have a world_shape attribute. Assuming rectangle." + ) + return "rectangle" + return self.attrs["world_shape"] + @property def default_sampling(self): assert ( @@ -855,7 +873,7 @@ class File(h5py.File): lw: int = 2, ms: int = 32, figsize: Tuple[int] = None, - step_size: int = 4, + step_size: int = 25, c: List = None, cmap: matplotlib.colors.Colormap = "Set1", skip_timesteps=0, @@ -1117,12 +1135,12 @@ class File(h5py.File): y = np.array([-1, -1, 1, 1, -1]) * self.world_size[1] / 2 return np.array([x, y]) - if "world_shape" not in self.attrs or self.attrs["world_shape"] == "rectangle": + if self.world_shape == "rectangle": border_vertices = create_square() - elif self.attrs["world_shape"] == "ellipse": + elif self.world_shape == "ellipse": border_vertices = create_circle(150) else: - raise ValueError(f"Unknown world shape: {self.attrs['world_shape']}") + raise ValueError(f"Unknown world shape: {self.world_shape}") spacing = 10 x = np.arange( diff --git a/src/robofish/io/validation.py b/src/robofish/io/validation.py index b211e10..1371a63 100644 --- a/src/robofish/io/validation.py +++ b/src/robofish/io/validation.py @@ -92,6 +92,13 @@ def validate(iofile: File, strict_validate: bool = True) -> (bool, str): assert_validate(a in iofile.attrs, f'Attribute "{a}" missing","root') assert_validate_type(iofile.attrs[a], a_type, a, "root") + # Validate world shape + if "world_shape" in iofile.attrs: + assert iofile.attrs["world_shape"] in [ + "rectangle", + "ellipse", + ], f"Invalid world shape {iofile.attrs['world_shape']}. Allowed are 'rectangle' and 'ellipse'." + # validate samplings assert_validate("samplings" in iofile, "samplings not found") for s_name, sampling in iofile["samplings"].items(): -- GitLab