From 90d66a0ba160ae7fa5737bc09298db888a43cb05 Mon Sep 17 00:00:00 2001
From: Andi <andi.gerken@gmail.com>
Date: Tue, 25 Jul 2023 15:24:01 +0200
Subject: [PATCH] Added robofish-io-update-world-shape

---
 setup.py                |  1 +
 src/robofish/io/app.py  | 39 +++++++++++++++++++++++++++++++++++++++
 src/robofish/io/file.py |  2 +-
 3 files changed, 41 insertions(+), 1 deletion(-)

diff --git a/setup.py b/setup.py
index 6a09e44..5af4ce7 100644
--- a/setup.py
+++ b/setup.py
@@ -15,6 +15,7 @@ entry_points = {
         # TODO: This should be called robofish-evaluate which is not possible because of the package name (guess) ask moritz
         "robofish-io-evaluate=robofish.evaluate.app:evaluate",
         "robofish-io-update-individual-ids = robofish.io.app:update_individual_ids",
+        "robofish-io-update-world-shape = robofish.io.app:update_world_shape",
         "robofish-io-overwrite_user_configs = robofish.io.app:overwrite_user_configs",
         "robofish-io-fix-switches = robofish.io.fix_switches:app_fix_switches",
     ]
diff --git a/src/robofish/io/app.py b/src/robofish/io/app.py
index 4e73418..2f4091d 100644
--- a/src/robofish/io/app.py
+++ b/src/robofish/io/app.py
@@ -17,6 +17,7 @@ import argparse
 import logging
 import warnings
 from tqdm.auto import tqdm
+
 from typing import Dict
 import itertools
 import numpy as np
@@ -397,3 +398,41 @@ def update_individual_ids(args=None):
                 if y == "y":
                     print(f"Could not read individual_id from {fp}")
                     print(e)
+
+
+def update_world_shape(args: dict = None) -> None:
+    """Update the world shape attribute of files.
+
+    Args:
+        args (dict): Passthrough for argparser args. Defaults to None.
+    """
+    if args is None:
+        parser = argparse.ArgumentParser()
+        parser.add_argument(
+            "path",
+            type=str,
+            nargs="+",
+            help="The path to one or multiple files and/or folders.",
+        )
+        parser.add_argument("world_shape", type=str, help="The world shape.")
+        args = parser.parse_args()
+
+    assert args.world_shape in [
+        "rectangle",
+        "ellipse",
+    ], "The world shape must be either 'rectangle' or 'ellipse'."
+
+    files_per_path = utils.get_all_files_from_paths(args.path)
+    files_per_path = [sorted(f_in_path) for f_in_path in files_per_path]
+    files = list(itertools.chain.from_iterable(files_per_path))
+
+    pbar = tqdm(files)
+    for file in pbar:
+
+        pbar.set_description(f"Updating {file.name}")
+        with robofish.io.File(file, "r+") as f:
+            f.attrs["world_shape"] = args.world_shape
+
+        pbar.update(1)
+
+    print("Update finished.")
diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py
index 8df570f..b81711a 100644
--- a/src/robofish/io/file.py
+++ b/src/robofish/io/file.py
@@ -418,7 +418,7 @@ class File(h5py.File):
     def world_shape(self):
         if "world_shape" not in self.attrs:
             warnings.warn(
-                "File did not have a world_shape attribute. Assuming rectangle."
+                f"File {self.filename} did not have a world_shape attribute, assuming rectangle.\nPlease use robofish-io-update-world-shape {self.filename} rectangle to fix this."
             )
             return "rectangle"
         return self.attrs["world_shape"]
-- 
GitLab