diff --git a/.gitignore b/.gitignore index 9d824fe64a4c864975d4b78b493ca11df2fcb1a9..c93c2451ee6958af633a4e24fd66c34304682f73 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ **/__pycache__ +build dist *.egg *.egg-info* diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index e9300a2487b01778102c5d10a11ba8053384113c..d579b9a074657960670aa4eedd0dcf8c2bc73fed 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,23 +1,23 @@ stages: - - test - package + - test - deploy .centos: tags: [linux, docker] image: git.imp.fu-berlin.de:5000/bioroboticslab/auto/ci/centos:latest +.macos: + tags: [macos, shell] + .windows: tags: [windows, docker] image: git.imp.fu-berlin.de:5000/bioroboticslab/auto/ci/windows:latest-devel before_script: - . $Profile.AllUsersAllHosts -.cpy37: &cpy37 - PYTHON_EXECUTABLE: "python3.7" - -.cpy38: &cpy38 - PYTHON_EXECUTABLE: "python3.8" +.python38: &python38 + PYTHON_VERSION: "3.8" .test: &test stage: test @@ -28,79 +28,53 @@ stages: script: - ./ci/test.py -test centos[cpy37]: - extends: .centos - <<: *test - variables: - <<: [*cpy37] - -test centos[cpy38]: - extends: .centos - <<: *test - variables: - <<: [*cpy38] - -.test windows[cpy37]: - extends: .windows - <<: *test - .package: &package stage: package artifacts: paths: - dist expire_in: 1 week + reports: + dotenv: build.env script: - ./ci/package.py -package centos[cpy37]: +package: extends: .centos - dependencies: - - test centos[cpy37] <<: *package variables: - <<: [*cpy37] + <<: [*python38] -package centos[cpy38]: +"test: [centos, 3.8]": extends: .centos - dependencies: - - test centos[cpy38] - <<: *package - variables: - <<: [*cpy38] + <<: *test -.package windows[cpy37]: - extends: .windows - dependencies: - - test windows[cpy37] - <<: *package +"test: [macos, 3.8]": + extends: .macos + <<: *test -deploy centos[cpy37]: - extends: .centos - stage: deploy - only: - - tags - dependencies: - - package centos[cpy37] - script: - - ./ci/deploy.py +"test: [windows, 3.8]": + extends: .windows + <<: *test -.deploy centos[cpy38]: +deploy to staging: extends: .centos stage: deploy only: + - master - tags + allow_failure: true dependencies: - - package centos[cpy38] + - package script: - ./ci/deploy.py -.deploy windows[cpy37]: +deploy to production: extends: .centos stage: deploy only: - tags dependencies: - - package windows[cpy37] + - package script: - - ./ci/deploy.py + - ./ci/deploy.py --production diff --git a/ci/deploy.py b/ci/deploy.py index 6dda82b345943f9e96b35acb1d1858343e51a1c8..0437457696a996d508cb3b4235ba6feaeb49d52c 100755 --- a/ci/deploy.py +++ b/ci/deploy.py @@ -3,22 +3,31 @@ from os import environ as env from subprocess import check_call -from pathlib import Path from platform import system +from argparse import ArgumentParser if __name__ == "__main__": if system() != "Linux": raise Exception("Uploading python package only supported on Linux") + p = ArgumentParser() + p.add_argument("--production", default=False, action="store_const", const=True) + args = p.parse_args() + env["TWINE_USERNAME"] = "gitlab-ci-token" env["TWINE_PASSWORD"] = env["CI_JOB_TOKEN"] + if args.production: + target_project_id = env['ARTIFACTS_REPOSITORY_PROJECT_ID'] + else: + target_project_id = env['CI_PROJECT_ID'] + command = ["python3"] command += ["-m", "twine", "upload", "dist/*"] command += [ "--repository-url", - f"https://git.imp.fu-berlin.de/api/v4/projects/{env['ARTIFACTS_REPOSITORY_PROJECT_ID']}/packages/pypi", + f"https://git.imp.fu-berlin.de/api/v4/projects/{target_project_id}/packages/pypi", ] check_call(command) diff --git a/ci/package.py b/ci/package.py index 4be105e3f11445e3c9c8e9d183882fcac116d8ba..38ef2d7f540eb37068ae0a0df5123173962e2789 100755 --- a/ci/package.py +++ b/ci/package.py @@ -2,16 +2,17 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from os import environ as env -from platform import system from subprocess import check_call -from sys import executable +from pathlib import Path +from platform import system +from shutil import which def python_executable(): if system() == "Windows": - return executable - elif system() == "Linux": - return env["PYTHON_EXECUTABLE"] + return f"/Python{''.join(env['PYTHON_VERSION'].split('.'))}/python.exe" + elif system() == "Linux" or system() == "Darwin": + return which(f"python{env['PYTHON_VERSION']}") assert False @@ -20,3 +21,6 @@ if __name__ == "__main__": command += ["setup.py", "bdist_wheel"] check_call(command) + + with open("build.env", "w") as f: + f.write(f"PYTHON_VERSION={env['PYTHON_VERSION']}\n") diff --git a/ci/test.py b/ci/test.py index 992a7cf573a716ff839a26989fd45dab6ce66c6c..c6d90a39d599fc668a6a1ed525c2a940ddeeec4f 100755 --- a/ci/test.py +++ b/ci/test.py @@ -1,28 +1,52 @@ #! /usr/bin/env python3 # SPDX-License-Identifier: LGPL-3.0-or-later -from os import environ as env +from os import environ as env, pathsep from subprocess import check_call +from pathlib import Path from platform import system -from sys import executable +from shutil import which def python_executable(): if system() == "Windows": - return executable - elif system() == "Linux": - return env["PYTHON_EXECUTABLE"] + return f"/Python{''.join(env['PYTHON_VERSION'].split('.'))}/python.exe" + elif system() == "Linux" or system() == "Darwin": + return which(f"python{env['PYTHON_VERSION']}") + assert False + + +def python_venv_executable(): + if system() == "Windows": + return str(Path(".venv/Scripts/python.exe").resolve()) + elif system() == "Linux" or system() == "Darwin": + return str(Path(".venv/bin/python").resolve()) assert False if __name__ == "__main__": - check_call([python_executable(), "-m", "pip", "install", "pytest"]) - # check_call([python_executable(), "-m", "pip", "install", "pytest-cov"]) - check_call([python_executable(), "-m", "pip", "install", "h5py"]) - check_call([python_executable(), "-m", "pip", "install", "pandas"]) - check_call([python_executable(), "-m", "pip", "install", "deprecation"]) + check_call( + [ + python_executable(), + "-m", + "venv", + "--system-site-packages", + "--prompt", + "ci", + ".venv", + ], + ) - command = [python_executable()] - command += ["-m", "pytest", "--junitxml=report.xml"] + check_call( + [ + python_venv_executable(), + "-m", + "pip", + "install", + str(sorted(Path("dist").glob("*.whl"))[-1].resolve()), + ], + ) - check_call(command) + check_call( + [python_venv_executable(), "-m", "pytest", "--junitxml=report.xml"], + ) diff --git a/examples/example_basic.ipynb b/examples/example_basic.ipynb index 253ff2fceacac9ad10c6399edf19a2211ed55b9e..ada157912f8586719aa91e45432645fd65248322 100644 --- a/examples/example_basic.ipynb +++ b/examples/example_basic.ipynb @@ -148,7 +148,7 @@ " sf = robofish.io.File(path=example_file)\n", "\n", " print(\"\\nEntity Names\")\n", - " print(sf.get_entity_names())\n", + " print(sf.entity_names)\n", "\n", " # Get an array with all poses. As the length of poses varies per agent, it\n", " # is filled up with nans. The result is not interpolated and the time scales\n", @@ -156,10 +156,10 @@ " # different time scales and have another function, which generates an\n", " # interpolated array.\n", " print(\"\\nAll poses\")\n", - " print(sf.get_poses_array())\n", + " print(sf.select_entity_poses())\n", "\n", " print(\"\\nFish poses\")\n", - " print(sf.get_poses_array(fish_names))\n", + " print(sf.select_entity_poses(lambda e: e.category == \"fish\"))\n", "\n", " print(\"\\nFile structure\")\n", " print(sf)\n" @@ -187,4 +187,4 @@ }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file diff --git a/examples/example_basic.py b/examples/example_basic.py index e328d380144d327512f7476e3493e687b57b43d8..b8d149ce4383025767d1d262335c4f20da3c2cc1 100755 --- a/examples/example_basic.py +++ b/examples/example_basic.py @@ -46,14 +46,14 @@ if __name__ == "__main__": sf = robofish.io.File(path=example_file) print("\nEntity Names") - print(sf.get_entity_names()) + print(sf.entity_names) # Get an array with all poses. As the length of poses varies per agent, it is filled up with nans. print("\nAll poses") - print(sf.get_poses()) + print(sf.entity_poses) print("\nFish poses") - print(sf.get_poses(category="fish")) + print(sf.select_entity_poses(lambda e: e.category == "fish")) print("\nFile structure") print(sf) diff --git a/setup.py b/setup.py index ce7fa01d0447d8634eda4aa8e6a8ad47b685f6c5..5be2fe4573c17f9dbb776058058ea0e38c4674aa 100644 --- a/setup.py +++ b/setup.py @@ -1,5 +1,7 @@ # SPDX-License-Identifier: LGPL-3.0-or-later +from subprocess import run, PIPE + from setuptools import setup, find_packages @@ -11,9 +13,32 @@ entry_points = { "robofish-io-evaluate=robofish.evaluate.app:evaluate", ] } + +def source_version(): + version_parts = ( + run(["git", "describe", "--tags", "--dirty"], check=True, stdout=PIPE, encoding="utf-8") + .stdout.strip() + .split("-") + ) + + if version_parts[-1] == "dirty": + dirty = True + version_parts = version_parts[:-1] + else: + dirty = False + + version = version_parts[0] + if len(version_parts) == 3: + version += ".post0" + version += f".dev{version_parts[1]}+{version_parts[2]}" + if dirty: + version += "+dirty" + + return version + setup( name="robofish-io", - version="0.1", + version=source_version(), author="", author_email="", install_requires=["h5py>=3", "numpy", "seaborn", "pandas", "deprecation"], @@ -27,7 +52,7 @@ setup( "Programming Language :: Python :: 3.8", ], python_requires=">=3.6", - packages=find_packages("src"), + packages=[f"robofish.{p}" for p in find_packages("src/robofish")], package_dir={"": "src"}, zip_safe=True, entry_points=entry_points, diff --git a/src/robofish/io/entity.py b/src/robofish/io/entity.py index e5329988e71b9f8e54018aee9e339b614ad73f64..89e5c22de9fac242391161760af0fd2d0779f8d7 100644 --- a/src/robofish/io/entity.py +++ b/src/robofish/io/entity.py @@ -54,8 +54,17 @@ class Entity(h5py.Group): ori_vec[:, 1] = np.sin(ori_rad[:, 0]) return ori_vec - def getName(self): - return self.name.split("/")[-1] + @property + def group_name(self): + return super().name + + @property + def name(self): + return self.group_name.split("/")[-1] + + @property + def category(self): + return self.attrs["category"] def create_outlines(self, outlines: Iterable, sampling=None): outlines = self.create_dataset("outlines", data=outlines, dtype=np.float32) @@ -98,12 +107,21 @@ class Entity(h5py.Group): positions.attrs["sampling"] = sampling orientations.attrs["sampling"] = sampling - def get_poses(self): - poses = np.concatenate([self["positions"], self["orientations"]], axis=1) - return poses + @property + def positions(self): + return self["positions"] + + @property + def orientations(self): + return self["orientations"] + + @property + def poses(self): + return np.concatenate([self.positions, self.orientations], axis=1) - def get_poses_rad(self): - poses = self.get_poses() + @property + def poses_rad(self): + poses = self.poses # calculate the angles from the orientation vectors, write them to the third row and delete the fourth row poses[:, 2] = np.arctan2(poses[:, 3], poses[:, 2]) poses = poses[:, :3] diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index deefcc9a771d6e54966224be4be1dc1abbb69c16..3bc5833140f9edbe3f87bf23227ac55a0e654610 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -28,26 +28,25 @@ import tempfile import uuid import deprecation - -temp_dir = tempfile.TemporaryDirectory() default_format_version = np.array([1, 0], dtype=np.int32) default_format_url = ( "https://git.imp.fu-berlin.de/bioroboticslab/robofish/track_format/-/releases/1.0" ) - class File(h5py.File): - """ Represents a hdf5 file, which should be used to store data about the - movement and shape of individuals or swarms in time. + """ Represents a RoboFish Track Format file, which should be used to store tracking data of individual animals or swarms. - This class extends the h5py.File class, and behaves in the same way. - robofish io Files can be created, saved, loaded, and manipulated. + Files can be opened (with optional creation), modified inplace, and have copies of them saved. """ + _temp_dir = None + def __init__( self, path: Union[str, Path] = None, + mode: str = "r", + *, # PEP 3102 world_size_cm: [int, int] = None, strict_validate: bool = False, format_version: [int, int] = default_format_version, @@ -57,25 +56,47 @@ class File(h5py.File): monotonic_time_points_us: Iterable = None, calendar_time_points: Iterable = None, ): - """ Constructor for the File class. - - The constructor should either be called with a path, when loading an existing file, - or when a new file should be created with a world size. - - Args: - path: optional path to a io file as a string or path object. The file will be loaded. - world_size_cm: optional integer array of the world size in cm - strict_validate: optional boolean, if the file should be strictly validated, when loaded from a path. The default is False. - format_version: optional version [major, minor] of the trackformat specification + """Create a new RoboFish Track Format object. + + When called with a path, it is loaded, otherwise a new temporary file is created. + + Parameters + ---------- + path : str or Path, optional + Location of file to be opened. If not provided, mode is ignored. + + mode : str + r Readonly, file must exist (default) + r+ Read/write, file must exist + w Create file, truncate if exists + x Create file, fail if exists + a Read/write if exists, create otherwise + + world_size_cm + optional integer array of the world size in cm + strict_validate + optional boolean, if the file should be strictly validated, when loaded from a path. The default is False. + format_version + optional version [major, minor] of the trackformat specification """ - self._name = str(uuid.uuid4()) - self._tf_path = Path(temp_dir.name) / self._name - self.load(path, strict_validate) - - if path is None or not Path(path).exists(): + if path is None: + if type(self)._temp_dir is None: + type(self)._temp_dir = tempfile.TemporaryDirectory(prefix="robofish-io-") + super().__init__(Path(type(self)._temp_dir.name) / str(uuid.uuid4()), mode="x", driver="core", backing_store=True, libver=("earliest", "v112")) + initialize = True + else: + #mode + #r Readonly, file must exist (default) + #r+ Read/write, file must exist + #w Create file, truncate if exists + #x Create file, fail if exists + #a Read/write if exists, create otherwise + logging.info(f"Opening File {path}") + initialize = not Path(path).exists() + super().__init__(path, mode, libver=("earliest", "v112")) - # Initialize new file + if initialize: assert world_size_cm is not None and format_version is not None self.attrs["world_size_cm"] = np.array(world_size_cm, dtype=np.float32) @@ -93,58 +114,32 @@ class File(h5py.File): calendar_time_points=calendar_time_points, default=True, ) + self.validate(strict_validate) - #### File Handling #### - def load(self, path: Union[str, Path], strict_validate: bool = False) -> None: - """ Load a new file from a path + def __enter__(self): + return self - Args: - path: path to a io file as a string or path object - strict_validate: optional boolean, if the file should be strictly validated, when loaded from a path. The default is False. - """ - if path is not None: - self._f_path = Path(path) - - if path is not None and Path(path).exists(): - logging.info(f"Opening File {path}") - shutil.copyfile(self._f_path, self._tf_path) - super().__init__(self._tf_path, "r+") - self.validate(strict_validate) - else: - super().__init__(self._tf_path, "w") + def __exit__(self, type, value, traceback): + self.validate() + self.close() - def save(self, path: Union[str, Path] = None, strict_validate: bool = True): - """ Load a new file from a path + def save_as(self, path: Union[str, Path], strict_validate: bool = True): + """ Save a copy of the file Args: - path: optional path to a io file as a string or path object. If no path is specified, the last known path (from loading or saving) is used. + path: path to a io file as a string or path object. If no path is specified, the last known path (from loading or saving) is used. strict_validate: optional boolean, if the file should be strictly validated, before saving. The default is True. """ - # Normaly only valid files can be saved self.validate(strict_validate=strict_validate) - # Find the correct path - if path is None: - if self._f_path is None: - raise Exception( - "path was not specified and there was no saved path from loading or an earlier save" - ) - else: - self._f_path = Path(path) - - # Close the temporary file - self.close() - - # Create the parent folder if it does not exist - if not self._f_path.parent.exists(): - self._f_path.parent.mkdir(parents=True, exist_ok=True) + # Ensure all buffered data has been written to disk + self.flush() - # Copy the temporaryself.create_sampling(frequency_hz, monotonic_time_points_us) file to the path - shutil.copyfile(self._tf_path, self._f_path) + path = Path(path).resolve() + path.parent.mkdir(parents=True, exist_ok=True) - # Reopen the temporary file - super().__init__(self._tf_path, "r+") + shutil.copyfile(Path(self.filename).resolve(), path) def create_sampling( self, @@ -206,7 +201,13 @@ class File(h5py.File): self["samplings"].attrs["default"] = self.default_sampling return name - def get_frequency(self): + @property + def world_size(self): + return self.attrs["world_size_cm"] + + @property + def frequency(self): + # NOTE: Only works if default sampling availabe and specified with frequency_hz. default_sampling = self["samplings"].attrs["default"] return self["samplings"][default_sampling].attrs["frequency_hz"] @@ -295,13 +296,8 @@ class File(h5py.File): ) return returned_names - def get_entities(self): - return { - e_name: robofish.io.Entity.from_h5py_group(e_group) - for e_name, e_group in self["entities"].items() - } - - def get_entity_names(self) -> Iterable[str]: + @property + def entity_names(self) -> Iterable[str]: """ Getter for the names of all entities Returns: @@ -309,8 +305,16 @@ class File(h5py.File): """ return sorted(self["entities"].keys()) - def get_poses(self, names: Iterable = None, category: str = None) -> Iterable: - """ Get an array of the poses of entities + @property + def entities(self): + return [robofish.io.Entity.from_h5py_group(self["entities"][name]) for name in self.entity_names] + + @property + def entity_poses(self): + return self.select_entity_poses(None) + + def select_entity_poses(self, predicate = None) -> Iterable: + """ Select an array of the poses of entities If no name or category is specified, all entities will be selected. @@ -321,57 +325,29 @@ class File(h5py.File): An three dimensional array of all poses with the shape (entity, time, 4) """ - if names is not None and category is not None: - logging.error("Specify either names or a category, not both.") - raise Exception + entities = self.entities + if predicate is not None: + entities = [e for e in entities if predicate(e)] - # collect the names of all entities with the correct category - if category is not None: - names = [ - e_name - for e_name, e_data in self["entities"].items() - if e_data.attrs["category"] == category - ] - - entities = self.get_entities() - - # If no names or category are given, select all - if names is None: - names = sorted(entities.keys()) - - # Entity objects given as names - if all([type(name) == robofish.io.Entity for name in names]): - names = [entity.getName() for entity in names] - - if not all([type(name) == str for name in names]): - raise Exception( - "Given names were not strings. Instead names were %s" % names - ) - - max_timesteps = ( - 0 - if len(names) == 0 - else max([entities[e_name]["positions"].shape[0] for e_name in names]) - ) + max_timesteps = max([0] + [e.positions.shape[0] for e in entities]) # Initialize poses output array - poses_output = np.empty((len(names), max_timesteps, 4)) + poses_output = np.empty((len(entities), max_timesteps, 4)) poses_output[:] = np.nan # Fill poses output array i = 0 custom_sampling = None - for name in names: - entity = entities[name] + for entity in entities: if "sampling" in entity["positions"].attrs: if custom_sampling is None: custom_sampling = entity["positions"].attrs["sampling"] elif custom_sampling != entity["positions"].attrs["sampling"]: raise Exception( - "Multiple samplings found, which can not be given back by the get_poses function collectively." + "Multiple samplings found, preventing return of a single array." ) - poses = entity.get_poses() + poses = entity.poses poses_output[i][: poses.shape[0]] = poses i += 1 return poses_output diff --git a/src/robofish/io/validation.py b/src/robofish/io/validation.py index 776984b606fb89469a0d02cc860ac4489439ea63..0718d2d95b5685baa46147f3f390179c0b8a16f5 100644 --- a/src/robofish/io/validation.py +++ b/src/robofish/io/validation.py @@ -150,7 +150,9 @@ def validate(iofile: File, strict_validate: bool = True) -> (bool, str): # validate entities assert_validate("entities" in iofile, "entities not found") - for e_name, entity in iofile.get_entities().items(): + for entity in iofile.entities: + e_name = entity.name + assert_validate( type(entity) == Entity, "Entity group was not a robofish.io.Entity object", diff --git a/tests/robofish/io/test_entity.py b/tests/robofish/io/test_entity.py index 0ebf09738a53d8f983b8a99f8c54de7c351b8bda..8b3f22d809398a433f88dfb68c881e9d5dab2f82 100644 --- a/tests/robofish/io/test_entity.py +++ b/tests/robofish/io/test_entity.py @@ -12,7 +12,7 @@ def test_entity_object(): sf = robofish.io.File(world_size_cm=[100, 100], frequency_hz=25) f = sf.create_entity("fish", positions=[[10, 10]]) assert type(f) == robofish.io.Entity - assert f.getName() == "fish_1" + assert f.name == "fish_1" assert f.attrs["category"] == "fish" print(dir(f)) print(f["positions"]) @@ -26,7 +26,7 @@ def test_entity_object(): f2 = sf.create_entity("fish", poses=poses_rad) assert type(f2["positions"]) == h5py.Dataset assert type(f2["orientations"]) == h5py.Dataset - poses_rad_retrieved = f2.get_poses_rad() + poses_rad_retrieved = f2.poses_rad # Check if retrieved rad poses is close to the original poses. # Internally always ori_x and ori_y are used. When retrieved, the range is from -pi to pi, so for some of our original data 2 pi has to be substracted. diff --git a/tests/robofish/io/test_file.py b/tests/robofish/io/test_file.py index 5af5742cc246f4b33556aebe6e00657024411e92..36d2dca9b7c0fb7bb7111211c882a1ce9e1061c0 100644 --- a/tests/robofish/io/test_file.py +++ b/tests/robofish/io/test_file.py @@ -26,11 +26,9 @@ def test_constructor(): def test_new_file_w_path(): sf = robofish.io.File( - created_by_test_path_2, world_size_cm=[100, 100], frequency_hz=25 + created_by_test_path_2, "w", world_size_cm=[100, 100], frequency_hz=25 ) - poses = (np.random.random((100, 3)) - 0.5) * 100 - sf.create_entity("fish", poses=poses) - sf.save() + sf.create_entity("fish") sf.validate() @@ -71,7 +69,7 @@ def test_multiple_entities(): sf = robofish.io.File(world_size_cm=[100, 100], monotonic_time_points_us=m_points) returned_entities = sf.create_multiple_entities("fish", poses) - returned_names = [entity.getName() for entity in returned_entities] + returned_names = [entity.name for entity in returned_entities] expected_names = ["fish_1", "fish_2", "fish_3"] print(returned_names) @@ -81,24 +79,26 @@ def test_multiple_entities(): sf.validate() # The returned poses should be equal to the inserted poses - returned_poses = sf.get_poses() + returned_poses = sf.entity_poses print(returned_poses) assert (returned_poses == poses).all() # Just get the array for some names - returned_poses = sf.get_poses(["fish_1", "fish_2"]) + returned_poses = sf.select_entity_poses(lambda e: e.name in ["fish_1", "fish_2"]) assert (returned_poses == poses[:2]).all() - # Falsely specify names and category - with pytest.raises(Exception): - sf.get_poses(names=["fish_1"], category="fish") + # Filter on both category and name + returned_poses = sf.select_entity_poses( + lambda e: e.category == "fish" and e.name == "fish_1" + ) + assert (returned_poses == poses[:1]).all() # Insert some random obstacles returned_names = sf.create_multiple_entities( "obstacle", poses=np.random.random((agents, timesteps, 4)) ) # Obstacles should not be returned when only fish are selected - returned_poses = sf.get_poses(category="fish") + returned_poses = sf.select_entity_poses(lambda e: e.category == "fish") assert (returned_poses == poses).all() # for each of the entities @@ -126,12 +126,12 @@ def test_multiple_entities(): print(returned_names) print(sf) - # pass an poses array in separate parts (positions, orientations) and retreive it with get_poses. + # pass an poses array in separate parts (positions, orientations) and retrieve it with poses. poses_arr = np.random.random((100, 4)) position_orientation_fish = sf.create_entity( "fish", positions=poses_arr[:, :2], orientations=poses_arr[:, 2:] ) - assert np.isclose(poses_arr, position_orientation_fish.get_poses()).all() + assert np.isclose(poses_arr, position_orientation_fish.poses).all() sf.validate() return sf @@ -144,7 +144,7 @@ def test_load_validate(): def test_get_entity_names(): sf = robofish.io.File(path=valid_file_path) - names = sf.get_entity_names() + names = sf.entity_names assert len(names) == 9 assert names[0] == "fish_1" assert names[1] == "fish_2" @@ -159,7 +159,7 @@ def test_loading_saving(): sf = test_multiple_entities() assert not created_by_test_path.exists() - sf.save(created_by_test_path) + sf.save_as(created_by_test_path) assert created_by_test_path.exists() # After saving, the file should still be accessible and valid