From 438877a5dc6760b9a55de7b7771d25fefc1cf768 Mon Sep 17 00:00:00 2001 From: Andi Gerken <andi.gerken@gmail.com> Date: Mon, 15 Nov 2021 18:51:15 +0000 Subject: [PATCH] Improved evaluation scripts - plots are accessible from python - labels, titles,... - Added new evaluation metrics (track_distance, individual_speeds) --- src/robofish/evaluate/app.py | 9 ++++-- src/robofish/evaluate/evaluate.py | 32 +++----------------- tests/robofish/evaluate/test_app_evaluate.py | 1 + tests/robofish/io/test_file.py | 21 ++++++++----- 4 files changed, 25 insertions(+), 38 deletions(-) diff --git a/src/robofish/evaluate/app.py b/src/robofish/evaluate/app.py index 97e3fdd..cd52952 100644 --- a/src/robofish/evaluate/app.py +++ b/src/robofish/evaluate/app.py @@ -28,7 +28,7 @@ def function_dict(): "tracks_distance": base.evaluate_tracks_distance, "evaluate_positionVec": base.evaluate_positionVec, "follow_iid": base.evaluate_follow_iid, - "avg_speeds": base.evaluate_avg_speed, + "individual_speeds": base.evaluate_individual_speeds, "all": base.evaluate_all, } @@ -74,7 +74,7 @@ def evaluate(args=None): help="The paths to files or folders. Multiple paths can be given to compare experiments.", ) parser.add_argument( - "--names", + "--labels", type=str, nargs="+", help="Names, that should be used in the graphs instead of the pahts.", @@ -96,8 +96,11 @@ def evaluate(args=None): raise Exception("When the analysis type is all, a path must be given.") if args.analysis_type in fdict: + if args.labels is None: + args.labels = args.paths + save_path = None if args.save_path is None else Path(args.save_path) - params = {"paths": args.paths, "labels": args.names} + params = {"paths": args.paths, "labels": args.labels} if args.analysis_type == "all": normal_functions = function_dict() normal_functions.pop("all") diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py index 2596b22..60c6e7f 100644 --- a/src/robofish/evaluate/evaluate.py +++ b/src/robofish/evaluate/evaluate.py @@ -92,9 +92,6 @@ def evaluate_speed( right_quantiles.append(np.quantile(path_speeds, 0.999)) speeds.append(path_speeds) - if labels is None: - labels = paths - fig = plt.figure() plt.hist( list(speeds), @@ -147,9 +144,6 @@ def evaluate_turn( right_quantiles.append(np.quantile(path_turns, 0.999)) turns.append(path_turns) - if labels is None: - labels = paths - fig = plt.figure() plt.hist( turns, @@ -212,18 +206,12 @@ def evaluate_orientation( if len(orientations) == 1: ax = [ax] - if labels is None: - labels = paths - for i in range(len(orientations)): orientation = orientations[i] s_1, x_edges, y_edges, bnr = orientation[0] s_2, x_edges, y_edges, bnr = orientation[1] - if labels is None: - ax[i].set_title("Mean orientation in tank") - else: - ax[i].set_title("Mean orientation in tank (%s)" % labels[i]) + ax[i].set_title("Mean orientation in tank (%s)" % labels[i]) ax[i].set_xlabel("x [cm]") ax[i].set_ylabel("y [cm]") @@ -278,9 +266,6 @@ def evaluate_relative_orientation( file.close() orientations.append(path_orientations) - if labels is None: - labels = paths - fig = plt.figure() plt.hist(orientations, bins=40, label=labels, density=True, range=[0, np.pi]) plt.title("Relative orientation") @@ -347,9 +332,6 @@ def evaluate_distanceToWall( worldBoundsX, worldBoundsY = max(worldBoundsX), max(worldBoundsY) - if labels is None: - labels = paths - fig = plt.figure() plt.hist( distances, @@ -405,8 +387,6 @@ def evaluate_tankpositions( fig, ax = plt.subplots(1, len(x_pos), figsize=(8 * len(x_pos), 8)) if len(x_pos) == 1: ax = [ax] - if labels is None: - labels = paths for i in range(len(x_pos)): ax[i].set_title("Tankpositions (%s)" % labels[i]) @@ -587,6 +567,7 @@ def evaluate_tracks( labels: Iterable[str] = None, predicate=None, lw_distances=False, + seed=42, ): """Evaluate the track. @@ -598,7 +579,7 @@ def evaluate_tracks( (example: lambda e: e.category == "fish") """ - random.seed() + random.seed(seed) files_per_path = [robofish.io.read_multiple_files(p) for p in paths] max_files_per_path = max([len(files) for files in files_per_path]) @@ -632,13 +613,13 @@ def evaluate_tracks( return fig -def evaluate_avg_speed( +def evaluate_individual_speeds( paths: Iterable[str], labels: Iterable[str] = None, predicate=None, lw_distances=False, ): - """Evaluate the track. + """Evaluate the average speeds per individual. Args: paths: An array of strings, with files of folders. @@ -652,9 +633,6 @@ def evaluate_avg_speed( fig = plt.figure(figsize=(10, 4)) - if labels is None: - labels = paths - offset = 0 for k, files in enumerate(files_per_path): all_avg_speeds = [] diff --git a/tests/robofish/evaluate/test_app_evaluate.py b/tests/robofish/evaluate/test_app_evaluate.py index ac3d8c0..5a89d39 100644 --- a/tests/robofish/evaluate/test_app_evaluate.py +++ b/tests/robofish/evaluate/test_app_evaluate.py @@ -22,6 +22,7 @@ def test_app_validate(tmp_path): self.names = None self.analysis_type = analysis_type self.save_path = save_path + self.labels = None for mode in app.function_dict().keys(): if mode == "all": diff --git a/tests/robofish/io/test_file.py b/tests/robofish/io/test_file.py index c4d800f..aa30a26 100644 --- a/tests/robofish/io/test_file.py +++ b/tests/robofish/io/test_file.py @@ -39,7 +39,7 @@ def test_missing_attribute(): sf.close() -def test_single_entity_monotonic_step(): +def test_single_entity_frequency_hz(): sf = robofish.io.File(world_size_cm=[100, 100], frequency_hz=(1000 / 40)) test_poses = np.ones(shape=(10, 4)) test_poses[:, 3] = 0 # All Fish pointing right @@ -51,9 +51,10 @@ def test_single_entity_monotonic_step(): def test_single_entity_monotonic_time_points_us(): - sf = robofish.io.File( - world_size_cm=[100, 100], monotonic_time_points_us=np.ones(10) - ) + with pytest.warns(Warning): + sf = robofish.io.File( + world_size_cm=[100, 100], monotonic_time_points_us=np.ones(10) + ) test_poses = np.ones(shape=(10, 4)) test_poses[:, 3] = 0 # All Fish pointing right sf.create_entity("robofish", poses=test_poses) @@ -72,7 +73,10 @@ def test_multiple_entities(): m_points = np.arange(timesteps) - sf = robofish.io.File(world_size_cm=[100, 100], monotonic_time_points_us=m_points) + with pytest.warns(Warning): + sf = robofish.io.File( + world_size_cm=[100, 100], monotonic_time_points_us=m_points + ) returned_entities = sf.create_multiple_entities("fish", poses) returned_names = [entity.name for entity in returned_entities] @@ -115,9 +119,10 @@ def test_multiple_entities(): c_points[:5] = "2020-12-02T10:21:58.100000+00:00" c_points[5:] = robofish.io.now_iso8061() - new_sampling = sf.create_sampling( - monotonic_time_points_us=m_points, calendar_time_points=c_points - ) + with pytest.warns(Warning): + new_sampling = sf.create_sampling( + monotonic_time_points_us=m_points, calendar_time_points=c_points + ) returned_names = sf.create_multiple_entities( "fish", poses, outlines=outlines, sampling=new_sampling -- GitLab