diff --git a/.gitignore b/.gitignore index 043b66a95a35968ee555bd349ade5dddc7a2dc0c..5f762743f6afcdf7443b26c2a4ee3288102fa82b 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,4 @@ env *.mp4 feature_requests.md +output_graph.png diff --git a/examples/example_readme.py b/examples/example_readme.py index 62df2336f89281a0d9306428839e9aee2921ae4b..197ada6ca723871aaa24aaa87b3aed6469311d28 100644 --- a/examples/example_readme.py +++ b/examples/example_readme.py @@ -10,11 +10,12 @@ def create_example_file(path): # Create a new robot entity. Positions and orientations are passed # separately in this example. Since the orientations have two columns, # unit vectors are assumed (orientation_x, orientation_y) + circle_rad = np.linspace(0, 2 * np.pi, num=100) f.create_entity( category="robot", name="robot", - positions=np.zeros((100, 2)), - orientations=np.ones((100, 2)) * [0, 1], + positions=np.stack((np.cos(circle_rad), np.sin(circle_rad))).T * 40, + orientations=np.stack((-np.sin(circle_rad), np.cos(circle_rad))).T, ) # Create a new fish entity. diff --git a/src/robofish/evaluate/app.py b/src/robofish/evaluate/app.py index ab7b17ecfc6628d4a6549a4886d0efa3a77da6ff..4d603f33ff39fd4b591ce33c8c607a55aa75c979 100644 --- a/src/robofish/evaluate/app.py +++ b/src/robofish/evaluate/app.py @@ -13,6 +13,21 @@ import robofish.evaluate import argparse +def function_dict(): + base = robofish.evaluate.evaluate + return { + "speed": base.evaluate_speed, + "turn": base.evaluate_turn, + "orientation": base.evaluate_orientation, + "relative_orientation": base.evaluate_relativeOrientation, + "distance_to_wall": base.evaluate_distanceToWall, + "tank_positions": base.evaluate_tankpositions, + "trajectories": base.evaluate_trajectories, + "evaluate_positionVec": base.evaluate_positionVec, + "follow_iid": base.evaluate_follow_iid, + } + + def evaluate(args=None): """This function can be called from the commandline to evaluate files. @@ -24,13 +39,7 @@ def evaluate(args=None): (robofish-io-evaluate --help for more info) """ - function_dict = { - "speed": robofish.evaluate.evaluate.evaluate_speed, - "turn": robofish.evaluate.evaluate.evaluate_turn, - "tank_positions": robofish.evaluate.evaluate.evaluate_tankpositions, - "trajectories": robofish.evaluate.evaluate.evaluate_trajectories, - "follow_iid": robofish.evaluate.evaluate.evaluate_follow_iid, - } + fdict = function_dict() parser = argparse.ArgumentParser( description="This function can be called from the commandline to evaluate files.\ @@ -42,7 +51,7 @@ def evaluate(args=None): parser.add_argument( "analysis_type", type=str, - choices=function_dict.keys(), + choices=fdict.keys(), help="The type of analysis.\ speed - A histogram of speeds\ turn - A histogram of angular velocities\ @@ -76,7 +85,7 @@ def evaluate(args=None): if args is None: args = parser.parse_args() - if args.analysis_type in function_dict: - function_dict[args.analysis_type](args.paths, args.names, args.save_path) + if args.analysis_type in fdict: + fdict[args.analysis_type](args.paths, args.names, args.save_path) else: print(f"Evaluation function not found {args.analysis_type}") diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py index 4db49e8766f255d1e0f8ab7bb7946988ca177447..db339f14f3ef5cbaea23da1cf3ccffa09809f804 100644 --- a/src/robofish/evaluate/evaluate.py +++ b/src/robofish/evaluate/evaluate.py @@ -160,7 +160,7 @@ def evaluate_orientation(paths, names=None, save_path=None, predicate=None): if names is None: ax[i].set_title("Mean orientation in tank") else: - ax[i].set_title("Mean orientation in tank (" + names[i] + ")") + ax[i].set_title("Mean orientation in tank (%s)" % names[i]) ax[i].set_xlabel("x [cm]") ax[i].set_ylabel("y [cm]") @@ -317,7 +317,7 @@ def evaluate_tankpositions(paths, names=None, save_path=None, predicate=None): names = paths for i in range(len(x_pos)): - ax[i].set_title("Tankpositions (" + names[i] + ")") + ax[i].set_title("Tankpositions (%s)" % names[i]) ax[i].set_xlim(-world_bounds[i][0] / 2, world_bounds[i][0] / 2) ax[i].set_ylim(-world_bounds[i][1] / 2, world_bounds[i][1] / 2) @@ -367,7 +367,7 @@ def evaluate_trajectories(paths, names=None, save_path=None, predicate=None): sns.scatterplot( x="x", y="y", hue="Agent", linewidth=0, s=4, data=pos[i][1], ax=ax[i] ) - ax[i].set_title("trajectories (" + names[i] + ")") + ax[i].set_title("Trajectories (%s)" % names[i]) ax[i].set_xlim(-world_bounds[i][0] / 2, world_bounds[i][0] / 2) ax[i].set_ylim(-world_bounds[i][1] / 2, world_bounds[i][1] / 2) ax[i].invert_yaxis() @@ -680,4 +680,4 @@ class SeabornFig2Grid: self.fig.canvas.draw() def _resize(self, evt=None): - self.sg.fig.set_size_inches(self.fig.get_size_inches()) \ No newline at end of file + self.sg.fig.set_size_inches(self.fig.get_size_inches()) diff --git a/src/robofish/io/validation.py b/src/robofish/io/validation.py index a1bde527f16a2b28431c6b1148a4ecfec50d2354..0e1ae54889dcbe9a674e19187ff6128d33889c84 100644 --- a/src/robofish/io/validation.py +++ b/src/robofish/io/validation.py @@ -6,7 +6,9 @@ import numpy as np import logging -def assert_validate(statement: bool, message: str, location: str = None) -> None: +def assert_validate( + statement: bool, message: str, location: str = None, strict_validate=True +) -> None: """ Assert the statement and attach the entity name to the error message. Args: @@ -18,9 +20,12 @@ def assert_validate(statement: bool, message: str, location: str = None) -> None """ if not statement: if location: - raise AssertionError("%s in %s" % (message, location)) - else: + message = "%s in %s" % (message, location) + + if strict_validate: raise AssertionError(message) + else: + logging.warning(message) def assert_validate_type( @@ -329,11 +334,13 @@ def validate_positions_range(world_size, positions, e_name): def validate_orientations_length(orientations, e_name): ori_lengths = np.linalg.norm(orientations, axis=1) + # Check if all orientation lengths are all 1. Different lengths cause warnings. assert_validate( np.isclose(ori_lengths, 1).all(), "The orientation vectors were not unit vectors. Their length was in the range [%.2f, %.2f] when it should be 1" % (min(ori_lengths), max(ori_lengths)), e_name, + strict_validate=False, ) diff --git a/tests/resources/valid.hdf5 b/tests/resources/valid.hdf5 index 0c14ecdad78385832491ff147824f4e259701a05..de8a1dabe0f28ecb0b1600aa01918328858e0f8f 100644 Binary files a/tests/resources/valid.hdf5 and b/tests/resources/valid.hdf5 differ diff --git a/tests/robofish/evaluate/test_app_evaluate.py b/tests/robofish/evaluate/test_app_evaluate.py index 2a7fe486a95c961f41f591ea7c8ab533db6b2ccf..fce024a3c4f7e030b82e301bd9c8871b675beca7 100644 --- a/tests/robofish/evaluate/test_app_evaluate.py +++ b/tests/robofish/evaluate/test_app_evaluate.py @@ -22,7 +22,6 @@ def test_app_validate(): self.names = None self.save_path = graphics_out - # TODO: Get rid of deprecation - with pytest.warns(DeprecationWarning): - app.evaluate(DummyArgs("speed")) + for mode in app.function_dict().keys(): + app.evaluate(DummyArgs(mode)) graphics_out.unlink() diff --git a/tests/robofish/io/test_file.py b/tests/robofish/io/test_file.py index 31c54bc8382f82aa6376af26f08fcafe257f619a..013d3a51056df5378bb133e2b2f5155c263b2e28 100644 --- a/tests/robofish/io/test_file.py +++ b/tests/robofish/io/test_file.py @@ -187,8 +187,8 @@ def test_load_validate(): def test_get_entity_names(): sf = robofish.io.File(path=valid_file_path) names = sf.entity_names - assert len(names) == 1 - assert names[0] == "fish_1" + assert len(names) == 2 + assert names == ["fish_1", "robot"] def test_File_without_path_or_worldsize():