diff --git a/.gitignore b/.gitignore index c93c2451ee6958af633a4e24fd66c34304682f73..043b66a95a35968ee555bd349ade5dddc7a2dc0c 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ dist report.xml htmlcov docs +env !tests/resources/*.hdf5 *.ipynb_checkpoints diff --git a/README.md b/README.md index 8ccc8b9f7ffedbc46e0f8f56dea023eeed5dafa2..5da734e05530add0c7ef723c9990f3634dbcd57f 100644 --- a/README.md +++ b/README.md @@ -31,14 +31,14 @@ We show a simple example below. More examples can be found in ```examples/``` import robofish.io import numpy as np -filename = "example.hdf5" +# Create a new robofish io file f = robofish.io.File(world_size_cm=[100, 100], frequency_hz=25.0) f.attrs["experiment_setup"] = "This is a simple example with made up data." -# Create a single robot with 30 timesteps -# positions are passed separately -# orientations are passed as with two columns -> orientation_x and orientation_y +# Create a new robot entity. Positions and orientations are passed +# separately in this example. Since the orientations have two columns, +# unit vectors are assumed (orientation_x, orientation_y) f.create_entity( category="robot", name="robot", @@ -46,7 +46,9 @@ f.create_entity( orientations=np.ones((100, 2)) * [0, 1], ) -# Create fishes with 30 poses (x, y, orientation_rad) +# Create a new fish entity. +# In this case, we pass positions and orientations together (x, y, rad). +# Since it is a 3 column array, orientations in radiants are assumed. poses = np.zeros((100, 3)) poses[:, 0] = np.arange(-50, 50) poses[:, 1] = np.arange(-50, 50) @@ -57,12 +59,9 @@ fish.attrs["fish_standard_length_cm"] = 10 # Show and save the file print(f) -print("Poses Shape: ", f.get_poses().shape) - -# Saving also validates the file -f.save(filename) -print(f"Saved to {filename}") +print("Poses Shape: ", f.entity_poses.shape) +f.save_as(path) ``` ### Evaluation diff --git a/examples/example_basic.ipynb b/examples/example_basic.ipynb index e737a4c023bb58873e9331b73d68a8af40ad8fda..eba81c56c7c497751bffce4c944b4ae3212dac15 100644 --- a/examples/example_basic.ipynb +++ b/examples/example_basic.ipynb @@ -2,7 +2,65 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "import robofish.io\n", + "import numpy as np\n", + "\n", + "\n", + "def create_example_file(path):\n", + " # Create a new io file object with a 100x100cm world\n", + " f = robofish.io.File(world_size_cm=[100, 100], frequency_hz=25.0)\n", + "\n", + " # create a simple obstacle, fixed in place, fixed outline\n", + " obstacle_outline = [[[-10, -10], [-10, 0], [0, 0], [0, -10]]]\n", + " obstacle_name = f.create_entity(\n", + " \"obstacle\", positions=[[50, 50]], orientations=[[0]], outlines=obstacle_outline\n", + " )\n", + "\n", + " # create a robofish with 1000 timesteps. If we would not give a name, the name would be generated to be robot_1.\n", + " robofish_timesteps = 1000\n", + " robofish_poses = np.tile([50, 50, 1, 0], (robofish_timesteps, 1))\n", + " robot = f.create_entity(\"robot\", name=\"robot\", poses=robofish_poses)\n", + "\n", + " # create multiple fishes with timestamps. Since we don't specify names, but only the type \"fish\" the fishes will be named [\"fish_1\", \"fish_2\", \"fish_3\"]\n", + " agents = 3\n", + " timesteps = 1000\n", + " # timestamps = np.linspace(0, timesteps + 1, timesteps)\n", + " agent_poses = np.random.random((agents, timesteps, 3))\n", + "\n", + " fishes = f.create_multiple_entities(\"fish\", agent_poses)\n", + "\n", + " # This would throw an exception if the file was invalid\n", + " f.validate()\n", + "\n", + " # Save file validates aswell\n", + " f.save_as(path)\n", + "\n", + " # Closing and opening files (just for demonstration). When opening with r+, we can read and write afterwards.\n", + " f.close()\n", + " f = robofish.io.File(path, \"r+\")\n", + "\n", + " print(\"\\nEntity Names\")\n", + " print(f.entity_names)\n", + "\n", + " # Get an array with all poses. As the length of poses varies per agent, it is filled up with nans.\n", + " print(\"\\nAll poses\")\n", + " print(f.entity_poses)\n", + "\n", + " # Select all entities with the category fish\n", + " print(\"\\nFish poses\")\n", + " print(f.select_entity_poses(lambda e: e.category == \"fish\"))\n", + "\n", + " print(\"\\nFinal file\")\n", + " print(f)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -14,29 +72,29 @@ "['fish_1', 'fish_2', 'fish_3', 'obstacle_1', 'robot']\n", "\n", "All poses\n", - "[[[5.56598067e-01 2.42021829e-01 2.45265663e-01 2.95294344e-01]\n", - " [9.83367860e-01 4.36241180e-01 5.68063319e-01 8.96367550e-01]\n", - " [8.71030211e-01 9.32705551e-02 8.51928890e-01 9.48658109e-01]\n", + "[[[5.20669639e-01 4.58831601e-02 9.51813161e-01 3.06678474e-01]\n", + " [4.54089224e-01 4.62401301e-01 7.62699127e-01 6.46753490e-01]\n", + " [9.67400968e-01 5.63468635e-01 9.99894381e-01 1.45336846e-02]\n", " ...\n", - " [6.94550753e-01 6.52826130e-01 6.08456850e-01 4.88614887e-01]\n", - " [4.84500438e-01 1.14949882e-01 6.08556986e-01 7.93362781e-02]\n", - " [9.54909027e-01 3.18913072e-01 4.58294243e-01 7.45387852e-01]]\n", + " [8.80278572e-02 3.80126536e-01 9.97331321e-01 7.30081499e-02]\n", + " [3.28493923e-01 6.17436647e-01 9.20374036e-01 3.91039193e-01]\n", + " [7.53752515e-02 2.74342328e-01 7.84065127e-01 6.20678544e-01]]\n", "\n", - " [[7.51943365e-02 9.95502114e-01 5.04003823e-01 8.36720586e-01]\n", - " [8.18848014e-01 4.04324770e-01 5.49858093e-01 3.51742476e-01]\n", - " [1.66903093e-01 1.78061739e-01 2.81622916e-01 8.88221264e-01]\n", + " [[4.45818096e-01 5.39128423e-01 5.66779137e-01 8.23869765e-01]\n", + " [4.38290328e-01 3.78231764e-01 9.88384008e-01 1.51977092e-01]\n", + " [1.82629541e-01 8.08078349e-01 8.27510357e-01 5.61450422e-01]\n", " ...\n", - " [7.69020915e-01 6.33118331e-01 4.15713340e-01 4.24170971e-01]\n", - " [7.64205098e-01 7.78579533e-01 7.44598091e-01 3.30398619e-01]\n", - " [4.45223182e-01 9.25011218e-01 2.36187894e-02 7.62242600e-02]]\n", + " [5.48624545e-02 9.00771558e-01 8.00554574e-01 5.99259913e-01]\n", + " [8.34412515e-01 2.47002933e-02 8.98045361e-01 4.39902902e-01]\n", + " [2.35550269e-01 5.17610013e-01 7.89557755e-01 6.13676250e-01]]\n", "\n", - " [[7.73669600e-01 3.54849041e-01 4.21867281e-01 9.16552365e-01]\n", - " [4.73650604e-01 1.79305673e-01 8.38760436e-01 3.96051705e-01]\n", - " [2.01547332e-02 8.12301695e-01 2.78097481e-01 8.67732406e-01]\n", + " [[4.05882373e-02 9.17375386e-01 9.94780242e-01 1.02040365e-01]\n", + " [8.21412623e-01 6.68225110e-01 9.35957074e-01 3.52114081e-01]\n", + " [8.17060947e-01 2.38414750e-01 9.85274673e-01 1.70979008e-01]\n", " ...\n", - " [8.12627971e-01 6.28660858e-01 2.13196307e-01 6.49513781e-01]\n", - " [5.58035910e-01 4.63277161e-01 8.21570277e-01 6.79726541e-01]\n", - " [7.10023165e-01 5.45146585e-01 8.51007760e-01 9.56029415e-01]]\n", + " [1.20162942e-01 4.16668624e-01 9.17142212e-01 3.98560166e-01]\n", + " [9.19857025e-01 7.85686851e-01 6.58272266e-01 7.52779961e-01]\n", + " [8.69306684e-01 5.62790215e-01 9.19795334e-01 3.92398477e-01]]\n", "\n", " [[5.00000000e+01 5.00000000e+01 1.00000000e+00 0.00000000e+00]\n", " [ nan nan nan nan]\n", @@ -46,38 +104,38 @@ " [ nan nan nan nan]\n", " [ nan nan nan nan]]\n", "\n", - " [[5.00000000e+01 5.00000000e+01 5.00000000e+01 5.00000000e+01]\n", - " [5.00000000e+01 5.00000000e+01 5.00000000e+01 5.00000000e+01]\n", - " [5.00000000e+01 5.00000000e+01 5.00000000e+01 5.00000000e+01]\n", + " [[5.00000000e+01 5.00000000e+01 1.00000000e+00 0.00000000e+00]\n", + " [5.00000000e+01 5.00000000e+01 1.00000000e+00 0.00000000e+00]\n", + " [5.00000000e+01 5.00000000e+01 1.00000000e+00 0.00000000e+00]\n", " ...\n", - " [5.00000000e+01 5.00000000e+01 5.00000000e+01 5.00000000e+01]\n", - " [5.00000000e+01 5.00000000e+01 5.00000000e+01 5.00000000e+01]\n", - " [5.00000000e+01 5.00000000e+01 5.00000000e+01 5.00000000e+01]]]\n", + " [5.00000000e+01 5.00000000e+01 1.00000000e+00 0.00000000e+00]\n", + " [5.00000000e+01 5.00000000e+01 1.00000000e+00 0.00000000e+00]\n", + " [5.00000000e+01 5.00000000e+01 1.00000000e+00 0.00000000e+00]]]\n", "\n", "Fish poses\n", - "[[[0.55659807 0.24202183 0.24526566 0.29529434]\n", - " [0.98336786 0.43624118 0.56806332 0.89636755]\n", - " [0.87103021 0.09327056 0.85192889 0.94865811]\n", + "[[[0.52066964 0.04588316 0.95181316 0.30667847]\n", + " [0.45408922 0.4624013 0.76269913 0.64675349]\n", + " [0.96740097 0.56346864 0.99989438 0.01453368]\n", " ...\n", - " [0.69455075 0.65282613 0.60845685 0.48861489]\n", - " [0.48450044 0.11494988 0.60855699 0.07933628]\n", - " [0.95490903 0.31891307 0.45829424 0.74538785]]\n", + " [0.08802786 0.38012654 0.99733132 0.07300815]\n", + " [0.32849392 0.61743665 0.92037404 0.39103919]\n", + " [0.07537525 0.27434233 0.78406513 0.62067854]]\n", "\n", - " [[0.07519434 0.99550211 0.50400382 0.83672059]\n", - " [0.81884801 0.40432477 0.54985809 0.35174248]\n", - " [0.16690309 0.17806174 0.28162292 0.88822126]\n", + " [[0.4458181 0.53912842 0.56677914 0.82386976]\n", + " [0.43829033 0.37823176 0.98838401 0.15197709]\n", + " [0.18262954 0.80807835 0.82751036 0.56145042]\n", " ...\n", - " [0.76902092 0.63311833 0.41571334 0.42417097]\n", - " [0.7642051 0.77857953 0.74459809 0.33039862]\n", - " [0.44522318 0.92501122 0.02361879 0.07622426]]\n", + " [0.05486245 0.90077156 0.80055457 0.59925991]\n", + " [0.83441252 0.02470029 0.89804536 0.4399029 ]\n", + " [0.23555027 0.51761001 0.78955775 0.61367625]]\n", "\n", - " [[0.7736696 0.35484904 0.42186728 0.91655236]\n", - " [0.4736506 0.17930567 0.83876044 0.3960517 ]\n", - " [0.02015473 0.8123017 0.27809748 0.86773241]\n", + " [[0.04058824 0.91737539 0.99478024 0.10204037]\n", + " [0.82141262 0.66822511 0.93595707 0.35211408]\n", + " [0.81706095 0.23841475 0.98527467 0.17097901]\n", " ...\n", - " [0.81262797 0.62866086 0.21319631 0.64951378]\n", - " [0.55803591 0.46327716 0.82157028 0.67972654]\n", - " [0.71002316 0.54514658 0.85100776 0.95602942]]]\n", + " [0.12016294 0.41666862 0.91714221 0.39856017]\n", + " [0.91985703 0.78568685 0.65827227 0.75277996]\n", + " [0.86930668 0.56279022 0.91979533 0.39239848]]]\n", "\n", "File structure\n", " format_url:\thttps://git.imp.fu-berlin.de/bioroboticslab/robofish/track_format/-/releases/1.0\n", @@ -114,61 +172,16 @@ } ], "source": [ - "import robofish.io\n", - "import numpy as np\n", - "\n", - "\n", - "def create_example_file(path):\n", - " # Create a new io file object with a 100x100cm world\n", - " sf = robofish.io.File(world_size_cm=[100, 100], frequency_hz=25.0)\n", - "\n", - " # create a simple obstacle, fixed in place, fixed outline\n", - " obstacle_outline = [[[-10, -10], [-10, 0], [0, 0], [0, -10]]]\n", - " obstacle_name = sf.create_entity(\n", - " \"obstacle\", positions=[[50, 50]], orientations=[[0]], outlines=obstacle_outline\n", - " )\n", - "\n", - " # create a robofish with 1000 timesteps. If we would not give a name, the name would be generated to be robot_1.\n", - " robofish_timesteps = 1000\n", - " robofish_poses = np.ones((robofish_timesteps, 4)) * 50\n", - " robot = sf.create_entity(\"robot\", robofish_poses, name=\"robot\")\n", - "\n", - " # create multiple fishes with timestamps. Since we don't specify names, but only the type \"fish\" the fishes will be named [\"fish_1\", \"fish_2\", \"fish_3\"]\n", - " agents = 3\n", - " timesteps = 1000\n", - " # timestamps = np.linspace(0, timesteps + 1, timesteps)\n", - " agent_poses = np.random.random((agents, timesteps, 4))\n", - "\n", - " fishes = sf.create_multiple_entities(\"fish\", agent_poses)\n", - "\n", - " # This would throw an exception if the file was invalid\n", - " sf.validate()\n", - "\n", - " # Save file validates aswell\n", - "\n", - " sf.save_as(path)\n", - "\n", - " # Closing and opening files (just for demonstration)\n", - " sf.close()\n", - " sf = robofish.io.File(path=path)\n", - "\n", - " print(\"\\nEntity Names\")\n", - " print(sf.entity_names)\n", - "\n", - " # Get an array with all poses. As the length of poses varies per agent, it is filled up with nans.\n", - " print(\"\\nAll poses\")\n", - " print(sf.entity_poses)\n", - "\n", - " print(\"\\nFish poses\")\n", - " print(sf.select_entity_poses(lambda e: e.category == \"fish\"))\n", - "\n", - " print(\"\\nFile structure\")\n", - " print(sf)\n", - "\n", - "\n", "if __name__ == \"__main__\":\n", - " create_example_file(\"example.hdf5\")\n" + " create_example_file(\"example.hdf5\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/examples/example_basic.py b/examples/example_basic.py index b9e356c7547a37cdfc66e75f8f93269bb4365a1c..34e092296f6edde6301a13623c15f77a46ddb70a 100755 --- a/examples/example_basic.py +++ b/examples/example_basic.py @@ -1,55 +1,53 @@ -#! /usr/bin/env python3 - import robofish.io import numpy as np def create_example_file(path): # Create a new io file object with a 100x100cm world - sf = robofish.io.File(world_size_cm=[100, 100], frequency_hz=25.0) + f = robofish.io.File(world_size_cm=[100, 100], frequency_hz=25.0) # create a simple obstacle, fixed in place, fixed outline obstacle_outline = [[[-10, -10], [-10, 0], [0, 0], [0, -10]]] - obstacle_name = sf.create_entity( + obstacle_name = f.create_entity( "obstacle", positions=[[50, 50]], orientations=[[0]], outlines=obstacle_outline ) # create a robofish with 1000 timesteps. If we would not give a name, the name would be generated to be robot_1. robofish_timesteps = 1000 - robofish_poses = np.ones((robofish_timesteps, 4)) * 50 - robot = sf.create_entity("robot", robofish_poses, name="robot") + robofish_poses = np.tile([50, 50, 1, 0], (robofish_timesteps, 1)) + robot = f.create_entity("robot", name="robot", poses=robofish_poses) # create multiple fishes with timestamps. Since we don't specify names, but only the type "fish" the fishes will be named ["fish_1", "fish_2", "fish_3"] agents = 3 timesteps = 1000 # timestamps = np.linspace(0, timesteps + 1, timesteps) - agent_poses = np.random.random((agents, timesteps, 4)) + agent_poses = np.random.random((agents, timesteps, 3)) - fishes = sf.create_multiple_entities("fish", agent_poses) + fishes = f.create_multiple_entities("fish", agent_poses) # This would throw an exception if the file was invalid - sf.validate() + f.validate() # Save file validates aswell + f.save_as(path) - sf.save_as(path) - - # Closing and opening files (just for demonstration) - sf.close() - sf = robofish.io.File(path=path) + # Closing and opening files (just for demonstration). When opening with r+, we can read and write afterwards. + f.close() + f = robofish.io.File(path, "r+") print("\nEntity Names") - print(sf.entity_names) + print(f.entity_names) # Get an array with all poses. As the length of poses varies per agent, it is filled up with nans. print("\nAll poses") - print(sf.entity_poses) + print(f.entity_poses) + # Select all entities with the category fish print("\nFish poses") - print(sf.select_entity_poses(lambda e: e.category == "fish")) + print(f.select_entity_poses(lambda e: e.category == "fish")) - print("\nFile structure") - print(sf) + print("\nFinal file") + print(f) if __name__ == "__main__": diff --git a/examples/example_readme.py b/examples/example_readme.py index f489958f1fd994dce6c19c46cd56bb7aab1b16f2..62df2336f89281a0d9306428839e9aee2921ae4b 100644 --- a/examples/example_readme.py +++ b/examples/example_readme.py @@ -1,35 +1,38 @@ import robofish.io import numpy as np -from pathlib import Path def create_example_file(path): - # By using the context, the file will be automatically validated - with robofish.io.File(path, "w", world_size_cm=[100, 100], frequency_hz=25.0) as f: - f.attrs["experiment_setup"] = "This is a simple example with made up data." - - # Create a single robot with 30 timesteps - # positions are passed separately - # orientations are passed as with two columns -> orientation_x and orientation_y - f.create_entity( - category="robot", - name="robot", - positions=np.zeros((100, 2)), - orientations=np.ones((100, 2)) * [0, 1], - ) - - # Create fishes with 30 poses (x, y, orientation_rad) - poses = np.zeros((100, 3)) - poses[:, 0] = np.arange(-50, 50) - poses[:, 1] = np.arange(-50, 50) - poses[:, 2] = np.arange(0, 2 * np.pi, step=2 * np.pi / 100) - fish = f.create_entity("fish", poses=poses) - fish.attrs["species"] = "My rotating spaghetti fish" - fish.attrs["fish_standard_length_cm"] = 10 - - # Show and save the file - print(f) - print("Poses Shape: ", f.entity_poses.shape) + # Create a new robofish io file + f = robofish.io.File(world_size_cm=[100, 100], frequency_hz=25.0) + f.attrs["experiment_setup"] = "This is a simple example with made up data." + + # Create a new robot entity. Positions and orientations are passed + # separately in this example. Since the orientations have two columns, + # unit vectors are assumed (orientation_x, orientation_y) + f.create_entity( + category="robot", + name="robot", + positions=np.zeros((100, 2)), + orientations=np.ones((100, 2)) * [0, 1], + ) + + # Create a new fish entity. + # In this case, we pass positions and orientations together (x, y, rad). + # Since it is a 3 column array, orientations in radiants are assumed. + poses = np.zeros((100, 3)) + poses[:, 0] = np.arange(-50, 50) + poses[:, 1] = np.arange(-50, 50) + poses[:, 2] = np.arange(0, 2 * np.pi, step=2 * np.pi / 100) + fish = f.create_entity("fish", poses=poses) + fish.attrs["species"] = "My rotating spaghetti fish" + fish.attrs["fish_standard_length_cm"] = 10 + + # Show and save the file + print(f) + print("Poses Shape: ", f.entity_poses.shape) + + f.save_as(path) if __name__ == "__main__": diff --git a/setup.py b/setup.py index 3a6dad4b1f1245612f6d726bcf62240d28f5d61e..552f54e96c461e588672a3238c82b98cee1b8030 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,7 @@ setup( version=source_version(), author="", author_email="", - install_requires=["h5py>=3", "numpy", "seaborn", "pandas", "deprecation"], + install_requires=["h5py>=2.10.0", "numpy", "seaborn", "pandas", "deprecation"], classifiers=[ "Development Status :: 3 - Alpha", "Intended Audience :: Science/Research", diff --git a/src/robofish/io/entity.py b/src/robofish/io/entity.py index f791a125e2e7ef9f22c539e9b9810061e444de88..669ac1746ab15dd5c24e6eae1b6986bf97c6a49b 100644 --- a/src/robofish/io/entity.py +++ b/src/robofish/io/entity.py @@ -30,6 +30,12 @@ class Entity(h5py.Group): poses, positions, orientations, outlines ) + assert poses is None or (poses.ndim == 2 and poses.shape[1] in [3, 4]) + assert positions is None or (positions.ndim == 2 and positions.shape[1] == 2) + assert orientations is None or ( + orientations.ndim == 2 and orientations.shape[1] in [1, 2] + ) + # If no name is given, create one from type and an id if name is None: i = 1 @@ -53,6 +59,10 @@ class Entity(h5py.Group): @classmethod def convert_rad_to_vector(cla, orientations_rad): + if min(orientations_rad) < 0 or max(orientations_rad) > 2 * np.pi: + logging.warning( + "Converting orientations, from a bigger range than [0, 2 * pi]. When passing the orientations, they are assumed to be in radiants." + ) ori_rad = utils.np_array(orientations_rad) assert ori_rad.shape[1] == 1 ori_vec = np.empty((ori_rad.shape[0], 2)) @@ -107,9 +117,10 @@ class Entity(h5py.Group): positions = self.create_dataset( "positions", data=positions, dtype=np.float32 ) - orientations = self.create_dataset( - "orientations", data=orientations, dtype=np.float32 - ) + if orientations is not None: + orientations = self.create_dataset( + "orientations", data=orientations, dtype=np.float32 + ) if sampling is not None: positions.attrs["sampling"] = sampling @@ -121,6 +132,9 @@ class Entity(h5py.Group): @property def orientations(self): + if not "orientations" in self: + # If no orientation is given, the default direction is to the right + return np.tile([1, 0], (self.positions.shape[0], 1)) return self["orientations"] @property diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index 669c05c0bc3e019b099e4357e69ae07534318862..e4798c809fc28fe367e7b8198065622dcdbc341a 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -208,14 +208,19 @@ class File(h5py.File): ) if default: - self.default_sampling = name - self["samplings"].attrs["default"] = self.default_sampling + self["samplings"].attrs["default"] = name return name @property def world_size(self): return self.attrs["world_size_cm"] + @property + def default_sampling(self): + if "default" in self["samplings"].attrs: + return self["samplings"].attrs["default"] + return None + @property def frequency(self): # NOTE: Only works if default sampling availabe and specified with frequency_hz. @@ -245,11 +250,10 @@ class File(h5py.File): Name of the created entity """ - if sampling is None: - if not hasattr(self, "default_sampling"): - raise Exception( - "There was no sampling specified, when creating the file, nor when creating the entity." - ) + if sampling is None and self.default_sampling is None: + raise Exception( + "There was no sampling specified, when creating the file, nor when creating the entity." + ) entity = robofish.io.Entity.create_entity( self["entities"], @@ -285,18 +289,18 @@ class File(h5py.File): """ assert poses.ndim == 3 - assert poses.shape[2] == 4 - n = poses.shape[0] + assert poses.shape[2] in [3, 4] + agents = poses.shape[0] timesteps = poses.shape[1] - returned_names = [] + entity_names = [] - for i in range(n): + for i in range(agents): e_name = None if names is None else names[i] e_outline = ( outlines if outlines is None or outlines.ndim == 3 else outlines[i] ) - returned_names.append( + entity_names.append( self.create_entity( category=category, sampling=sampling, @@ -305,7 +309,7 @@ class File(h5py.File): outlines=e_outline, ) ) - return returned_names + return entity_names @property def entity_names(self) -> Iterable[str]: @@ -327,8 +331,13 @@ class File(h5py.File): def entity_poses(self): return self.select_entity_poses(None) - def select_entity_poses(self, predicate=None) -> Iterable: - """ Select an array of the poses of entities + @property + def entity_poses_rad(self): + return self.select_entity_poses(None, rad=True) + + def select_entity_poses(self, predicate=None, rad=False) -> Iterable: + """ TODO: Rework + Select an array of the poses of entities If no name or category is specified, all entities will be selected. @@ -346,7 +355,8 @@ class File(h5py.File): max_timesteps = max([0] + [e.positions.shape[0] for e in entities]) # Initialize poses output array - poses_output = np.empty((len(entities), max_timesteps, 4)) + pose_len = 3 if rad else 4 + poses_output = np.empty((len(entities), max_timesteps, pose_len)) poses_output[:] = np.nan # Fill poses output array @@ -361,7 +371,7 @@ class File(h5py.File): raise Exception( "Multiple samplings found, preventing return of a single array." ) - poses = entity.poses + poses = entity.poses_rad if rad else entity.poses poses_output[i][: poses.shape[0]] = poses i += 1 return poses_output diff --git a/src/robofish/io/validation.py b/src/robofish/io/validation.py index 0718d2d95b5685baa46147f3f390179c0b8a16f5..a1bde527f16a2b28431c6b1148a4ecfec50d2354 100644 --- a/src/robofish/io/validation.py +++ b/src/robofish/io/validation.py @@ -200,7 +200,9 @@ def validate(iofile: File, strict_validate: bool = True) -> (bool, str): if positions.shape[0] > 0: # validate range of poses - validate_poses_range(iofile, positions, e_name) + validate_positions_range( + iofile.attrs["world_size_cm"], positions, e_name + ) if "orientations" in entity: assert_validate( @@ -228,6 +230,8 @@ def validate(iofile: File, strict_validate: bool = True) -> (bool, str): e_name, ) + validate_orientations_length(orientations, e_name) + # outlines if "outlines" in entity: outlines = entity["outlines"] @@ -291,37 +295,48 @@ def validate(iofile: File, strict_validate: bool = True) -> (bool, str): return (True, "") -def validate_poses_range(iofile, poses, e_name): - # Poses which are just a bit over the world edge are fine +def validate_positions_range(world_size, positions, e_name): + # positions which are just a bit over the world edge are fine error_allowance = 1.01 allowed_x = [ - -1 * iofile.attrs["world_size_cm"][0] * error_allowance / 2, - iofile.attrs["world_size_cm"][0] * error_allowance / 2, + -1 * world_size[0] * error_allowance / 2, + world_size[0] * error_allowance / 2, ] - real_x = [poses[:, 0].min(), poses[:, 0].max()] + real_x = [positions[:, 0].min(), positions[:, 0].max()] allowed_y = [ - -1 * iofile.attrs["world_size_cm"][1] * error_allowance / 2.0, - iofile.attrs["world_size_cm"][1] * error_allowance / 2.0, + -1 * world_size[1] * error_allowance / 2.0, + world_size[1] * error_allowance / 2.0, ] - real_y = [poses[:, 1].min(), poses[:, 1].max()] + real_y = [positions[:, 1].min(), positions[:, 1].max()] assert_validate( allowed_x[0] <= real_x[0] and real_x[1] <= allowed_x[1], - "Poses of x axis were not in range. The allowed range is [%.1f, %.1f], which was [%.1f, %.1f] in the poses" + "Positions of x axis were not in range. The allowed range is [%.1f, %.1f], which was [%.1f, %.1f] in the Positions" % (allowed_x[0], allowed_x[1], real_x[0], real_x[1]), e_name, ) assert_validate( allowed_y[0] <= real_y[0] and real_y[1] <= allowed_y[1], - "Poses of y axis were not in range. The allowed range is [%.1f, %.1f], which was [%.1f, %.1f] in the poses" + "Positions of y axis were not in range. The allowed range is [%.1f, %.1f], which was [%.1f, %.1f] in the Positions" % (allowed_y[0], allowed_y[1], real_y[0], real_y[1]), e_name, ) +def validate_orientations_length(orientations, e_name): + ori_lengths = np.linalg.norm(orientations, axis=1) + + assert_validate( + np.isclose(ori_lengths, 1).all(), + "The orientation vectors were not unit vectors. Their length was in the range [%.2f, %.2f] when it should be 1" + % (min(ori_lengths), max(ori_lengths)), + e_name, + ) + + def validate_iso8601(str_val: str) -> bool: """This function validates strings to match the ISO8601 format. diff --git a/tests/robofish/io/test_examples.py b/tests/robofish/io/test_examples.py index fde15a4f0f6b50e318de30281b75c8fd47b1550b..c81b1eed99221e9d92a82f0d354352fa869cfbae 100644 --- a/tests/robofish/io/test_examples.py +++ b/tests/robofish/io/test_examples.py @@ -32,9 +32,14 @@ def manual_test_example_basic_ipynb(): # Executing the notebook should not lead to an exception with testbook(str(ipynb_path), execute=True) as tb: pass + # tb.ref("create_example_file")(path) + # path.unlink() if __name__ == "__main__": + print("example_readme.py") test_example_readme() + print("example_basic.py") test_example_basic() + print("example_basic.ipynb") manual_test_example_basic_ipynb() diff --git a/tests/robofish/io/test_file.py b/tests/robofish/io/test_file.py index e19031ea8008470df9560a2e685b8716826c6e11..31c54bc8382f82aa6376af26f08fcafe257f619a 100644 --- a/tests/robofish/io/test_file.py +++ b/tests/robofish/io/test_file.py @@ -8,6 +8,8 @@ import datetime import sys import logging +LOGGER = logging.getLogger(__name__) + valid_file_path = utils.full_path(__file__, "../../resources/valid.hdf5") created_by_test_path = utils.full_path(__file__, "../../resources/created_by_test.hdf5") created_by_test_path_2 = utils.full_path( @@ -42,6 +44,7 @@ def test_missing_attribute(): def test_single_entity_monotonic_step(): sf = robofish.io.File(world_size_cm=[100, 100], frequency_hz=(1000 / 40)) test_poses = np.ones(shape=(10, 4)) + test_poses[:, 3] = 0 # All Fish pointing right sf.create_entity("robofish", poses=test_poses) sf.create_entity("object_without_poses") print(sf) @@ -53,6 +56,7 @@ def test_single_entity_monotonic_time_points_us(): world_size_cm=[100, 100], monotonic_time_points_us=np.ones(10) ) test_poses = np.ones(shape=(10, 4)) + test_poses[:, 3] = 0 # All Fish pointing right sf.create_entity("robofish", poses=test_poses) print(sf) sf.validate() @@ -64,7 +68,7 @@ def test_multiple_entities(): poses = np.zeros((agents, timesteps, 4)) poses[1] = 1 - poses[2] = 2 + poses[:, :, 2:] = [1, 0] # All agents point to the right m_points = np.arange(timesteps) @@ -95,9 +99,9 @@ def test_multiple_entities(): assert (returned_poses == poses[:1]).all() # Insert some random obstacles - returned_names = sf.create_multiple_entities( - "obstacle", poses=np.random.random((agents, timesteps, 4)) - ) + obstacles = 3 + obs_poses = np.random.random((agents, 1, 3)) + returned_names = sf.create_multiple_entities("obstacle", poses=obs_poses) # Obstacles should not be returned when only fish are selected returned_poses = sf.select_entity_poses(lambda e: e.category == "fish") assert (returned_poses == poses).all() @@ -115,12 +119,6 @@ def test_multiple_entities(): monotonic_time_points_us=m_points, calendar_time_points=c_points ) - try: - sf.create_sampling(frequency_hz=25, monotonic_time_points_us=m_points) - raise Exception("This sampling should have created an error") - except: - pass - returned_names = sf.create_multiple_entities( "fish", poses, outlines=outlines, sampling=new_sampling ) @@ -129,6 +127,9 @@ def test_multiple_entities(): # pass an poses array in separate parts (positions, orientations) and retrieve it with poses. poses_arr = np.random.random((100, 4)) + poses_arr[:, 2:] /= np.atleast_2d( + np.linalg.norm(poses_arr[:, 2:], axis=1) + ).T # Normalization position_orientation_fish = sf.create_entity( "fish", positions=poses_arr[:, :2], orientations=poses_arr[:, 2:] ) @@ -138,6 +139,15 @@ def test_multiple_entities(): return sf +def test_broken_sampling(caplog): + sf = robofish.io.File(world_size_cm=[10, 10]) + caplog.set_level(logging.ERROR) + broken_sampling = sf.create_sampling( + name="broken sampling", frequency_hz=25, monotonic_time_points_us=np.ones((100)) + ) + assert "ERROR" in caplog.text + + def test_deprecated_get_poses(): f = test_multiple_entities() with pytest.warns(DeprecationWarning): @@ -146,6 +156,29 @@ def test_deprecated_get_poses(): assert f.get_poses(names="fish_1").shape[0] == 1 +def test_entity_poses_rad(caplog): + with robofish.io.File(world_size_cm=[100, 100], frequency_hz=25) as f: + # Create an entity, using radiants + f.create_entity("fish", poses=np.ones((100, 3))) + + # Read the poses of the file as radiants + np.testing.assert_almost_equal(f.entity_poses_rad, np.ones((1, 100, 3))) + + caplog.set_level(logging.WARNING) + # Passing orientations, which are bigger than 2 pi + f.create_entity("fish", poses=np.ones((100, 3)) * 50) + assert "WARNING" in caplog.text + + +def test_entity_positions_no_orientation(): + with robofish.io.File(world_size_cm=[100, 100], frequency_hz=25) as f: + # Create an entity, using radiants + f.create_entity("fish", positions=np.ones((100, 2))) + + assert f.entity_poses.shape == (1, 100, 4) + assert (f.entity_poses[:, :] == np.array([1, 1, 1, 0])).all() + + def test_load_validate(): sf = robofish.io.File(path=valid_file_path) sf.validate() @@ -173,6 +206,12 @@ def test_loading_saving(): # After saving, the file should still be accessible and valid sf.validate() + # Open the file again and add another entity + sf = robofish.io.File(created_by_test_path, "r+") + entity = sf.create_entity("fish", positions=np.ones((100, 2))) + sf.entity_poses + entity.poses + def test_validate_created_file_after_reloading(): sf = robofish.io.File(created_by_test_path) @@ -182,8 +221,9 @@ def test_validate_created_file_after_reloading(): # Cleanup test. The z in the name makes sure, that it is executed last in main def test_z_cleanup(): """ This cleans up after all tests and removes all test artifacts """ - created_by_test_path.unlink() - created_by_test_path_2.unlink() + for f in [created_by_test_path, created_by_test_path_2]: + if f.exists(): + f.unlink() if __name__ == "__main__":