diff --git a/.gitignore b/.gitignore index c93c2451ee6958af633a4e24fd66c34304682f73..043b66a95a35968ee555bd349ade5dddc7a2dc0c 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ dist report.xml htmlcov docs +env !tests/resources/*.hdf5 *.ipynb_checkpoints diff --git a/README.md b/README.md index 677e3a4d077fc09e1c21b3eb3b01166e52e41268..5da734e05530add0c7ef723c9990f3634dbcd57f 100644 --- a/README.md +++ b/README.md @@ -32,33 +32,36 @@ import robofish.io import numpy as np -# By using the context, the file will be automatically validated -with robofish.io.File("Example.hdf5", "w", world_size_cm=[100, 100], frequency_hz=25.0) as f: - f.attrs["experiment_setup"] = "This is a simple example with made up data." - - # Create a single robot with 30 timesteps - # positions are passed separately - # orientations are passed as with two columns -> orientation_x and orientation_y - f.create_entity( - category="robot", - name="robot", - positions=np.zeros((100, 2)), - orientations=np.ones((100, 2)) * [0, 1], - ) - - # Create fishes with 30 poses (x, y, orientation_rad) - poses = np.zeros((100, 3)) - poses[:, 0] = np.arange(-50, 50) - poses[:, 1] = np.arange(-50, 50) - poses[:, 2] = np.arange(0, 2 * np.pi, step=2 * np.pi / 100) - fish = f.create_entity("fish", poses=poses) - fish.attrs["species"] = "My rotating spaghetti fish" - fish.attrs["fish_standard_length_cm"] = 10 - - # Show and save the file - print(f) - print("Poses Shape: ", f.entity_poses.shape) - +# Create a new robofish io file +f = robofish.io.File(world_size_cm=[100, 100], frequency_hz=25.0) +f.attrs["experiment_setup"] = "This is a simple example with made up data." + +# Create a new robot entity. Positions and orientations are passed +# separately in this example. Since the orientations have two columns, +# unit vectors are assumed (orientation_x, orientation_y) +f.create_entity( + category="robot", + name="robot", + positions=np.zeros((100, 2)), + orientations=np.ones((100, 2)) * [0, 1], +) + +# Create a new fish entity. +# In this case, we pass positions and orientations together (x, y, rad). +# Since it is a 3 column array, orientations in radiants are assumed. +poses = np.zeros((100, 3)) +poses[:, 0] = np.arange(-50, 50) +poses[:, 1] = np.arange(-50, 50) +poses[:, 2] = np.arange(0, 2 * np.pi, step=2 * np.pi / 100) +fish = f.create_entity("fish", poses=poses) +fish.attrs["species"] = "My rotating spaghetti fish" +fish.attrs["fish_standard_length_cm"] = 10 + +# Show and save the file +print(f) +print("Poses Shape: ", f.entity_poses.shape) + +f.save_as(path) ``` ### Evaluation @@ -68,7 +71,7 @@ Current modes are: - turn - tank_positions - trajectories -- follow_ii +- follow_iid ## LICENSE diff --git a/ci/test.py b/ci/test.py index c6d90a39d599fc668a6a1ed525c2a940ddeeec4f..ab41980d6b6bae18e8bf990f162b20b57edd82af 100755 --- a/ci/test.py +++ b/ci/test.py @@ -34,7 +34,7 @@ if __name__ == "__main__": "--prompt", "ci", ".venv", - ], + ] ) check_call( @@ -44,9 +44,7 @@ if __name__ == "__main__": "pip", "install", str(sorted(Path("dist").glob("*.whl"))[-1].resolve()), - ], + ] ) - check_call( - [python_venv_executable(), "-m", "pytest", "--junitxml=report.xml"], - ) + check_call([python_venv_executable(), "-m", "pytest", "--junitxml=report.xml"]) diff --git a/examples/example_basic.ipynb b/examples/example_basic.ipynb index 3569083dc85945861dea2aecdba701dbfe2d3327..eba81c56c7c497751bffce4c944b4ae3212dac15 100644 --- a/examples/example_basic.ipynb +++ b/examples/example_basic.ipynb @@ -2,7 +2,65 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "import robofish.io\n", + "import numpy as np\n", + "\n", + "\n", + "def create_example_file(path):\n", + " # Create a new io file object with a 100x100cm world\n", + " f = robofish.io.File(world_size_cm=[100, 100], frequency_hz=25.0)\n", + "\n", + " # create a simple obstacle, fixed in place, fixed outline\n", + " obstacle_outline = [[[-10, -10], [-10, 0], [0, 0], [0, -10]]]\n", + " obstacle_name = f.create_entity(\n", + " \"obstacle\", positions=[[50, 50]], orientations=[[0]], outlines=obstacle_outline\n", + " )\n", + "\n", + " # create a robofish with 1000 timesteps. If we would not give a name, the name would be generated to be robot_1.\n", + " robofish_timesteps = 1000\n", + " robofish_poses = np.tile([50, 50, 1, 0], (robofish_timesteps, 1))\n", + " robot = f.create_entity(\"robot\", name=\"robot\", poses=robofish_poses)\n", + "\n", + " # create multiple fishes with timestamps. Since we don't specify names, but only the type \"fish\" the fishes will be named [\"fish_1\", \"fish_2\", \"fish_3\"]\n", + " agents = 3\n", + " timesteps = 1000\n", + " # timestamps = np.linspace(0, timesteps + 1, timesteps)\n", + " agent_poses = np.random.random((agents, timesteps, 3))\n", + "\n", + " fishes = f.create_multiple_entities(\"fish\", agent_poses)\n", + "\n", + " # This would throw an exception if the file was invalid\n", + " f.validate()\n", + "\n", + " # Save file validates aswell\n", + " f.save_as(path)\n", + "\n", + " # Closing and opening files (just for demonstration). When opening with r+, we can read and write afterwards.\n", + " f.close()\n", + " f = robofish.io.File(path, \"r+\")\n", + "\n", + " print(\"\\nEntity Names\")\n", + " print(f.entity_names)\n", + "\n", + " # Get an array with all poses. As the length of poses varies per agent, it is filled up with nans.\n", + " print(\"\\nAll poses\")\n", + " print(f.entity_poses)\n", + "\n", + " # Select all entities with the category fish\n", + " print(\"\\nFish poses\")\n", + " print(f.select_entity_poses(lambda e: e.category == \"fish\"))\n", + "\n", + " print(\"\\nFinal file\")\n", + " print(f)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, "metadata": {}, "outputs": [ { @@ -14,29 +72,29 @@ "['fish_1', 'fish_2', 'fish_3', 'obstacle_1', 'robot']\n", "\n", "All poses\n", - "[[[1.36782631e-01 4.68791187e-01 2.42182687e-01 6.48771763e-01]\n", - " [7.93363988e-01 3.50356251e-01 3.64646018e-01 1.12773582e-01]\n", - " [2.60495663e-01 8.65413904e-01 2.21576691e-02 6.10984743e-01]\n", + "[[[5.20669639e-01 4.58831601e-02 9.51813161e-01 3.06678474e-01]\n", + " [4.54089224e-01 4.62401301e-01 7.62699127e-01 6.46753490e-01]\n", + " [9.67400968e-01 5.63468635e-01 9.99894381e-01 1.45336846e-02]\n", " ...\n", - " [8.34640980e-01 9.05928910e-01 7.33852565e-01 9.64702129e-01]\n", - " [5.20252772e-02 1.44064710e-01 9.47307721e-02 3.03431451e-01]\n", - " [1.99740291e-01 3.40426266e-02 5.65072261e-02 4.65648413e-01]]\n", + " [8.80278572e-02 3.80126536e-01 9.97331321e-01 7.30081499e-02]\n", + " [3.28493923e-01 6.17436647e-01 9.20374036e-01 3.91039193e-01]\n", + " [7.53752515e-02 2.74342328e-01 7.84065127e-01 6.20678544e-01]]\n", "\n", - " [[2.59539992e-01 3.93728465e-01 9.71068442e-01 6.31185651e-01]\n", - " [1.83127314e-01 1.43956905e-02 4.26256537e-01 5.80079734e-01]\n", - " [3.76704991e-01 6.72040820e-01 5.88318594e-02 3.48662198e-01]\n", + " [[4.45818096e-01 5.39128423e-01 5.66779137e-01 8.23869765e-01]\n", + " [4.38290328e-01 3.78231764e-01 9.88384008e-01 1.51977092e-01]\n", + " [1.82629541e-01 8.08078349e-01 8.27510357e-01 5.61450422e-01]\n", " ...\n", - " [3.48356485e-01 9.14380550e-01 4.48769242e-01 7.38050520e-01]\n", - " [6.12327099e-01 1.99107945e-01 8.90928864e-01 2.84884181e-02]\n", - " [8.27723503e-01 6.57829344e-01 4.65144426e-01 4.01587933e-01]]\n", + " [5.48624545e-02 9.00771558e-01 8.00554574e-01 5.99259913e-01]\n", + " [8.34412515e-01 2.47002933e-02 8.98045361e-01 4.39902902e-01]\n", + " [2.35550269e-01 5.17610013e-01 7.89557755e-01 6.13676250e-01]]\n", "\n", - " [[8.43813717e-01 3.53193313e-01 3.37166399e-01 4.36319530e-01]\n", - " [9.16836739e-01 2.00342074e-01 1.78178921e-01 6.41124010e-01]\n", - " [4.94625986e-01 2.56477743e-01 3.52547020e-01 7.87709892e-01]\n", + " [[4.05882373e-02 9.17375386e-01 9.94780242e-01 1.02040365e-01]\n", + " [8.21412623e-01 6.68225110e-01 9.35957074e-01 3.52114081e-01]\n", + " [8.17060947e-01 2.38414750e-01 9.85274673e-01 1.70979008e-01]\n", " ...\n", - " [5.02124608e-01 6.97787464e-01 2.58715957e-01 2.88078666e-01]\n", - " [9.08326209e-01 2.46579379e-01 1.59346357e-01 6.96394265e-01]\n", - " [3.52213830e-01 6.03601038e-02 2.54920900e-01 3.10269982e-01]]\n", + " [1.20162942e-01 4.16668624e-01 9.17142212e-01 3.98560166e-01]\n", + " [9.19857025e-01 7.85686851e-01 6.58272266e-01 7.52779961e-01]\n", + " [8.69306684e-01 5.62790215e-01 9.19795334e-01 3.92398477e-01]]\n", "\n", " [[5.00000000e+01 5.00000000e+01 1.00000000e+00 0.00000000e+00]\n", " [ nan nan nan nan]\n", @@ -46,38 +104,38 @@ " [ nan nan nan nan]\n", " [ nan nan nan nan]]\n", "\n", - " [[5.00000000e+01 5.00000000e+01 5.00000000e+01 5.00000000e+01]\n", - " [5.00000000e+01 5.00000000e+01 5.00000000e+01 5.00000000e+01]\n", - " [5.00000000e+01 5.00000000e+01 5.00000000e+01 5.00000000e+01]\n", + " [[5.00000000e+01 5.00000000e+01 1.00000000e+00 0.00000000e+00]\n", + " [5.00000000e+01 5.00000000e+01 1.00000000e+00 0.00000000e+00]\n", + " [5.00000000e+01 5.00000000e+01 1.00000000e+00 0.00000000e+00]\n", " ...\n", - " [5.00000000e+01 5.00000000e+01 5.00000000e+01 5.00000000e+01]\n", - " [5.00000000e+01 5.00000000e+01 5.00000000e+01 5.00000000e+01]\n", - " [5.00000000e+01 5.00000000e+01 5.00000000e+01 5.00000000e+01]]]\n", + " [5.00000000e+01 5.00000000e+01 1.00000000e+00 0.00000000e+00]\n", + " [5.00000000e+01 5.00000000e+01 1.00000000e+00 0.00000000e+00]\n", + " [5.00000000e+01 5.00000000e+01 1.00000000e+00 0.00000000e+00]]]\n", "\n", "Fish poses\n", - "[[[0.13678263 0.46879119 0.24218269 0.64877176]\n", - " [0.79336399 0.35035625 0.36464602 0.11277358]\n", - " [0.26049566 0.8654139 0.02215767 0.61098474]\n", + "[[[0.52066964 0.04588316 0.95181316 0.30667847]\n", + " [0.45408922 0.4624013 0.76269913 0.64675349]\n", + " [0.96740097 0.56346864 0.99989438 0.01453368]\n", " ...\n", - " [0.83464098 0.90592891 0.73385257 0.96470213]\n", - " [0.05202528 0.14406471 0.09473077 0.30343145]\n", - " [0.19974029 0.03404263 0.05650723 0.46564841]]\n", + " [0.08802786 0.38012654 0.99733132 0.07300815]\n", + " [0.32849392 0.61743665 0.92037404 0.39103919]\n", + " [0.07537525 0.27434233 0.78406513 0.62067854]]\n", "\n", - " [[0.25953999 0.39372846 0.97106844 0.63118565]\n", - " [0.18312731 0.01439569 0.42625654 0.58007973]\n", - " [0.37670499 0.67204082 0.05883186 0.3486622 ]\n", + " [[0.4458181 0.53912842 0.56677914 0.82386976]\n", + " [0.43829033 0.37823176 0.98838401 0.15197709]\n", + " [0.18262954 0.80807835 0.82751036 0.56145042]\n", " ...\n", - " [0.34835649 0.91438055 0.44876924 0.73805052]\n", - " [0.6123271 0.19910794 0.89092886 0.02848842]\n", - " [0.8277235 0.65782934 0.46514443 0.40158793]]\n", + " [0.05486245 0.90077156 0.80055457 0.59925991]\n", + " [0.83441252 0.02470029 0.89804536 0.4399029 ]\n", + " [0.23555027 0.51761001 0.78955775 0.61367625]]\n", "\n", - " [[0.84381372 0.35319331 0.3371664 0.43631953]\n", - " [0.91683674 0.20034207 0.17817892 0.64112401]\n", - " [0.49462599 0.25647774 0.35254702 0.78770989]\n", + " [[0.04058824 0.91737539 0.99478024 0.10204037]\n", + " [0.82141262 0.66822511 0.93595707 0.35211408]\n", + " [0.81706095 0.23841475 0.98527467 0.17097901]\n", " ...\n", - " [0.50212461 0.69778746 0.25871596 0.28807867]\n", - " [0.90832621 0.24657938 0.15934636 0.69639426]\n", - " [0.35221383 0.0603601 0.2549209 0.31026998]]]\n", + " [0.12016294 0.41666862 0.91714221 0.39856017]\n", + " [0.91985703 0.78568685 0.65827227 0.75277996]\n", + " [0.86930668 0.56279022 0.91979533 0.39239848]]]\n", "\n", "File structure\n", " format_url:\thttps://git.imp.fu-berlin.de/bioroboticslab/robofish/track_format/-/releases/1.0\n", @@ -114,64 +172,16 @@ } ], "source": [ - "#! /usr/bin/env python3\n", - "\n", - "import robofish.io\n", - "from robofish.io import utils\n", - "import numpy as np\n", - "\n", - "\n", - "def create_example_file(path):\n", - " # Create a new io file object with a 100x100cm world\n", - " sf = robofish.io.File(world_size_cm=[100, 100], frequency_hz=25.0)\n", - "\n", - " # create a simple obstacle, fixed in place, fixed outline\n", - " obstacle_outline = [[[-10, -10], [-10, 0], [0, 0], [0, -10]]]\n", - " obstacle_name = sf.create_entity(\n", - " \"obstacle\", positions=[[50, 50]], orientations=[[0]], outlines=obstacle_outline\n", - " )\n", - "\n", - " # create a robofish with 1000 timesteps. If we would not give a name, the name would be generated to be robot_1.\n", - " robofish_timesteps = 1000\n", - " robofish_poses = np.ones((robofish_timesteps, 4)) * 50\n", - " robot = sf.create_entity(\"robot\", robofish_poses, name=\"robot\")\n", - "\n", - " # create multiple fishes with timestamps. Since we don't specify names, but only the type \"fish\" the fishes will be named [\"fish_1\", \"fish_2\", \"fish_3\"]\n", - " agents = 3\n", - " timesteps = 1000\n", - " # timestamps = np.linspace(0, timesteps + 1, timesteps)\n", - " agent_poses = np.random.random((agents, timesteps, 4))\n", - "\n", - " fishes = sf.create_multiple_entities(\"fish\", agent_poses)\n", - "\n", - " # This would throw an exception if the file was invalid\n", - " sf.validate()\n", - "\n", - " # Save file validates aswell\n", - "\n", - " sf.save_as(path)\n", - "\n", - " # Closing and opening files (just for demonstration)\n", - " sf.close()\n", - " sf = robofish.io.File(path=path)\n", - "\n", - " print(\"\\nEntity Names\")\n", - " print(sf.entity_names)\n", - "\n", - " # Get an array with all poses. As the length of poses varies per agent, it is filled up with nans.\n", - " print(\"\\nAll poses\")\n", - " print(sf.entity_poses)\n", - "\n", - " print(\"\\nFish poses\")\n", - " print(sf.select_entity_poses(lambda e: e.category == \"fish\"))\n", - "\n", - " print(\"\\nFile structure\")\n", - " print(sf)\n", - "\n", - "\n", "if __name__ == \"__main__\":\n", - " create_example_file(\"example.hdf5\")\n" + " create_example_file(\"example.hdf5\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { diff --git a/examples/example_basic.py b/examples/example_basic.py index dec8e3be1c7997b33b36f9868f0df8b0c6c3fbe6..34e092296f6edde6301a13623c15f77a46ddb70a 100755 --- a/examples/example_basic.py +++ b/examples/example_basic.py @@ -1,56 +1,53 @@ -#! /usr/bin/env python3 - import robofish.io -from robofish.io import utils import numpy as np def create_example_file(path): # Create a new io file object with a 100x100cm world - sf = robofish.io.File(world_size_cm=[100, 100], frequency_hz=25.0) + f = robofish.io.File(world_size_cm=[100, 100], frequency_hz=25.0) # create a simple obstacle, fixed in place, fixed outline obstacle_outline = [[[-10, -10], [-10, 0], [0, 0], [0, -10]]] - obstacle_name = sf.create_entity( + obstacle_name = f.create_entity( "obstacle", positions=[[50, 50]], orientations=[[0]], outlines=obstacle_outline ) # create a robofish with 1000 timesteps. If we would not give a name, the name would be generated to be robot_1. robofish_timesteps = 1000 - robofish_poses = np.ones((robofish_timesteps, 4)) * 50 - robot = sf.create_entity("robot", robofish_poses, name="robot") + robofish_poses = np.tile([50, 50, 1, 0], (robofish_timesteps, 1)) + robot = f.create_entity("robot", name="robot", poses=robofish_poses) # create multiple fishes with timestamps. Since we don't specify names, but only the type "fish" the fishes will be named ["fish_1", "fish_2", "fish_3"] agents = 3 timesteps = 1000 # timestamps = np.linspace(0, timesteps + 1, timesteps) - agent_poses = np.random.random((agents, timesteps, 4)) + agent_poses = np.random.random((agents, timesteps, 3)) - fishes = sf.create_multiple_entities("fish", agent_poses) + fishes = f.create_multiple_entities("fish", agent_poses) # This would throw an exception if the file was invalid - sf.validate() + f.validate() # Save file validates aswell + f.save_as(path) - sf.save_as(path) - - # Closing and opening files (just for demonstration) - sf.close() - sf = robofish.io.File(path=path) + # Closing and opening files (just for demonstration). When opening with r+, we can read and write afterwards. + f.close() + f = robofish.io.File(path, "r+") print("\nEntity Names") - print(sf.entity_names) + print(f.entity_names) # Get an array with all poses. As the length of poses varies per agent, it is filled up with nans. print("\nAll poses") - print(sf.entity_poses) + print(f.entity_poses) + # Select all entities with the category fish print("\nFish poses") - print(sf.select_entity_poses(lambda e: e.category == "fish")) + print(f.select_entity_poses(lambda e: e.category == "fish")) - print("\nFile structure") - print(sf) + print("\nFinal file") + print(f) if __name__ == "__main__": diff --git a/examples/example_readme.py b/examples/example_readme.py index f489958f1fd994dce6c19c46cd56bb7aab1b16f2..62df2336f89281a0d9306428839e9aee2921ae4b 100644 --- a/examples/example_readme.py +++ b/examples/example_readme.py @@ -1,35 +1,38 @@ import robofish.io import numpy as np -from pathlib import Path def create_example_file(path): - # By using the context, the file will be automatically validated - with robofish.io.File(path, "w", world_size_cm=[100, 100], frequency_hz=25.0) as f: - f.attrs["experiment_setup"] = "This is a simple example with made up data." - - # Create a single robot with 30 timesteps - # positions are passed separately - # orientations are passed as with two columns -> orientation_x and orientation_y - f.create_entity( - category="robot", - name="robot", - positions=np.zeros((100, 2)), - orientations=np.ones((100, 2)) * [0, 1], - ) - - # Create fishes with 30 poses (x, y, orientation_rad) - poses = np.zeros((100, 3)) - poses[:, 0] = np.arange(-50, 50) - poses[:, 1] = np.arange(-50, 50) - poses[:, 2] = np.arange(0, 2 * np.pi, step=2 * np.pi / 100) - fish = f.create_entity("fish", poses=poses) - fish.attrs["species"] = "My rotating spaghetti fish" - fish.attrs["fish_standard_length_cm"] = 10 - - # Show and save the file - print(f) - print("Poses Shape: ", f.entity_poses.shape) + # Create a new robofish io file + f = robofish.io.File(world_size_cm=[100, 100], frequency_hz=25.0) + f.attrs["experiment_setup"] = "This is a simple example with made up data." + + # Create a new robot entity. Positions and orientations are passed + # separately in this example. Since the orientations have two columns, + # unit vectors are assumed (orientation_x, orientation_y) + f.create_entity( + category="robot", + name="robot", + positions=np.zeros((100, 2)), + orientations=np.ones((100, 2)) * [0, 1], + ) + + # Create a new fish entity. + # In this case, we pass positions and orientations together (x, y, rad). + # Since it is a 3 column array, orientations in radiants are assumed. + poses = np.zeros((100, 3)) + poses[:, 0] = np.arange(-50, 50) + poses[:, 1] = np.arange(-50, 50) + poses[:, 2] = np.arange(0, 2 * np.pi, step=2 * np.pi / 100) + fish = f.create_entity("fish", poses=poses) + fish.attrs["species"] = "My rotating spaghetti fish" + fish.attrs["fish_standard_length_cm"] = 10 + + # Show and save the file + print(f) + print("Poses Shape: ", f.entity_poses.shape) + + f.save_as(path) if __name__ == "__main__": diff --git a/setup.py b/setup.py index 5be2fe4573c17f9dbb776058058ea0e38c4674aa..552f54e96c461e588672a3238c82b98cee1b8030 100644 --- a/setup.py +++ b/setup.py @@ -14,9 +14,15 @@ entry_points = { ] } + def source_version(): version_parts = ( - run(["git", "describe", "--tags", "--dirty"], check=True, stdout=PIPE, encoding="utf-8") + run( + ["git", "describe", "--tags", "--dirty"], + check=True, + stdout=PIPE, + encoding="utf-8", + ) .stdout.strip() .split("-") ) @@ -36,12 +42,13 @@ def source_version(): return version + setup( name="robofish-io", version=source_version(), author="", author_email="", - install_requires=["h5py>=3", "numpy", "seaborn", "pandas", "deprecation"], + install_requires=["h5py>=2.10.0", "numpy", "seaborn", "pandas", "deprecation"], classifiers=[ "Development Status :: 3 - Alpha", "Intended Audience :: Science/Research", diff --git a/src/robofish/io/entity.py b/src/robofish/io/entity.py index 8ef557c3d7d0bdcdd010f47055a121ca3452114f..669ac1746ab15dd5c24e6eae1b6986bf97c6a49b 100644 --- a/src/robofish/io/entity.py +++ b/src/robofish/io/entity.py @@ -30,6 +30,12 @@ class Entity(h5py.Group): poses, positions, orientations, outlines ) + assert poses is None or (poses.ndim == 2 and poses.shape[1] in [3, 4]) + assert positions is None or (positions.ndim == 2 and positions.shape[1] == 2) + assert orientations is None or ( + orientations.ndim == 2 and orientations.shape[1] in [1, 2] + ) + # If no name is given, create one from type and an id if name is None: i = 1 @@ -53,6 +59,10 @@ class Entity(h5py.Group): @classmethod def convert_rad_to_vector(cla, orientations_rad): + if min(orientations_rad) < 0 or max(orientations_rad) > 2 * np.pi: + logging.warning( + "Converting orientations, from a bigger range than [0, 2 * pi]. When passing the orientations, they are assumed to be in radiants." + ) ori_rad = utils.np_array(orientations_rad) assert ori_rad.shape[1] == 1 ori_vec = np.empty((ori_rad.shape[0], 2)) @@ -98,7 +108,7 @@ class Entity(h5py.Group): else: if poses is not None: - assert poses.shape[1] == 3 or poses.shape[1] == 4 + assert poses.shape[1] in [3, 4] positions = poses[:, :2] orientations = poses[:, 2:] if orientations is not None and orientations.shape[1] == 1: @@ -107,9 +117,10 @@ class Entity(h5py.Group): positions = self.create_dataset( "positions", data=positions, dtype=np.float32 ) - orientations = self.create_dataset( - "orientations", data=orientations, dtype=np.float32 - ) + if orientations is not None: + orientations = self.create_dataset( + "orientations", data=orientations, dtype=np.float32 + ) if sampling is not None: positions.attrs["sampling"] = sampling @@ -121,6 +132,9 @@ class Entity(h5py.Group): @property def orientations(self): + if not "orientations" in self: + # If no orientation is given, the default direction is to the right + return np.tile([1, 0], (self.positions.shape[0], 1)) return self["orientations"] @property diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index 669c05c0bc3e019b099e4357e69ae07534318862..e4798c809fc28fe367e7b8198065622dcdbc341a 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -208,14 +208,19 @@ class File(h5py.File): ) if default: - self.default_sampling = name - self["samplings"].attrs["default"] = self.default_sampling + self["samplings"].attrs["default"] = name return name @property def world_size(self): return self.attrs["world_size_cm"] + @property + def default_sampling(self): + if "default" in self["samplings"].attrs: + return self["samplings"].attrs["default"] + return None + @property def frequency(self): # NOTE: Only works if default sampling availabe and specified with frequency_hz. @@ -245,11 +250,10 @@ class File(h5py.File): Name of the created entity """ - if sampling is None: - if not hasattr(self, "default_sampling"): - raise Exception( - "There was no sampling specified, when creating the file, nor when creating the entity." - ) + if sampling is None and self.default_sampling is None: + raise Exception( + "There was no sampling specified, when creating the file, nor when creating the entity." + ) entity = robofish.io.Entity.create_entity( self["entities"], @@ -285,18 +289,18 @@ class File(h5py.File): """ assert poses.ndim == 3 - assert poses.shape[2] == 4 - n = poses.shape[0] + assert poses.shape[2] in [3, 4] + agents = poses.shape[0] timesteps = poses.shape[1] - returned_names = [] + entity_names = [] - for i in range(n): + for i in range(agents): e_name = None if names is None else names[i] e_outline = ( outlines if outlines is None or outlines.ndim == 3 else outlines[i] ) - returned_names.append( + entity_names.append( self.create_entity( category=category, sampling=sampling, @@ -305,7 +309,7 @@ class File(h5py.File): outlines=e_outline, ) ) - return returned_names + return entity_names @property def entity_names(self) -> Iterable[str]: @@ -327,8 +331,13 @@ class File(h5py.File): def entity_poses(self): return self.select_entity_poses(None) - def select_entity_poses(self, predicate=None) -> Iterable: - """ Select an array of the poses of entities + @property + def entity_poses_rad(self): + return self.select_entity_poses(None, rad=True) + + def select_entity_poses(self, predicate=None, rad=False) -> Iterable: + """ TODO: Rework + Select an array of the poses of entities If no name or category is specified, all entities will be selected. @@ -346,7 +355,8 @@ class File(h5py.File): max_timesteps = max([0] + [e.positions.shape[0] for e in entities]) # Initialize poses output array - poses_output = np.empty((len(entities), max_timesteps, 4)) + pose_len = 3 if rad else 4 + poses_output = np.empty((len(entities), max_timesteps, pose_len)) poses_output[:] = np.nan # Fill poses output array @@ -361,7 +371,7 @@ class File(h5py.File): raise Exception( "Multiple samplings found, preventing return of a single array." ) - poses = entity.poses + poses = entity.poses_rad if rad else entity.poses poses_output[i][: poses.shape[0]] = poses i += 1 return poses_output diff --git a/src/robofish/io/validation.py b/src/robofish/io/validation.py index 0718d2d95b5685baa46147f3f390179c0b8a16f5..a1bde527f16a2b28431c6b1148a4ecfec50d2354 100644 --- a/src/robofish/io/validation.py +++ b/src/robofish/io/validation.py @@ -200,7 +200,9 @@ def validate(iofile: File, strict_validate: bool = True) -> (bool, str): if positions.shape[0] > 0: # validate range of poses - validate_poses_range(iofile, positions, e_name) + validate_positions_range( + iofile.attrs["world_size_cm"], positions, e_name + ) if "orientations" in entity: assert_validate( @@ -228,6 +230,8 @@ def validate(iofile: File, strict_validate: bool = True) -> (bool, str): e_name, ) + validate_orientations_length(orientations, e_name) + # outlines if "outlines" in entity: outlines = entity["outlines"] @@ -291,37 +295,48 @@ def validate(iofile: File, strict_validate: bool = True) -> (bool, str): return (True, "") -def validate_poses_range(iofile, poses, e_name): - # Poses which are just a bit over the world edge are fine +def validate_positions_range(world_size, positions, e_name): + # positions which are just a bit over the world edge are fine error_allowance = 1.01 allowed_x = [ - -1 * iofile.attrs["world_size_cm"][0] * error_allowance / 2, - iofile.attrs["world_size_cm"][0] * error_allowance / 2, + -1 * world_size[0] * error_allowance / 2, + world_size[0] * error_allowance / 2, ] - real_x = [poses[:, 0].min(), poses[:, 0].max()] + real_x = [positions[:, 0].min(), positions[:, 0].max()] allowed_y = [ - -1 * iofile.attrs["world_size_cm"][1] * error_allowance / 2.0, - iofile.attrs["world_size_cm"][1] * error_allowance / 2.0, + -1 * world_size[1] * error_allowance / 2.0, + world_size[1] * error_allowance / 2.0, ] - real_y = [poses[:, 1].min(), poses[:, 1].max()] + real_y = [positions[:, 1].min(), positions[:, 1].max()] assert_validate( allowed_x[0] <= real_x[0] and real_x[1] <= allowed_x[1], - "Poses of x axis were not in range. The allowed range is [%.1f, %.1f], which was [%.1f, %.1f] in the poses" + "Positions of x axis were not in range. The allowed range is [%.1f, %.1f], which was [%.1f, %.1f] in the Positions" % (allowed_x[0], allowed_x[1], real_x[0], real_x[1]), e_name, ) assert_validate( allowed_y[0] <= real_y[0] and real_y[1] <= allowed_y[1], - "Poses of y axis were not in range. The allowed range is [%.1f, %.1f], which was [%.1f, %.1f] in the poses" + "Positions of y axis were not in range. The allowed range is [%.1f, %.1f], which was [%.1f, %.1f] in the Positions" % (allowed_y[0], allowed_y[1], real_y[0], real_y[1]), e_name, ) +def validate_orientations_length(orientations, e_name): + ori_lengths = np.linalg.norm(orientations, axis=1) + + assert_validate( + np.isclose(ori_lengths, 1).all(), + "The orientation vectors were not unit vectors. Their length was in the range [%.2f, %.2f] when it should be 1" + % (min(ori_lengths), max(ori_lengths)), + e_name, + ) + + def validate_iso8601(str_val: str) -> bool: """This function validates strings to match the ISO8601 format. diff --git a/tests/robofish/evaluate/test_app_evaluate.py b/tests/robofish/evaluate/test_app_evaluate.py index 58ca278882b1756830fd66b7a37a4cb823d8199a..2a7fe486a95c961f41f591ea7c8ab533db6b2ccf 100644 --- a/tests/robofish/evaluate/test_app_evaluate.py +++ b/tests/robofish/evaluate/test_app_evaluate.py @@ -22,7 +22,7 @@ def test_app_validate(): self.names = None self.save_path = graphics_out - # TODO: Get rid of deprecated get_poses function + # TODO: Get rid of deprecation with pytest.warns(DeprecationWarning): app.evaluate(DummyArgs("speed")) graphics_out.unlink() diff --git a/tests/robofish/io/test_examples.py b/tests/robofish/io/test_examples.py index 1d0379fb1b06448e5f733c3690ff48a4f623592c..c81b1eed99221e9d92a82f0d354352fa869cfbae 100644 --- a/tests/robofish/io/test_examples.py +++ b/tests/robofish/io/test_examples.py @@ -1,12 +1,13 @@ import robofish.io from robofish.io import utils from pathlib import Path +from testbook import testbook import sys sys.path.append(str(utils.full_path(__file__, "../../../examples/"))) - +ipynb_path = utils.full_path(__file__, "../../../examples/example_basic.ipynb") path = utils.full_path(__file__, "../../../examples/tmp_example.hdf5") if path.exists(): path.unlink() @@ -24,3 +25,21 @@ def test_example_basic(): example_basic.create_example_file(path) path.unlink() + + +# This test can be executed manually. The CI/CD System has issues with testbook. +def manual_test_example_basic_ipynb(): + # Executing the notebook should not lead to an exception + with testbook(str(ipynb_path), execute=True) as tb: + pass + # tb.ref("create_example_file")(path) + # path.unlink() + + +if __name__ == "__main__": + print("example_readme.py") + test_example_readme() + print("example_basic.py") + test_example_basic() + print("example_basic.ipynb") + manual_test_example_basic_ipynb() diff --git a/tests/robofish/io/test_file.py b/tests/robofish/io/test_file.py index e19031ea8008470df9560a2e685b8716826c6e11..31c54bc8382f82aa6376af26f08fcafe257f619a 100644 --- a/tests/robofish/io/test_file.py +++ b/tests/robofish/io/test_file.py @@ -8,6 +8,8 @@ import datetime import sys import logging +LOGGER = logging.getLogger(__name__) + valid_file_path = utils.full_path(__file__, "../../resources/valid.hdf5") created_by_test_path = utils.full_path(__file__, "../../resources/created_by_test.hdf5") created_by_test_path_2 = utils.full_path( @@ -42,6 +44,7 @@ def test_missing_attribute(): def test_single_entity_monotonic_step(): sf = robofish.io.File(world_size_cm=[100, 100], frequency_hz=(1000 / 40)) test_poses = np.ones(shape=(10, 4)) + test_poses[:, 3] = 0 # All Fish pointing right sf.create_entity("robofish", poses=test_poses) sf.create_entity("object_without_poses") print(sf) @@ -53,6 +56,7 @@ def test_single_entity_monotonic_time_points_us(): world_size_cm=[100, 100], monotonic_time_points_us=np.ones(10) ) test_poses = np.ones(shape=(10, 4)) + test_poses[:, 3] = 0 # All Fish pointing right sf.create_entity("robofish", poses=test_poses) print(sf) sf.validate() @@ -64,7 +68,7 @@ def test_multiple_entities(): poses = np.zeros((agents, timesteps, 4)) poses[1] = 1 - poses[2] = 2 + poses[:, :, 2:] = [1, 0] # All agents point to the right m_points = np.arange(timesteps) @@ -95,9 +99,9 @@ def test_multiple_entities(): assert (returned_poses == poses[:1]).all() # Insert some random obstacles - returned_names = sf.create_multiple_entities( - "obstacle", poses=np.random.random((agents, timesteps, 4)) - ) + obstacles = 3 + obs_poses = np.random.random((agents, 1, 3)) + returned_names = sf.create_multiple_entities("obstacle", poses=obs_poses) # Obstacles should not be returned when only fish are selected returned_poses = sf.select_entity_poses(lambda e: e.category == "fish") assert (returned_poses == poses).all() @@ -115,12 +119,6 @@ def test_multiple_entities(): monotonic_time_points_us=m_points, calendar_time_points=c_points ) - try: - sf.create_sampling(frequency_hz=25, monotonic_time_points_us=m_points) - raise Exception("This sampling should have created an error") - except: - pass - returned_names = sf.create_multiple_entities( "fish", poses, outlines=outlines, sampling=new_sampling ) @@ -129,6 +127,9 @@ def test_multiple_entities(): # pass an poses array in separate parts (positions, orientations) and retrieve it with poses. poses_arr = np.random.random((100, 4)) + poses_arr[:, 2:] /= np.atleast_2d( + np.linalg.norm(poses_arr[:, 2:], axis=1) + ).T # Normalization position_orientation_fish = sf.create_entity( "fish", positions=poses_arr[:, :2], orientations=poses_arr[:, 2:] ) @@ -138,6 +139,15 @@ def test_multiple_entities(): return sf +def test_broken_sampling(caplog): + sf = robofish.io.File(world_size_cm=[10, 10]) + caplog.set_level(logging.ERROR) + broken_sampling = sf.create_sampling( + name="broken sampling", frequency_hz=25, monotonic_time_points_us=np.ones((100)) + ) + assert "ERROR" in caplog.text + + def test_deprecated_get_poses(): f = test_multiple_entities() with pytest.warns(DeprecationWarning): @@ -146,6 +156,29 @@ def test_deprecated_get_poses(): assert f.get_poses(names="fish_1").shape[0] == 1 +def test_entity_poses_rad(caplog): + with robofish.io.File(world_size_cm=[100, 100], frequency_hz=25) as f: + # Create an entity, using radiants + f.create_entity("fish", poses=np.ones((100, 3))) + + # Read the poses of the file as radiants + np.testing.assert_almost_equal(f.entity_poses_rad, np.ones((1, 100, 3))) + + caplog.set_level(logging.WARNING) + # Passing orientations, which are bigger than 2 pi + f.create_entity("fish", poses=np.ones((100, 3)) * 50) + assert "WARNING" in caplog.text + + +def test_entity_positions_no_orientation(): + with robofish.io.File(world_size_cm=[100, 100], frequency_hz=25) as f: + # Create an entity, using radiants + f.create_entity("fish", positions=np.ones((100, 2))) + + assert f.entity_poses.shape == (1, 100, 4) + assert (f.entity_poses[:, :] == np.array([1, 1, 1, 0])).all() + + def test_load_validate(): sf = robofish.io.File(path=valid_file_path) sf.validate() @@ -173,6 +206,12 @@ def test_loading_saving(): # After saving, the file should still be accessible and valid sf.validate() + # Open the file again and add another entity + sf = robofish.io.File(created_by_test_path, "r+") + entity = sf.create_entity("fish", positions=np.ones((100, 2))) + sf.entity_poses + entity.poses + def test_validate_created_file_after_reloading(): sf = robofish.io.File(created_by_test_path) @@ -182,8 +221,9 @@ def test_validate_created_file_after_reloading(): # Cleanup test. The z in the name makes sure, that it is executed last in main def test_z_cleanup(): """ This cleans up after all tests and removes all test artifacts """ - created_by_test_path.unlink() - created_by_test_path_2.unlink() + for f in [created_by_test_path, created_by_test_path_2]: + if f.exists(): + f.unlink() if __name__ == "__main__":