Skip to content
Snippets Groups Projects
Commit 81dec7eb authored by Andi Gerken's avatar Andi Gerken
Browse files

Merge branch 'develop' into 'master'

Changed the "soft validation error" in form of a warning, Refined evaluation code

Closes #9

See merge request !12
parents 972891f4 fa2757d3
No related branches found
No related tags found
1 merge request!12Changed the "soft validation error" in form of a warning, Refined evaluation code
Pipeline #37761 passed
...@@ -84,13 +84,14 @@ def evaluate(args=None): ...@@ -84,13 +84,14 @@ def evaluate(args=None):
default=None, default=None,
) )
# TODO: ignore fish/ concider_names # TODO: ignore fish/ consider_names
if args is None: if args is None:
args = parser.parse_args() args = parser.parse_args()
if args.analysis_type in fdict: if args.analysis_type in fdict:
params = (args.paths, args.names, Path(args.save_path)) save_path = None if args.save_path is None else Path(args.save_path)
params = (args.paths, args.names, save_path)
if args.analysis_type == "all": if args.analysis_type == "all":
normal_functions = function_dict() normal_functions = function_dict()
normal_functions.pop("all") normal_functions.pop("all")
......
...@@ -363,8 +363,9 @@ def evaluate_distanceToWall( ...@@ -363,8 +363,9 @@ def evaluate_distanceToWall(
range=[0, min(worldBoundsX / 2, worldBoundsY / 2)], range=[0, min(worldBoundsX / 2, worldBoundsY / 2)],
) )
plt.title("Distance to closest wall") plt.title("Distance to closest wall")
plt.xlabel("Distance in cm") plt.xlabel("Distance [cm]")
plt.ylabel("Frequency") plt.ylabel("Frequency")
plt.tight_layout()
plt.ticklabel_format(useOffset=False) plt.ticklabel_format(useOffset=False)
plt.legend() plt.legend()
# plt.tight_layout() # plt.tight_layout()
...@@ -462,7 +463,7 @@ def evaluate_trajectories( ...@@ -462,7 +463,7 @@ def evaluate_trajectories(
) )
world_bounds.append(file.attrs["world_size_cm"]) world_bounds.append(file.attrs["world_size_cm"])
path_poses.append(poses[:, :, :2]) path_poses.append(poses[:, :, :2])
poses = np.concatenate(path_poses, axis = 1) poses = np.concatenate(path_poses, axis=1)
path_pos = { path_pos = {
fish: pd.DataFrame({"x": poses[fish, :, 0], "y": poses[fish, :, 1]}) fish: pd.DataFrame({"x": poses[fish, :, 0], "y": poses[fish, :, 1]})
...@@ -470,7 +471,7 @@ def evaluate_trajectories( ...@@ -470,7 +471,7 @@ def evaluate_trajectories(
} }
combined = pd.concat( combined = pd.concat(
[ [
path_pos[fish].assign(Agent=f"Agent {fish}") path_pos[fish].assign(Agent=f"{file.entity_names[fish]}")
for fish in path_pos.keys() for fish in path_pos.keys()
] ]
) )
...@@ -488,10 +489,11 @@ def evaluate_trajectories( ...@@ -488,10 +489,11 @@ def evaluate_trajectories(
sns.scatterplot( sns.scatterplot(
x="x", y="y", hue="Agent", linewidth=0, s=4, data=pos[i][1], ax=ax[i] x="x", y="y", hue="Agent", linewidth=0, s=4, data=pos[i][1], ax=ax[i]
) )
ax[i].set_title("Trajectories (%s)" % labels[i]) ax[i].set_title(f"Trajectories\n{labels[i]}")
ax[i].set_xlim(-world_bounds[i][0] / 2, world_bounds[i][0] / 2) ax[i].set_xlim(-world_bounds[i][0] / 2, world_bounds[i][0] / 2)
ax[i].set_ylim(-world_bounds[i][1] / 2, world_bounds[i][1] / 2) ax[i].set_ylim(-world_bounds[i][1] / 2, world_bounds[i][1] / 2)
# ax[i].invert_yaxis() ax[i].set_xlabel("x [cm]")
ax[i].set_ylabel("y [cm]")
ax[i].xaxis.set_ticks_position("top") ax[i].xaxis.set_ticks_position("top")
ax[i].xaxis.set_label_position("top") ax[i].xaxis.set_label_position("top")
ax[i].yaxis.set_ticks_position("left") ax[i].yaxis.set_ticks_position("left")
......
...@@ -9,7 +9,7 @@ import logging ...@@ -9,7 +9,7 @@ import logging
def assert_validate( def assert_validate(
statement: bool, message: str, location: str = None, strict_validate=True statement: bool, message: str, location: str = None, strict_validate=True
) -> None: ) -> None:
""" Assert the statement and attach the entity name to the error message. """Assert the statement and attach the entity name to the error message.
Args: Args:
statement: The statement, which should be tested. statement: The statement, which should be tested.
...@@ -31,7 +31,7 @@ def assert_validate( ...@@ -31,7 +31,7 @@ def assert_validate(
def assert_validate_type( def assert_validate_type(
object, expected_type, object_name: str, location: str = None object, expected_type, object_name: str, location: str = None
) -> None: ) -> None:
""" Assert the statement and attach the entity name to the error message. """Assert the statement and attach the entity name to the error message.
Args: Args:
statement: The statement, which should be tested. statement: The statement, which should be tested.
...@@ -291,7 +291,7 @@ def validate(iofile: File, strict_validate: bool = True) -> (bool, str): ...@@ -291,7 +291,7 @@ def validate(iofile: File, strict_validate: bool = True) -> (bool, str):
# e_name, # e_name,
# ) # )
except Exception as e: except AssertionError as e:
if strict_validate: if strict_validate:
raise e raise e
else: else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment