diff --git a/setup.py b/setup.py index 6a09e4465d08523d1689a9ab9fc63eadf4ad6bf3..5af4ce7e8b8de7f1b7c148735c717543b8a9d22f 100644 --- a/setup.py +++ b/setup.py @@ -15,6 +15,7 @@ entry_points = { # TODO: This should be called robofish-evaluate which is not possible because of the package name (guess) ask moritz "robofish-io-evaluate=robofish.evaluate.app:evaluate", "robofish-io-update-individual-ids = robofish.io.app:update_individual_ids", + "robofish-io-update-world-shape = robofish.io.app:update_world_shape", "robofish-io-overwrite_user_configs = robofish.io.app:overwrite_user_configs", "robofish-io-fix-switches = robofish.io.fix_switches:app_fix_switches", ] diff --git a/src/robofish/io/app.py b/src/robofish/io/app.py index 4e73418db4a854c535bbeee1bbc193f4d73cd3bd..2f4091db6ac1a0c51d40b42a3bb31ddd5086be14 100644 --- a/src/robofish/io/app.py +++ b/src/robofish/io/app.py @@ -17,6 +17,7 @@ import argparse import logging import warnings from tqdm.auto import tqdm + from typing import Dict import itertools import numpy as np @@ -397,3 +398,41 @@ def update_individual_ids(args=None): if y == "y": print(f"Could not read individual_id from {fp}") print(e) + + +def update_world_shape(args: dict = None) -> None: + """Update the world shape attribute of files. + + Args: + args (dict): Passthrough for argparser args. Defaults to None. + """ + if args is None: + parser = argparse.ArgumentParser() + parser.add_argument( + "path", + type=str, + nargs="+", + help="The path to one or multiple files and/or folders.", + ) + parser.add_argument("world_shape", type=str, help="The world shape.") + args = parser.parse_args() + + assert args.world_shape in [ + "rectangle", + "ellipse", + ], "The world shape must be either 'rectangle' or 'ellipse'." + + files_per_path = utils.get_all_files_from_paths(args.path) + files_per_path = [sorted(f_in_path) for f_in_path in files_per_path] + files = list(itertools.chain.from_iterable(files_per_path)) + + pbar = tqdm(files) + for file in pbar: + + pbar.set_description(f"Updating {file.name}") + with robofish.io.File(file, "r+") as f: + f.attrs["world_shape"] = args.world_shape + + pbar.update(1) + + print("Update finished.") diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index 8df570f42264ede992dbdf35d49ef2fcc552dd2f..b81711a1fff8adfa20346677247752bc1a64e6e6 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -418,7 +418,7 @@ class File(h5py.File): def world_shape(self): if "world_shape" not in self.attrs: warnings.warn( - "File did not have a world_shape attribute. Assuming rectangle." + f"File {self.filename} did not have a world_shape attribute, assuming rectangle.\nPlease use robofish-io-update-world-shape {self.filename} rectangle to fix this." ) return "rectangle" return self.attrs["world_shape"]