Skip to content
Snippets Groups Projects
Commit 4a73d10d authored by Andi Gerken's avatar Andi Gerken
Browse files

Updated tqdm for evaluation

parent 8c5f1d83
Branches
No related tags found
No related merge requests found
Pipeline #48800 passed
...@@ -838,7 +838,6 @@ def evaluate_all( ...@@ -838,7 +838,6 @@ def evaluate_all(
fdict = robofish.evaluate.app.function_dict() fdict = robofish.evaluate.app.function_dict()
fdict.pop("all") fdict.pop("all")
print("Loading all poses and actions.")
poses_from_paths, file_settings = utils.get_all_poses_from_paths(paths) poses_from_paths, file_settings = utils.get_all_poses_from_paths(paths)
speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths( speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths(
paths, "speeds_turns" paths, "speeds_turns"
......
...@@ -2,7 +2,7 @@ import robofish.io ...@@ -2,7 +2,7 @@ import robofish.io
import numpy as np import numpy as np
from typing import Union, Iterable from typing import Union, Iterable
from pathlib import Path from pathlib import Path
import os from tqdm import tqdm
def np_array(*arrays): def np_array(*arrays):
...@@ -78,6 +78,11 @@ def get_all_data_from_paths( ...@@ -78,6 +78,11 @@ def get_all_data_from_paths(
files_per_path = get_all_files_from_paths(paths) files_per_path = get_all_files_from_paths(paths)
pbar = tqdm(
total=sum([len(files_in_path) for files_in_path in files_per_path]),
desc=f"Loading {request_type} from files.",
)
# for each given path # for each given path
for files_in_path in files_per_path: for files_in_path in files_per_path:
data_from_files = [] data_from_files = []
...@@ -85,6 +90,8 @@ def get_all_data_from_paths( ...@@ -85,6 +90,8 @@ def get_all_data_from_paths(
# Open all files, gather the poses and check if all of them have the same world size and frequency # Open all files, gather the poses and check if all of them have the same world size and frequency
for i_path, file_path in enumerate(files_in_path): for i_path, file_path in enumerate(files_in_path):
with robofish.io.File(file_path, "r") as file: with robofish.io.File(file_path, "r") as file:
pbar.update(1)
pbar.refresh()
file_settings = { file_settings = {
"world_size_cm_x": file.attrs["world_size_cm"][0], "world_size_cm_x": file.attrs["world_size_cm"][0],
"world_size_cm_y": file.attrs["world_size_cm"][1], "world_size_cm_y": file.attrs["world_size_cm"][1],
...@@ -111,5 +118,5 @@ def get_all_data_from_paths( ...@@ -111,5 +118,5 @@ def get_all_data_from_paths(
data_from_files.append(data) data_from_files.append(data)
all_data.append(data_from_files) all_data.append(data_from_files)
pbar.close()
return all_data, expected_settings return all_data, expected_settings
...@@ -180,7 +180,7 @@ def validate(iofile: File, strict_validate: bool = True) -> (bool, str): ...@@ -180,7 +180,7 @@ def validate(iofile: File, strict_validate: bool = True) -> (bool, str):
) )
if "poses" in entity: if "poses" in entity:
raise Exception( raise Exception(
"The poses dataset is depricated. Please use positions and orientations." "The poses dataset is deprecated. Please use positions and orientations."
) )
if "positions" in entity: if "positions" in entity:
assert_validate( assert_validate(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment