From d1eb5fa1370657a71f9c17541eb5ac543bc093eb Mon Sep 17 00:00:00 2001
From: Andi <andi.gerken@gmail.com>
Date: Thu, 27 Jul 2023 17:52:06 +0200
Subject: [PATCH] updated tests for world shape

---
 examples/example_basic.ipynb                 |   2 +-
 examples/example_basic.py                    |   4 +-
 examples/example_readme.py                   |   4 +-
 src/robofish/io/app.py                       |   4 +-
 tests/resources/invalid.hdf5                 | Bin 20760 -> 26904 bytes
 tests/resources/nan_test.hdf5                | Bin 15576 -> 21720 bytes
 tests/resources/valid_1.hdf5                 | Bin 23272 -> 29416 bytes
 tests/resources/valid_2.hdf5                 | Bin 22488 -> 28632 bytes
 tests/robofish/evaluate/test_app_evaluate.py |  49 +++++++++----------
 tests/robofish/io/test_entity.py             |   8 ++-
 tests/robofish/io/test_file.py               |  34 +++++++++----
 11 files changed, 63 insertions(+), 42 deletions(-)

diff --git a/examples/example_basic.ipynb b/examples/example_basic.ipynb
index 7fa6591..23005db 100644
--- a/examples/example_basic.ipynb
+++ b/examples/example_basic.ipynb
@@ -12,7 +12,7 @@
     "\n",
     "def create_example_file(path):\n",
     "    # Create a new io file object with a 100x100cm world\n",
-    "    f = robofish.io.File(world_size_cm=[100, 100], frequency_hz=25.0)\n",
+    "    f = robofish.io.File(world_size_cm=[100, 100], frequency_hz=25.0, world_shape=\"rectangle\")\n",
     "\n",
     "    # create a robofish with 1000 timesteps. If we would not give a name, the name would be generated to be robot_1.\n",
     "    robofish_timesteps = 1000\n",
diff --git a/examples/example_basic.py b/examples/example_basic.py
index f37d0cc..a825774 100755
--- a/examples/example_basic.py
+++ b/examples/example_basic.py
@@ -5,7 +5,9 @@ import numpy as np
 def create_example_file(path):
     # Create a new io file object with a 100x100cm world.
     # Mode "w" means that the file should be opened with write access.
-    f = robofish.io.File(path, "w", world_size_cm=[100, 100], frequency_hz=25.0)
+    f = robofish.io.File(
+        path, "w", world_size_cm=[100, 100], frequency_hz=25.0, world_shape="rectangle"
+    )
 
     # create a robofish with 1000 timesteps. If we would not give a name, the name would be generated to be robot_1.
     robofish_timesteps = 1000
diff --git a/examples/example_readme.py b/examples/example_readme.py
index 7b2e22c..cf77488 100644
--- a/examples/example_readme.py
+++ b/examples/example_readme.py
@@ -5,7 +5,9 @@ import numpy as np
 def create_example_file(path):
     # Create a new robofish io file
     # Mode "w" means that the file should be opened with write access.
-    f = robofish.io.File(path, "w", world_size_cm=[100, 100], frequency_hz=25.0)
+    f = robofish.io.File(
+        path, "w", world_size_cm=[100, 100], frequency_hz=25.0, world_shape="rectangle"
+    )
     f.attrs["experiment_setup"] = "This is a simple example with made up data."
 
     # Create a new robot entity with 10 timesteps.
diff --git a/src/robofish/io/app.py b/src/robofish/io/app.py
index 2f4091d..0efc84c 100644
--- a/src/robofish/io/app.py
+++ b/src/robofish/io/app.py
@@ -430,7 +430,9 @@ def update_world_shape(args: dict = None) -> None:
     for file in pbar:
 
         pbar.set_description(f"Updating {file.name}")
-        with robofish.io.File(file, "r+") as f:
+        with robofish.io.File(
+            file, "r+", validate_when_saving=False, calculate_data_on_close=False
+        ) as f:
             f.attrs["world_shape"] = args.world_shape
 
         pbar.update(1)
diff --git a/tests/resources/invalid.hdf5 b/tests/resources/invalid.hdf5
index 18e2361c58dce4478a173e40f9c998cda90beac7..e76736436cc0d0499e8a95a031bccc449a07fc50 100755
GIT binary patch
delta 140
zcmbQSh;hax#tE8?nG>}v7`Znlv$6{cFbFV!fJ7jOWMBxGxUqc`hlHmz4}%9-gpq-V
zL4<*Wp*+7RCndf(Be5WLayN@AFDC;7gG3})2?OKiM2^kulQ<+4L82hBqSWM)#Ju#J
SR1gOVU)UVTv7Ue80SN$`5gQ-?

delta 30
mcmbPniE+jv#tE8?ffKbX7&$j4v$8WYFi1??ICT?;geL%$^a$Jl

diff --git a/tests/resources/nan_test.hdf5 b/tests/resources/nan_test.hdf5
index 92c3050e3b5204721b0eb0a0543adf0c4d5f65c0..daabe3c98a442a0f35bfe834d9c09cae6cd8ad59 100644
GIT binary patch
delta 138
zcmcand1EEx1Wm?}iCPwn>>HEMG7AbY2rz)a4I2>2z!0$6k!8AtG!KIZn9s<-!yv-I
z!BC!Gl#>!)oRL_NI{7!VDlaDk1H%m$uo4Ew%^Nv3vrpoXPy&g9#EMdrOA_<ab5cPZ
QBz%E?^F|JT{)q?50XmT#UjP6A

delta 32
ocmcbylJQ361WiVpiCPwuFR_bntlZAb%)lTpaii#F50(iQ0L`)sRR910

diff --git a/tests/resources/valid_1.hdf5 b/tests/resources/valid_1.hdf5
index 664e40a39b908364ac70cf99dc11715a6ecd2abc..cb5d38c857def87d8719ef309c189fe4374bc308 100644
GIT binary patch
delta 143
zcmaE{mGQ+>#tE8?MH96w7`Zkk?`0AaU=UycffrFAl7S(EVY3nQw(!Xc%u<pbU?D~Z
z9tIHx4u<mlqMVfY;*7+C)XBe@Re3oX7#Ln8ft4^YZg%9H%|3}kLJ1@a5-UngE=kNw
W&q)Pwknjcm&5oS=`8FP4X9NH~DIS*q

delta 41
tcmaFyl<~z@#tE8?Q4_T+CKnlqZLC<qH2Hu$2UiCJ1gv1%Y{a}R8~{W;4Ltw=

diff --git a/tests/resources/valid_2.hdf5 b/tests/resources/valid_2.hdf5
index eb8c69e1a01f1704c29c94b7c85590c5443275dd..a080dda64d6ba7d59d2505482757249725376005 100644
GIT binary patch
delta 138
zcmcbyp7F+g#tE8?`4hD)7}+-_pJf&lU=Uycfg9l<l7S&$vm?v&U}+u(4=|sRfrmka
zfrFtuzbGdqzBnVXAa(L@W>sEJ1_p*3@n9tkjGHfVY-XRtA)yEo1&I}<CYL1UrRSuA
RI7s-y=7Sve{1Xp!004^W9$5eY

delta 30
mcmca{pYg_e#tE8?;S;qiCSPI~-&nbwdGY}ziOn7?6M_NAHw+s9

diff --git a/tests/robofish/evaluate/test_app_evaluate.py b/tests/robofish/evaluate/test_app_evaluate.py
index d8b1d69..4b1332f 100644
--- a/tests/robofish/evaluate/test_app_evaluate.py
+++ b/tests/robofish/evaluate/test_app_evaluate.py
@@ -14,46 +14,43 @@ h5py_file_2 = utils.full_path(__file__, "../../resources/valid_2.hdf5")
 nan_file_path = utils.full_path(__file__, "../../resources/nan_test.hdf5")
 
 
+class DummyArgs:
+    def __init__(self, analysis_type, paths, save_path, add_train_data):
+        self.paths = paths
+        self.names = None
+        self.analysis_type = analysis_type
+        self.save_path = save_path
+        self.labels = None
+        self.add_train_data = add_train_data
+
+
 def test_app_validate(tmp_path):
     """This tests the function of the robofish-io-validate command"""
 
-    class DummyArgs:
-        def __init__(self, analysis_type, paths, save_path):
-            self.paths = paths
-            self.names = None
-            self.analysis_type = analysis_type
-            self.save_path = save_path
-            self.labels = None
-
     for mode in app.function_dict().keys():
         if mode == "all":
-            app.evaluate(DummyArgs(mode, [h5py_file_1], tmp_path))
-            app.evaluate(DummyArgs(mode, [h5py_file_2, h5py_file_2], tmp_path))
-            app.evaluate(DummyArgs(mode, [nan_file_path], tmp_path))
+            app.evaluate(DummyArgs(mode, [h5py_file_1], tmp_path, False))
+            app.evaluate(DummyArgs(mode, [h5py_file_2, h5py_file_2], tmp_path, False))
+            app.evaluate(DummyArgs(mode, [nan_file_path], tmp_path, False))
         else:
-            app.evaluate(DummyArgs(mode, [h5py_file_1], tmp_path / "image.png"))
+            app.evaluate(DummyArgs(mode, [h5py_file_1], tmp_path / "image.png", False))
             app.evaluate(
-                DummyArgs(mode, [h5py_file_2, h5py_file_2], tmp_path / "image.png")
+                DummyArgs(
+                    mode, [h5py_file_2, h5py_file_2], tmp_path / "image.png", False
+                )
             )
 
 
 def test_app_validate(tmp_path):
     """This tests the function of the robofish-io-validate command"""
-
-    class DummyArgs:
-        def __init__(self, analysis_type, paths, save_path):
-            self.paths = paths
-            self.names = None
-            self.analysis_type = analysis_type
-            self.save_path = save_path
-            self.labels = None
-
     for mode in app.function_dict().keys():
         if mode == "all":
-            app.evaluate(DummyArgs(mode, [h5py_file_1], tmp_path))
-            app.evaluate(DummyArgs(mode, [h5py_file_2, h5py_file_2], tmp_path))
+            app.evaluate(DummyArgs(mode, [h5py_file_1], tmp_path, False))
+            app.evaluate(DummyArgs(mode, [h5py_file_2, h5py_file_2], tmp_path, False))
         else:
-            app.evaluate(DummyArgs(mode, [h5py_file_1], tmp_path / "image.png"))
+            app.evaluate(DummyArgs(mode, [h5py_file_1], tmp_path / "image.png", False))
             app.evaluate(
-                DummyArgs(mode, [h5py_file_2, h5py_file_2], tmp_path / "image.png")
+                DummyArgs(
+                    mode, [h5py_file_2, h5py_file_2], tmp_path / "image.png", False
+                )
             )
diff --git a/tests/robofish/io/test_entity.py b/tests/robofish/io/test_entity.py
index 0d843bb..38d50de 100644
--- a/tests/robofish/io/test_entity.py
+++ b/tests/robofish/io/test_entity.py
@@ -4,7 +4,9 @@ import numpy as np
 
 
 def test_entity_object():
-    sf = robofish.io.File(world_size_cm=[100, 100], frequency_hz=25)
+    sf = robofish.io.File(
+        world_size_cm=[100, 100], frequency_hz=25, world_shape="rectangle"
+    )
     f = sf.create_entity("fish", positions=[[10, 10]])
     assert type(f) == robofish.io.Entity, "Type of entity was wrong"
     assert f.name == "fish_1", "Name of entity was wrong"
@@ -40,7 +42,9 @@ def test_entity_turns_speeds():
     There is an open issue, to check this test again: https://git.imp.fu-berlin.de/bioroboticslab/robofish/io/-/issues/14
     """
 
-    f = robofish.io.File(world_size_cm=[100, 100], frequency_hz=25)
+    f = robofish.io.File(
+        world_size_cm=[100, 100], frequency_hz=25, world_shape="rectangle"
+    )
     circle_rad = np.linspace(0, 2 * np.pi, num=100)
     circle_size = 40
     poses_rad = np.stack(
diff --git a/tests/robofish/io/test_file.py b/tests/robofish/io/test_file.py
index 04d4b98..927f44f 100644
--- a/tests/robofish/io/test_file.py
+++ b/tests/robofish/io/test_file.py
@@ -15,33 +15,37 @@ nan_file_path = utils.full_path(__file__, "../../resources/nan_test.hdf5")
 
 
 def test_constructor():
-    sf = robofish.io.File(world_size_cm=[100, 100])
+    sf = robofish.io.File(world_size_cm=[100, 100], world_shape="rectangle")
     sf.validate()
     sf.close()
 
 
 def test_context():
-    with robofish.io.File(world_size_cm=[10, 10]) as f:
+    with robofish.io.File(world_size_cm=[10, 10], world_shape="rectangle") as f:
         pass
 
 
 def test_new_file_w_path(tmp_path):
     f = tmp_path / "file.hdf5"
-    sf = robofish.io.File(f, "w", world_size_cm=[100, 100], frequency_hz=25)
+    sf = robofish.io.File(
+        f, "w", world_size_cm=[100, 100], frequency_hz=25, world_shape="rectangle"
+    )
     sf.create_entity("fish")
     sf.validate()
     sf.close()
 
 
 def test_missing_attribute():
-    sf = robofish.io.File(world_size_cm=[10, 10])
+    sf = robofish.io.File(world_size_cm=[10, 10], world_shape="rectangle")
     sf.attrs.pop("world_size_cm")
     assert not sf.validate(strict_validate=False)[0]
     sf.close()
 
 
 def test_single_entity_frequency_hz():
-    sf = robofish.io.File(world_size_cm=[100, 100], frequency_hz=(1000 / 40))
+    sf = robofish.io.File(
+        world_size_cm=[100, 100], frequency_hz=(1000 / 40), world_shape="rectangle"
+    )
     test_poses = np.ones(shape=(10, 4))
     test_poses[:, 3] = 0  # All Fish pointing right
     sf.create_entity("robofish", poses=test_poses)
@@ -54,7 +58,9 @@ def test_single_entity_frequency_hz():
 def test_single_entity_monotonic_time_points_us():
     with pytest.warns(Warning):
         sf = robofish.io.File(
-            world_size_cm=[100, 100], monotonic_time_points_us=np.ones(10)
+            world_size_cm=[100, 100],
+            monotonic_time_points_us=np.ones(10),
+            world_shape="rectangle",
         )
     test_poses = np.ones(shape=(10, 4))
     test_poses[:, 3] = 0  # All Fish pointing right
@@ -76,7 +82,9 @@ def test_multiple_entities():
 
     with pytest.warns(Warning):
         sf = robofish.io.File(
-            world_size_cm=[100, 100], monotonic_time_points_us=m_points
+            world_size_cm=[100, 100],
+            monotonic_time_points_us=m_points,
+            world_shape="rectangle",
         )
     returned_entities = sf.create_multiple_entities("fish", poses)
     returned_names = [entity.name for entity in returned_entities]
@@ -146,7 +154,9 @@ def test_multiple_entities():
 
 
 def test_actions_speeds_turns_angles():
-    with robofish.io.File(world_size_cm=[100, 100], frequency_hz=25) as f:
+    with robofish.io.File(
+        world_size_cm=[100, 100], frequency_hz=25, world_shape="rectangle"
+    ) as f:
         poses = np.zeros((10, 100, 3))
         f.create_multiple_entities("fish", poses=poses)
 
@@ -155,7 +165,9 @@ def test_actions_speeds_turns_angles():
 
 
 def test_entity_poses_rad(caplog):
-    with robofish.io.File(world_size_cm=[100, 100], frequency_hz=25) as f:
+    with robofish.io.File(
+        world_size_cm=[100, 100], frequency_hz=25, world_shape="rectangle"
+    ) as f:
         # Create an entity, using radians
         f.create_entity("fish", poses=np.ones((100, 3)))
 
@@ -169,7 +181,9 @@ def test_entity_poses_rad(caplog):
 
 
 def test_entity_positions_no_orientation():
-    with robofish.io.File(world_size_cm=[100, 100], frequency_hz=25) as f:
+    with robofish.io.File(
+        world_size_cm=[100, 100], frequency_hz=25, world_shape="rectangle"
+    ) as f:
         # Create an entity, using radians
         f.create_entity("fish", positions=np.ones((100, 2)))
 
-- 
GitLab