diff --git a/.gitignore b/.gitignore index 764c5ec57cd932bc024ae62cff117e517559d4a3..0cd5109888fa7ddcebdb5a6386249d8b1aed7110 100644 --- a/.gitignore +++ b/.gitignore @@ -21,3 +21,4 @@ env feature_requests.md output_graph.png +.testmondata diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py index d46d2a643b5b1faebf3f373bf5f2ba297e691c32..fe48edbeeea0c0d22ba06cd3229f4487225e4562 100644 --- a/src/robofish/evaluate/evaluate.py +++ b/src/robofish/evaluate/evaluate.py @@ -705,6 +705,8 @@ def evaluate_all( fdict: A dictionary of strings and evaluation functions predicate: a lambda function, selecting entities (example: lambda e: e.category == "fish") + Returns: + An array of all paths to the created images. """ assert ( save_folder is not None and save_folder != "" @@ -712,6 +714,7 @@ def evaluate_all( save_folder.mkdir(exist_ok=True) + save_paths = [] if fdict is None: fdict = robofish.evaluate.app.function_dict() fdict.pop("all") @@ -720,7 +723,11 @@ def evaluate_all( for f_name, f_callable in t: t.set_description(f_name) t.refresh() # to show immediately the update - f_callable(paths, labels, save_folder / (f_name + ".png"), predicate) + save_path = save_folder / (f_name + ".png") + f_callable(paths, labels, save_path, predicate) + save_paths.append(save_path) + + return save_paths def calculate_follow(a, b): @@ -745,6 +752,20 @@ def calculate_iid(a, b): return np.linalg.norm(b[:, :2] - a[:, :2], axis=-1) +def calc_tlvc(a, b, tau_min, tau_max): + """ + Given two velocity series and both minimum and maximum time lag return the + time lagged velocity correlation from the first to the second series. + """ + length = tau_max - tau_min + return np.float32( + [ + (a[t] @ b[t + tau_min :][:length].T).mean() + for t in range(min(len(a), len(b) - tau_max + 1)) + ] + ) + + def normalize_series(x): """Normalize series. diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index 14272418e2ad968f57b562255fe4cd429bce4388..82cfefcad880788bf6d79895acf02ea7c344aaf2 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -37,6 +37,8 @@ import matplotlib.pyplot as plt from matplotlib import animation from matplotlib import patches +from subprocess import run + # Remember: Update docstring when updating these two global variables default_format_version = np.array([1, 0], dtype=np.int32) @@ -66,6 +68,7 @@ class File(h5py.File): frequency_hz: int = None, monotonic_time_points_us: Iterable = None, calendar_time_points: Iterable = None, + open_copy: bool = False, ): """Create a new RoboFish Track Format object. @@ -126,15 +129,35 @@ class File(h5py.File): refer to explanation of `sampling_name` calendar_time_points: Iterable of str, optional refer to explanation of `sampling_name` + open_copy: bool, optional + a temporary copy of the file will be opened instead of the file itself. """ - if path is None: - if type(self)._temp_dir is None: - type(self)._temp_dir = tempfile.TemporaryDirectory( - prefix="robofish-io-" - ) + if open_copy: + assert ( + path is not None + ), "A path has to be given if a copy should be opened." + + temp_file = self.temp_dir / str(uuid.uuid4()) + logging.info( + f"Copying file to temporary file and opening it:\n{path} -> {temp_file}" + ) + + shutil.copyfile(path, temp_file) + super().__init__( + temp_file, + mode="r+", + driver="core", + backing_store=True, + libver=("earliest", "v110"), + ) + initialize = False + + elif path is None: + temp_file = self.temp_dir / str(uuid.uuid4()) + logging.info(f"Opening New temporary file {temp_file}") super().__init__( - Path(type(self)._temp_dir.name) / str(uuid.uuid4()), + temp_file, mode="x", driver="core", backing_store=True, @@ -182,7 +205,12 @@ class File(h5py.File): self.validate() super().__exit__(type, value, traceback) - def save_as(self, path: Union[str, Path], strict_validate: bool = True, no_warning: bool=False): + def save_as( + self, + path: Union[str, Path], + strict_validate: bool = True, + no_warning: bool = False, + ): """Save a copy of the file Args: @@ -269,6 +297,13 @@ class File(h5py.File): self["samplings"].attrs["default"] = name return name + @property + def temp_dir(self): + cla = type(self) + if cla._temp_dir is None: + cla._temp_dir = tempfile.TemporaryDirectory(prefix="robofish-io-") + return Path(cla._temp_dir.name) + @property def world_size(self): return self.attrs["world_size_cm"] @@ -462,8 +497,8 @@ class File(h5py.File): @deprecation.deprecated( deprecated_in="0.2", removed_in="0.2.4", - details="We found that our calculation of 'speed_turn' is flawed and replaced it " - "with 'actions_speeds_turns'. The difference in calculation is, that the tracked " + details="We found that our calculation of 'entity_speeds_turns' is flawed and replaced it " + "with 'entity_actions_speeds_turns'. The difference in calculation is, that the tracked " "orientation is used now which gives the fish the ability to swim backwards. " "If you see this message and you don't know what to do, update all packages and if nothing helps, contact Andi.\n" "Don't ignore this warning, it's a serious issue.", @@ -613,6 +648,14 @@ class File(h5py.File): As there are render functions in gym_guppy and robofish.trackviewer, this function is a temporary addition. The goal should be to bring together the rendering tools.""" + if video_path is not None: + try: + run(["ffmpeg"], capture_output=True) + except Exception as e: + raise Exception( + f"ffmpeg is required to store videos. Please install it.\n{e}" + ) + def shape_vertices(scale=1) -> np.ndarray: base_shape = np.array( [ diff --git a/tests/robofish/io/test_file.py b/tests/robofish/io/test_file.py index 2ec1f26b2d0804f20c70d685c7576a3fe4a6bc05..601657d173f643d62502e5732c4ae5acb0a28166 100644 --- a/tests/robofish/io/test_file.py +++ b/tests/robofish/io/test_file.py @@ -16,6 +16,7 @@ valid_file_path = utils.full_path(__file__, "../../resources/valid_1.hdf5") def test_constructor(): sf = robofish.io.File(world_size_cm=[100, 100]) sf.validate() + sf.close() def test_context(): @@ -28,12 +29,14 @@ def test_new_file_w_path(tmp_path): sf = robofish.io.File(f, "w", world_size_cm=[100, 100], frequency_hz=25) sf.create_entity("fish") sf.validate() + sf.close() def test_missing_attribute(): sf = robofish.io.File(world_size_cm=[10, 10]) sf.attrs.pop("world_size_cm") assert not sf.validate(strict_validate=False)[0] + sf.close() def test_single_entity_monotonic_step(): @@ -44,6 +47,7 @@ def test_single_entity_monotonic_step(): sf.create_entity("object_without_poses") print(sf) sf.validate() + sf.close() def test_single_entity_monotonic_time_points_us(): @@ -55,6 +59,7 @@ def test_single_entity_monotonic_time_points_us(): sf.create_entity("robofish", poses=test_poses) print(sf) sf.validate() + sf.close() def test_multiple_entities(): @@ -174,6 +179,7 @@ def test_entity_positions_no_orientation(): def test_load_validate(): sf = robofish.io.File(path=valid_file_path) sf.validate() + sf.close() def test_get_entity_names(): @@ -181,6 +187,19 @@ def test_get_entity_names(): names = sf.entity_names assert len(names) == 2 assert names == ["fish_1", "robot"] + sf.close() + + +def test_load_copy(): + # Open a copy of a file and change an attribute + sf = robofish.io.File(path=valid_file_path, open_copy=True) + sf.attrs["test"] = "test" + sf.close() + + # When reopening the file, the attribute should be not saved. + sf = robofish.io.File(valid_file_path, "r") + assert "test" not in sf.attrs + sf.close() def test_File_without_path_or_worldsize(): @@ -210,6 +229,7 @@ def test_loading_saving(tmp_path): entity.poses sf.validate() + sf.close() if __name__ == "__main__":