From ff4310a21c88031c932a182b2a9ab5daf276b948 Mon Sep 17 00:00:00 2001 From: Andi <andi.gerken@gmail.com> Date: Mon, 26 Aug 2024 14:40:36 +0200 Subject: [PATCH 1/8] Plotting start for tracks when plotting a track --- src/robofish/io/file.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index a1b82e0..1063960 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -971,6 +971,15 @@ class File(h5py.File): label="End", zorder=5, ) + ax.scatter( + [poses[:, 0, 0]], + [poses[:, 0, 1]], + marker="o", + c="black", + s=ms, + label="Start", + zorder=5, + ) if legend and isinstance(legend, str): ax.legend(legend) elif legend: -- GitLab From cba3451290457e1eaab17abfe47f9759313583b4 Mon Sep 17 00:00:00 2001 From: Andi Gerken <andi.gerken@gmail.com> Date: Thu, 20 Mar 2025 11:33:25 +0100 Subject: [PATCH 2/8] Added max_files to evaluation and better sorting when plotting. --- src/robofish/evaluate/app.py | 9 ++++++- src/robofish/evaluate/evaluate.py | 45 ++++++++++++++----------------- src/robofish/io/utils.py | 3 ++- 3 files changed, 30 insertions(+), 27 deletions(-) diff --git a/src/robofish/evaluate/app.py b/src/robofish/evaluate/app.py index c7d0a44..5c837bd 100644 --- a/src/robofish/evaluate/app.py +++ b/src/robofish/evaluate/app.py @@ -99,6 +99,12 @@ def evaluate(args: dict = None) -> None: help="Filename for saving resulting graphics.", default=None, ) + parser.add_argument( + "--max_files", + type=int, + default=None, + help="The maximum number of files to be loaded from the given paths.", + ) parser.add_argument( "--add_train_data", action="store_true", @@ -117,6 +123,7 @@ def evaluate(args: dict = None) -> None: if args.analysis_type in fdict: paths = args.paths labels = args.labels + max_files = args.max_files print("starting", paths, labels) if labels is None: @@ -141,7 +148,7 @@ def evaluate(args: dict = None) -> None: print("starting", paths, labels) save_path = None if args.save_path is None else Path(args.save_path) - params = {"paths": paths, "labels": labels} + params = {"paths": paths, "labels": labels, "max_files": max_files} if args.analysis_type == "all": normal_functions = function_dict() normal_functions.pop("all") diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py index bede8f7..1f4d2af 100644 --- a/src/robofish/evaluate/evaluate.py +++ b/src/robofish/evaluate/evaluate.py @@ -505,11 +505,6 @@ def evaluate_quiver( poses_from_paths = np.array(poses_from_paths) - # if speeds_turns_from_paths.dtype == np.object: - # print( - # "This will probably fail because the type of speeds_turns_from_path is object" - # ) - # print(speeds_turns_from_paths) try: print(speeds_turns_from_paths) speeds_turns_from_paths = np.stack( @@ -652,8 +647,10 @@ def evaluate_social_vector( try: from fish_models.models.pascals_lstms.attribution import SocialVectors - except ImportError: - warnings.warn("Please install the fish_models package to use this function.") + except ImportError as e: + warnings.warn( + "Please install the fish_models package to use this function.\n", e + ) return if poses_from_paths is None: @@ -769,7 +766,7 @@ def evaluate_follow_iid( follow_iid_data = pd.DataFrame( { - "IID [cm]": iid[i][mask], + f"IID [cm] {labels[i]}": iid[i][mask], "Follow": follow[i][mask], }, dtype=np.float32, @@ -777,7 +774,7 @@ def evaluate_follow_iid( plt.rcParams["lines.markersize"] = 1 grid = sns.jointplot( - x="IID [cm]", + x=f"IID [cm] {labels[i]}", y="Follow", data=follow_iid_data, kind="hist", @@ -788,25 +785,19 @@ def evaluate_follow_iid( joint_kws={"bins": 30}, marginal_kws=dict(bins=30), ) - # grid.fig.set_figwidth(9) - # grid.fig.set_figheight(6) - # grid.fig.subplots_adjust(top=0.9) grids.append(grid) # This is neccessary because joint plot does not receive an ax object. # It creates issues with too many plots though # Created an issue - fig = plt.figure(figsize=(6 * len(grids), 6)) + fig = plt.figure(figsize=(12, 5)) - fig.suptitle( - f"follow/iid: from left to right:\n{', '.join([str(l) for l in labels])}", - fontsize=12, - ) + gs = gridspec.GridSpec(1, len(grids)) - - for i in range(len(grids)): - SeabornFig2Grid(grids[i], fig, gs[i]) - + + for i in range(len(grids)): + SeabornFig2Grid(grids[i], fig, gs[i]) + return fig @@ -850,6 +841,7 @@ def evaluate_tracks( seed: int = 42, max_timesteps: int = None, verbose: bool = False, + max_files: int = None ) -> matplotlib.figure.Figure: """Evaluate the distances of two or more fish on the track. Lambda function example: lambda e: e.category == "fish" @@ -871,9 +863,11 @@ def evaluate_tracks( paths = [Path(p) for p in paths] random.seed(seed) - files_per_path = utils.get_all_files_from_paths(paths) + files_per_path = utils.get_all_files_from_paths(paths, max_files=max_files) max_files_per_path = max([len(files) for files in files_per_path]) + min_files_per_path = min([len(files) for files in files_per_path]) + rows, cols = len(files_per_path), min(6, max_files_per_path) multirow = False @@ -886,12 +880,13 @@ def evaluate_tracks( initial_poses_info_available = None all_initial_info = [] + selected_tracks = np.random.choice(range(min_files_per_path), cols, replace=False) + # Iterate all paths for k, files_in_path in enumerate(files_per_path): - random.shuffle(files_in_path) # Iterate all files - for i, file_path in enumerate(files_in_path): + for i, file_path in enumerate(np.array(files_in_path)[selected_tracks]): if multirow: if i >= cols * rows: break @@ -1340,7 +1335,7 @@ def show_values( color = (0.0, 0.0, 0.0) else: color = (1.0, 1.0, 1.0) - if value == "--": + if np.all(value == "--"): ax.text(x, y, fmt % value, ha="center", va="center", color=color, **kw) diff --git a/src/robofish/io/utils.py b/src/robofish/io/utils.py index 1a64666..940879b 100644 --- a/src/robofish/io/utils.py +++ b/src/robofish/io/utils.py @@ -48,7 +48,8 @@ def get_all_files_from_paths(paths: Iterable[Union[str, Path]], max_files=None): files_path = [] for ext in ("hdf", "hdf5", "h5", "he5"): files_path += list(path.rglob(f"*.{ext}")) - files.append(files_path[:max_files]) + + files.append(sorted(files_path)[:max_files]) else: files.append([path]) return files -- GitLab From 184f3c7bc92ef4c11cc6f099ea9b99081216d701 Mon Sep 17 00:00:00 2001 From: Andi <andi.gerken@gmail.com> Date: Fri, 21 Mar 2025 15:38:51 +0100 Subject: [PATCH 3/8] Caching entity objects in file for performance boost. --- src/robofish/io/entity.py | 2 +- src/robofish/io/file.py | 29 ++++++++++++++++++++--------- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/src/robofish/io/entity.py b/src/robofish/io/entity.py index c2e7a67..7ff5b34 100644 --- a/src/robofish/io/entity.py +++ b/src/robofish/io/entity.py @@ -236,7 +236,7 @@ class Entity(h5py.Group): "merge to the master branch of fish_models if nothing helps, contact Andi.\n" "Don't ignore this warning, it's a serious issue.", ) - def speed_turn(self): + def speed_turn(self) -> np.ndarray: """Get the speed, turn and from the positions. The vectors pointing from each position to the next are computed. diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index 1063960..ca73444 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -29,7 +29,7 @@ import warnings from pathlib import Path from subprocess import run from textwrap import wrap -from typing import Dict, Iterable, List, Optional, Tuple, Union +from typing import Iterable, List, Optional, Tuple, Union import deprecation import h5py @@ -61,7 +61,7 @@ class File(h5py.File): def __init__( self, - path: Optinal[Union[str, Path]] = None, + path: Optional[Union[str, Path]] = None, mode: str = "r", *, # PEP 3102 world_size_cm: Optional[List[int]] = None, @@ -150,6 +150,8 @@ class File(h5py.File): self.validate_when_saving = validate_when_saving self.calculate_data_on_close = calculate_data_on_close + self._entities = None + if open_copy: assert ( path is not None @@ -502,6 +504,8 @@ class File(h5py.File): sampling=sampling, ) + self._entities = None # Reset the entities cache + return entity def create_multiple_entities( @@ -584,10 +588,12 @@ class File(h5py.File): @property def entities(self): - return [ - robofish.io.Entity.from_h5py_group(self["entities"][name]) - for name in self.entity_names - ] + if self._entities is None: + self._entities = [ + robofish.io.Entity.from_h5py_group(self["entities"][name]) + for name in self.entity_names + ] + return self._entities @property def entity_positions(self): @@ -702,7 +708,11 @@ class File(h5py.File): else: properties = [entity_property.__get__(entity) for entity in entities] - max_timesteps = max([0] + [p.shape[0] for p in properties]) + n_timesteps = [p.shape[0] for p in properties] + max_timesteps = max(n_timesteps) + + if np.all(np.equal(n_timesteps, max_timesteps)): + return np.array(properties) property_array = np.empty( (len(entities), max_timesteps, properties[0].shape[1]) @@ -1068,13 +1078,14 @@ class File(h5py.File): categories = [entity.attrs.get("category", None) for entity in self.entities] n_fish = len([c for c in categories if c == "organism"]) - lines = [ plt.plot( [], [], lw=linewidth, - color=custom_colors[i%len(custom_colors)-1] if custom_colors else None, + color=custom_colors[i % len(custom_colors) - 1] + if custom_colors + else None, zorder=0, )[0] for i in range(n_entities) -- GitLab From 10bc54ff482827f76574bfa1a7cf568b5804c18c Mon Sep 17 00:00:00 2001 From: Andi Gerken <andi.gerken@gmail.com> Date: Fri, 21 Mar 2025 16:27:39 +0100 Subject: [PATCH 4/8] Added rendering zone sizes from attributes. --- src/robofish/io/file.py | 62 ++++++++++++++--------- tests/resources/valid_couzin_params.hdf5 | Bin 0 -> 33648 bytes 2 files changed, 37 insertions(+), 25 deletions(-) create mode 100644 tests/resources/valid_couzin_params.hdf5 diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index ca73444..c45bb74 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -29,7 +29,7 @@ import warnings from pathlib import Path from subprocess import run from textwrap import wrap -from typing import Iterable, List, Optional, Tuple, Union +from typing import Dict, Iterable, List, Optional, Tuple, Union import deprecation import h5py @@ -1097,23 +1097,25 @@ class File(h5py.File): if categories[i] == "organism" ] - zone_sizes = [ - ( - get_zone_sizes(self.attrs.get("guppy_model_rollout", "")) - if render_zones - else {} - ) - for _ in range(n_fish) - ] - zones = [ - [ - plt.Circle( - (0, 0), zone_size, color=fish_colors[i], alpha=0.2, fill=False - ) - for zone_size in zone_sizes_fish.values() - ] - for i, zone_sizes_fish in enumerate(zone_sizes) - ] + zones = [] + if render_zones: + for ei, e in enumerate(self.entities): + zone_sizes_str = get_zone_sizes_from_model_str(self.attrs.get("guppy_model_rollout", "")) + zone_sizes_attrs = get_zone_sizes_from_attrs(e) + + # Check that there are no zone sizes in the model and in the attributes + assert zone_sizes_str == {} or zone_sizes_attrs == {}, "There are zone sizes in the model and in the attributes. Please use only one (preferrably the attributes)." + zone_sizes = zone_sizes_attrs if zone_sizes_attrs != {} else zone_sizes_str + + fov = zone_sizes.get("fov", np.pi*2) + fov = np.rad2deg(fov) + zone_sizes.pop("fov", None) + + entity_zones = [] + for zone_size in zone_sizes.values(): + entity_zones.append(matplotlib.patches.Arc((0,0), zone_size, zone_size, angle=0, theta1=-fov/2, theta2=fov/2, color=fish_colors[ei], alpha=0.3, fill=False)) + zones.append(entity_zones) + zones_flat = [] for zones_fish in zones: for zone in zones_fish: @@ -1388,12 +1390,12 @@ class File(h5py.File): poses_trails = entity_poses[:, max(0, file_frame - trail) : file_frame] for i_entity in range(n_entities): - if categories[i_entity] == "organism": - for zone in zones[i_entity]: - zone.center = ( - this_pose[i_entity, 0], - this_pose[i_entity, 1], - ) + for zone in zones[i_entity]: + zone.center = ( + this_pose[i_entity, 0], + this_pose[i_entity, 1], + ) + zone.angle = this_pose[i_entity, 2] * 180 / np.pi if render_swarm_center: swarm_center[0].set_offsets(swarm_center_position[file_frame]) @@ -1454,7 +1456,16 @@ class File(h5py.File): plt.show() -def get_zone_sizes(model: str): +def get_zone_sizes_from_attrs(e: robofish.io.Entity) -> Dict[str, float]: + if "model" not in e: + return {} + else: + possible_params = ["fov", "zor", "zoa", "zoo", "preferred_d", "neighbor_radius"] + model_params = e["model"]["parameters"].attrs + return {p: model_params[p] for p in possible_params if p in model_params} + + +def get_zone_sizes_from_model_str(model: str) -> Dict[str, float]: zone_sizes = {} for zone in [ "zor", @@ -1462,6 +1473,7 @@ def get_zone_sizes(model: str): "zoo", "preferred_d", "neighbor_radius", + "fov" ]: match = re.search(r"{}=(\d+(?:\.\d+)?)".format(zone), model) if match: diff --git a/tests/resources/valid_couzin_params.hdf5 b/tests/resources/valid_couzin_params.hdf5 new file mode 100644 index 0000000000000000000000000000000000000000..e8c7ef32b4b929f0db806dadbfd266ab1eaca66b GIT binary patch literal 33648 zcmeD5aB<`1lHy_j0S*oZ76t(j3y%LoK|wP_2+I8r;W02IKpBisx&unDV1h6h8Mqig zauN_Og8<Zg1!jnV21t^DfgvQw)s=yPkpX5tjE1OUU|@h60Hxr<ql}Re0v@i80U)17 zfCvT#1`Q~E0-DaCT!z%VlFX9K)M6OFI5D>%Co?Y{CIC%t3<fX-1ZY@#eGvswm;s>} zco+g0SQtPlmXAS%L4u(?zbGdqzBscgH9k3)fq_9#K^`p6z@We&&cPsF57o!W2yp?- z*)W=c;e-<i^DuaT^)WK=Gw?7-Fr?)d<tCQIm!%dJXXfWIFbHrkfaO8vf{esv9>f*~ z284Os3?d91F!M@_azMHn7$g}P!8sC?>mhnLz!DHP0|PkFAwuAM3=0nj1+X#(Sa>ip zfFL}ZG6;fAWnd@(1rh@T!vr>ndP4c!K^dYBY9)j-N{*NiK+or}@_qv}{lEgEC_gE` zgn@w}EweZy-Vl@zAu<jSie!?Jfd!ntVfMk&`2l%|Q4J7^0o{F|d;^jSU<Z*1Y-FmC zQ3Wm^AYvX6ifj^>`#hlTtALu!$iRlx9~~+X1u%a=cnl2g&i+2&ybkpUn9Tsq_Y4dS z86_nJ#a8<I>6s;ZnYjgeX{EYJsYN-Nd3q_S`bnAj;CRSPF3w3z(g(3Y@u6Q*l$e|y z4=K9zb@hu<b5awFQ;YQt^$cL;9V^rUuyBGl2N<-W;vpHC#R{3l3W*BEnYjfysS2qT zppraQp**uBLm@XYB~_ucKp`cuBvB8n16s^5z{~}?1(ZHue3(4QeIO06<Rk^+FfcIq zRw@+bmn4>C=A|nX7bK=<q?VLqD!^1TFu<Z0rY||av??=?fdN)dp^I}u%|#c7G_)8P z(8alNsON!-!|cf~N>9woEY9VHildvu2Ng$G&kq$xR}YPLxIbX#3*t}@6=t{qa{x?W zlpYO%(GVC7fzc2c4S~@R7!85Z5Eu=C(GVC7fzc2c)FBY;=kLPEzy$CAc|beLuyFwN z@jKY~2nV!73{&SJ3Gt=@R2;$^IV7QRO$IqgSit(_5Z)*`QbGWdJ{WM%Z;h0I81*OC z5I`SShLsa5WKosF%3JWf45}1aJV^P^2rvIvK+A8KJsHvvV>{5?0TCS~VIcsU??MkR z9tH>Sd>tbLFJ%6$D7COOH7~g`9y~pOd!7(xFvAH)2op9^3*o^?n7!!rFT!39u)Qg% zX^EvdB?RaDz~kl6#0MFxht>En8&PS*Jg6Ypp45th)S}GX)Vz}T;?$DT0+?z-^Niqe zf2h5Xxd12!Y5=7C#JxU1LJgu2RTo5TlpOvc04qN?K)ni^m$q<)3P7oX{Nl`#%=|pq zdVu_*%+x&4ybn|eR_+nIJ^)gFGC^m%85jzn8e!(lfbw@3K^y=}pBK=|*%i?F^b<-D zbt8?0m1F4PfLOo42`*of6H8Ll^NT8B;X|lgK`zfgT^rcE?hR;sz}Az%=1swqr{H;8 z$O;oiCP+-CKrMuni3|+rd}am-uvT2_K3Ewz!17E?5Cb`&3!0c26VxF_!&XE=ctX$` z8D#suUiSnB_(}_iJgrGcx@KX3&0A}@;I|PHS_}*l&iK{AYLf(g{OVx&ghLg-y1_52 z-~@7G=UREEhFB$sifwXE7daO>oRyY!x@r5#LGQkllh_^w$5%`ePW_wR9Y1%9IO*Qa zb-Z6L;B>a5)A9ZVZYSHRvmN)(WOMRiT<duMI)f9-?!AsNN?#o{H=cF$p7Gq#?eZPR zI{BN92L)d{J~(#L@pbGEN9TDv9XUdoocy<~a6I^h-Kj`nx?>0juhW@#EspJ7f=;#> zDUO@UM4j^4?HtdXkZ{`fNX#+4R@#Zv^`677iLy>*_j?@{Fv~kln)lg$_C<N8HRVeC zZ~T&Xa(L)<K&M{bX_w}<17h)VPHn=>2M^4aaT4m*I{3>>%IRWI;6WQ>ai^0`<p&o{ z7Iyj|*>~_oG`|xA!-9k96S$m|95)^mFk^M<IeqZp`6d4x4f`%0WW4^-@p8q3gD?1> zI11FfJGedes$)gx?}O9MA9Xx^lKD_~=~hPucg{m42}>L|oADiby>OD_*~vnOtbH0B z&qs<Kn))Ql@oJ0YApw6A$Jv51ht~e*ajgF<d+3?%DTfX>`9tr*iyabQ%O5HzT4WEU z_d~@GK-C?9>N^NE=OEPFgHU@8LhU^Wb;m)dyADF#c@XOEgHV4Qg!=0s)Sm~T{yqc^ zheObCIRp)-L(p(L1P#YS&~QBj4d+A9a6bf%heObKIRuTTL(q6T1dYc-(0Dxrjpswq zcs~S92Zx~P;t({Q9D=5sgV1zz5Sp$ILetqnXu3NHO@{}e>GB{nogRdy+k?<_d=Q$h z4?@%V0cg5E0L=#n7#I|s>S_&EFfcGQFfcUOr#)%h%U4rxU$T&K-+>Rc_9CYp_pOVs zwy&7hu&?HHx&6&U^Y`5`FShUY+PANDWv;yl-<^F%YU%dJ-~HG(V@te!K|0s|l$db) z3n#?)8>;!)@9tIEUn}im&$`lJ{~|X_dlNpJ{iiqS+IMbr+y5+E(Vn*>VE>#t5&PNg zk^3Efu-U&|mbky=>_@vLcQf|8Z@gj`ppw7;)Xr^ol_e$nKeNuTTY92m|Ea7(yBL|8 z{Z}p++8O86?N|Q)$oA;g`u&biooo+$uiqc`vvJP@!}|SaeHZQhkW#z<_b$PGXC_tc zzrG`IU(CL;{f_w^`zF6G++R|<d|#$~&VKguNB7-~Pu;&c=JCGkn`8GgFfi=*Q4Za| z^(^mxp6%ZIZ8@d)pQ~`(|7D5#euZMQ{dXoC?@!pMwSW0ZyZx)}<@O)9^4xz-OK^Y1 zouK`(4b1z`E{Wd1dd{nTJGzqhA8|XoPqZy-zl-epeVh6V_B+^5*cZC6bidfl#C=D$ zR_>QOt+?;ZiJJZU)*RYvf4y$M+E>QCjQ8t7v@M9<x8{&7NL=o;q8&(`*vv#bkUj_d z33ecJTx8eVfy_PPcGeDL&+0j^>_GO$HZa?R+;L4y&>rNj)%J4sAa^Eg)UpS;TcOy@ z9^{X66^`~GfAMVhwg>ssM>*6U<nQa7W9>oVa5FyD9uzK_@;UaPaGLzO&>j?SG5gBw zLE(62Qk6X@TtB4L+JnORfnmKpDBKTxueS%q!_lqv_Mmt%&Z)Bp#Z!z-jXfydmY%4v z2gPG$Nr^ouUISF}?LqOp<ZgyNDBj;LOSA{2gW2to_MmjZ+Yw+7N++Eg-Rwc>#)Qwt z9+Zw)R~p!Z($(%>6?;%RyKq9>9+d71(z)zG>G1fwA9kR0DZ+Qh4wO#2z4qCG((TPd z^X)+CxMEs^9VlIkoOZMWrSp=7jCP=OpZ28D7L*T+40}~U`2tjTjN+ji0t45sY|w&) z+TdzeGBR+2+rhAQAgn!ZU<`3BY<(i3{{8_Ih<aE%5yBfKM@k5w_xE7~O&Q?UHfY_u z1$5kvn}LUcgP|b5IJG!FBe57dFaTOd4@o<a6~mx@7AWjQK*T>#cOL(E+YGdRJuDq3 zfNC!W28I_Pa~MGGfQ|n_#^H(!Qj<a36A1NN!Rt~X4ur221sOST<w=4$$bJTf!Bw6> z`l+n&erkjx#Fe;6SU=SR679H@!F$^;iXcj0=|ccIaAZ&l5r>sy9A*$cE)p7*3>%>3 z94>VOB|dQJs=*54=fRb(Aa0@t>2ZjG0fb!+F))Cz!yyI+5VkqQzyQJ)hZq<@*yIoc z0|*-&VqgGaokI)^AgpnSfdPb74lyu*u)-k*1`w7x#J~W;5{DQVKv?7u0|N*P9AaPq zVV*+_3?R&Lh=BowSq?EUfH1=$1_luRagc!lgg+c)U;yD42N@VZ_`yL21`xh+kbwb& zFC1iG0O1n{85ltLz(EEE5Z-Z+fdPa!9Asbs;S~oN7(jTzK~Nu;fnml$2%d0|fdPa& z4l*!+aKk|c1`sYd$iM)?2?rS%Ksewa0|N*<9AsbsVS|GV3?Qs<kbwb&1r9PWfH1>B z1_ls*aDagUgby5GU;yC-2N)PYxZwZ;0|+}DU|;}Yh64-?Al$H@fdPbJaoS)HZC`^* z0tW^L2PkfEU|;~@1r7`hAbh}qfdPacI504PFaxNBab#c+fMNwl1_lr|aAaTrVFyPB z1`rN#WMBZ{1V;u25H4_JU;yC;M+OEE?r>ya0O1La3=AMV!;ygjgcmq6Fo5t1M+OEE z-r&f<0Kz*Q85ltLfFlC~2!rN@K=^_q0|N-(aAaTr;RlWk3?Tf%k%0k(KR7Zlfbb7T z1_lsj0M(FA3=Aw#%;Chq0Kz;@3=AMF;KaZH!Xi!#3?MAw#J~W;GENK(APgGM0$~*= z1_lt;aAIHpVI3z11`sxIVqgGa6DI}+5Vmk)U;tqoCk6%(c5q@~0AUv=1_lrY#RCJw zy1IJ%z!?p;*1v1*qr(~Objz#lA3w0S6Z~3kKVe0kU13GBeU<q<yQ{2u_Lr*m*zI4G zZolvUO}j;*3HI45zS%ucjj;cu&1t_^+TUKHQp`Tu!PVYEPuYIMTr2x(JAE()u|Z-W zIglEVT96))UXU3evp{Bo%mvI}G<$Zn7uKyCrK3FJ1A8$oUbxf$g4^mNOmp!5_t zqhT+IA05uH4<!Hif&D&^`UxxQ_JQ<Qna|q?GVfB=o_!$m_x-=Q4`g5Vif{Ws_J7jm z+z)b(M5Wk%kozq3l=p+&TWzPmALRZGbFKD+{1ffqx*z11z0&^sL4J9l8nGYbmqnoo z`$2x$zbJh_$S+q}^Y(-MQdm*EALJLoujTtee$g$j-VgGN_3zsKAiu1utKSdu3n-7I zr|(<}<?n^cgZuze5Ap*@KgbUt^FV$8nGf;<$UcxCK=$v0x(DP3ko!P>0J#_B2ax;0 z7~~hI|G;WMeu4TQYzD|L(C`7<0rCqp{J?Gi`2`xjAU}Y@8{`L2`0qb;G8hyeAigcg zFCagH`~vbb$nT&q0O<#X2`FqpVFWTC6lS2X1BD^Teo&Z#`~nJNko!Pk4hnlv9Dv*p ziW5-WfZ_<`7f_sm;tmvtAisd(6co3hI0pFz6z8D02c-d!UqERBlr}(V1mqV`ngOL9 zP#OZIC0J=SN)OEtK+MBSfcqh!9cA&wC5d?{iA5>#IjMQ+B^jU{!?^k*u>E2@4CM7^ z&OtZQgXYmdAqeV!gZ2R@lQ@qCZoh&F22kgS0kQK9WGZO?C1l?rxO0zZ{tVWRhRvVB z_=x#4d+5xU1A_@SgUNr$VgLr*^JgB=DJtCU$IsF(N%WpSBh;Qwh6Eb}6O=~hlOG-# z-Vh@VU|apbT;%!0fd~(<BN!M!{cybV@5zZd$)!1oC8;U#$W!z2MTw9UgCw7r44%l3 zFD^(;O(~8qDJ{w?X2{J?Nd;|fhovh};|^vwDKw$@gSZBi?jfe4^U06D8G#TJN7FMX z9GDq*gn(!UhLM?`Vd)rJ$U%Z)lpGC#!5ab#_KR)+je<7p7d-%?6Xq+N0MQHfi(UZH z4f@hoK(xbrg&Sa+U+oTvUa(*E0f=rWmv{uG^`)PH=!EC8&%pG2g%=>&Azb+tnC4e| z1E!B_yaUk-_KSW1(GAk#AHj6F#3wL)S@JWO)|dVQrn_ang6M?jvfse8v;223JzwDm zh<5m`_!CTrEB^x1n^k^;X@0doAX*_?{V$k4uJI2{D{K7+(GB}W89>!dL+&&PEiKLn z=AWDnq01$hz<kx25c;wtGnn5v3qtElvw-<`XG7?28CDQK!E!EyelE)f=1-jmp`GQ~ z!TgW&A@qC&4lv($A%y;|$O-1JTm+%RmASxt)+G>nvkEtupSTo4^Q-ZI`FodvX@_ie zUJ&0was`AwuE7W9m#&1+%3Azj{>4>bdcl5C0T6$ISQv!Pod%(gT!YZk;(}oLqHqX( zayo=oxB;dc$|ZzA@(t%AAhhaC2wig%LSL2?2Fq(jLFmR=5c<X~2(2$I0+w%!hR}Cs zLuivb5V~7N6fFNJ20~lTh0qi3Lg?qRVqkf@I0!v;9)x~#4@@UG%Zr2L6K2Ij=#TRu zw8sMoJzqfrEdMP5Li;X+&`TbI=?3Ya;PeU09}0pG!Ra&snqLyWZ3Cy*15m!h=Spz8 zh4CLaX+qL3RQ`eS5ip+t#@F-$r)!1-(EK-d7YbiNRur7xVd@=xBO!dKdWV;75I)qr z1Bd29_)vK`A1V)550ytZ4{je+z5(W*eGvCOKyxqLKT!2B|HAwO3lA6{7G4Zc^I+i# zlZS;jjL*Qp01khc`i7FXZ$bGMCf~4T&ubL^z5JKR`~|U`&%u1CeGSvP!Q~TNKkv_n zV0ozf8Fs9=2j)Z7D?~fo1oNTt4b}TDqws?Q&!F%HIS(T98M5{*L+69mjKJ*|<UEMN z4+=bk!mr+U8JQ1q4>BL*UNC<TBz_ure?A2BVeXyQ{REj03J+v{!@c~MDEu{hUZe0! z-o8cVgW?034~j2jJ}5qs`Jnhl=7Z7$G9Q#)Q2542zM$|my}lyzLFp09hx!+kUcr2* ze?jRPh5xec8!{i1AMo)(`2$%V-8^*r;O>Fi56U0N`a$^vh3}+^;y?88fQQ!(NPMD) zCwh2e#0Oe@f$|5k{pj(Hksi>}3wnBjr#Gm7(bFR+f1ucho}SUtJ1Bo3s|WFs^GgF- zeriC^Z)o`u<Q}koX!wEL3+6-J3-S*NAFVt9g$If}T6qErPZW8y@+bk?4mt2a719oQ zrU`C`EV!);X@^`i1h+#L95aQqL-tsL+aU*T$U)j4mluF(hjU8c_6NhsC1Co%VKqqm zW8Vre?XXh|-2PzLvIa~aSf>YWe<ZBj0Hzg|7=ha#59VzF(+g&rLE0Zvc7SPxek*YM z<3Q&gFx}8%2X22j)E@xT3RO<v_Q!*gBVc+#z8kpx;gEF#Of#f<f!iMk;?IEThA2Ns z`y=!Mm}UqF0=GX7cwGV04X$C}_J@MQ4KV${Dhks6Fuen&9VGO@?TrS}0x-=Wyb(-0 z2z&z53wVve?F|O*5-^>>u@y`oVEqQB6`0K-?G1(sF#X`~4lwQT`zM%Q@WTq+-eCAv z4W=DF?*Y>XKKupK4Dak9?Ty#<VEVv|17KR=86%{f@z@DWKX}j#rXB7b0n-a^vx3_h z3^(1tbi%cEFn!?i2{5g2krUj`XgKEurXQT{2Gb5F&w%L#$9Tc*j0cDPz_i1`elWdY z-vuzuut!i5)IM=wj$Z*zcQ^C)fanB?@M9qQ!QrxVAezC#8(c3P;Elcx;xBldbsI!C z?5w&Eq7S$q1=n*8MrR*`_zkjGpMq(QyWsjU;qw!4z3Fi84Y)o{IQa!!&nj&B^B$yb zfj<+t-J{^j25#>(SaE^dISdAT;Py;|st}~zA}t1Pe<TP<f!pT~SmnU&b6CD$*x(IL zpA2F4;QodKj2{8vA6Snn4_6PBe*lw*@)OYX!_9-5kFI_Jx_&hN12pr{?MLUMn-8}i z>OMF>1meF0H2WIR>}x=iNB1AP{qXRDnvX7z?moDD48%QfKGePFe0X?4&4>F3sveyW zcMnt^-8^)8bpI_t3lF$@sQu{Xq5BUWzEJ-ifSDHn@vj3cd=nsixPGWSoDX%s0!%-Y zkM3Ud_=ATRR6V+WboZn4(d|bMU%3CE?uFYA^$%P<R32_WR309lP<eFw;o$>yKRo=P zeDw4QHxDWgw-3rkmxsF-DvutX=<3nkhb|BIFVsA^{ZRiAiVyVghS=xEy%DYaNC=0d z7l=Fq!vQECt{=*W$vZ&Z1NSdf9u|KMP(Dl^#)ru>K=s4@2Xzm6c){(5s)yOn05u<- z4=>*$A^w5e2Q?oazEF8|^WotORgdl-xPEl^K>2X<p!(tFLFM7@1@U3~%h1oKM$C)x zfcKX{&+*MjEJy{9un}6P2Ri5RMG{ykZ0ZZb86`(UU^E0qLtr!nMnhmU1O_|=aGi4v zs=q;wD@si+Nz6;nfz<jS5eU8jJ2)A_9VJIYU^E0qLtr!nMnhmU1V%$(Gz3ONU^E0q zLtr!nMnhmU1O{;kFfuTM$LC@DwPE}9S3vi1!|t&pbne6n=tW!5tuTYw$D>Z8MF_yo zoq(Mi2un`^&~yZ|vmmi3F*mg&wFr7%0BjvUXaNRH4JkBi{tEs40nq+3kU5MDtl;^p zyu@7a5Eh|%EAZwZXd)H_@7H3-nm+_Um)JozP=l<2-qSjIjx{(iMhT1%K+hl0B*stx zEjM8rVEKVo=U8);L2QEAV*up`pt%F0YLtY90IXhAfZq2Eiys~a$bmPYdniHY=~U(C zgS8Oew~DSW5wAY*`8(iqbI|oc4r0M&ANZVCu*uXVVEb=j?xK}F)OGc!X`>-98Umvs iFd71*Aut*Ol!O4R--X_OL~jqJ<wF{nxY{SMa~=Q<&+Y*L literal 0 HcmV?d00001 -- GitLab From f05e263642d6a79e8af0194433f2377382cb792b Mon Sep 17 00:00:00 2001 From: Andi Gerken <andi.gerken@gmail.com> Date: Fri, 21 Mar 2025 16:45:49 +0100 Subject: [PATCH 5/8] Changed arcs to wedges for better display of zones --- src/robofish/io/file.py | 17 +++++++++++++---- tests/resources/valid_couzin_params.hdf5 | Bin 33648 -> 33648 bytes 2 files changed, 13 insertions(+), 4 deletions(-) diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index c45bb74..cc267cb 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -1099,6 +1099,7 @@ class File(h5py.File): zones = [] if render_zones: + fovs = [] for ei, e in enumerate(self.entities): zone_sizes_str = get_zone_sizes_from_model_str(self.attrs.get("guppy_model_rollout", "")) zone_sizes_attrs = get_zone_sizes_from_attrs(e) @@ -1110,10 +1111,14 @@ class File(h5py.File): fov = zone_sizes.get("fov", np.pi*2) fov = np.rad2deg(fov) zone_sizes.pop("fov", None) + fovs.append(fov) entity_zones = [] for zone_size in zone_sizes.values(): - entity_zones.append(matplotlib.patches.Arc((0,0), zone_size, zone_size, angle=0, theta1=-fov/2, theta2=fov/2, color=fish_colors[ei], alpha=0.3, fill=False)) + if fov >= 360: + entity_zones.append(matplotlib.patches.Circle((0, 0), zone_size, color=fish_colors[ei], alpha=0.3, fill=False)) + else: + entity_zones.append(matplotlib.patches.Wedge((0,0), zone_size, theta1=-fov/2, theta2=fov/2, color=fish_colors[ei], alpha=0.3, fill=False)) zones.append(entity_zones) zones_flat = [] @@ -1391,11 +1396,15 @@ class File(h5py.File): poses_trails = entity_poses[:, max(0, file_frame - trail) : file_frame] for i_entity in range(n_entities): for zone in zones[i_entity]: - zone.center = ( + + zone.set_center(( this_pose[i_entity, 0], this_pose[i_entity, 1], - ) - zone.angle = this_pose[i_entity, 2] * 180 / np.pi + )) + if fovs[i_entity] < 360: + ori_deg = np.rad2deg(this_pose[i_entity, 2]) + zone.theta1 = ori_deg - fovs[i_entity] / 2 + zone.theta2 = ori_deg + fovs[i_entity] / 2 if render_swarm_center: swarm_center[0].set_offsets(swarm_center_position[file_frame]) diff --git a/tests/resources/valid_couzin_params.hdf5 b/tests/resources/valid_couzin_params.hdf5 index e8c7ef32b4b929f0db806dadbfd266ab1eaca66b..506cdd1ae9df5a0438b2bcec9b47ffc016037515 100644 GIT binary patch delta 63 zcmey+#`K|$X@db5qtN7qoZ^halN~w58AZTsQ4o8w0T%~T0q<r<t~B0F5+)jqBAZPe O&YLodZe}gZ<^=#Qv=F5L delta 39 vcmey+#`K|$X@ddR<Oy6HOc!`J8zvv)-6Ub6!6>lV)Zx4-qu^%NvTR-eA}|dk -- GitLab From 7e856352e2f6ba0d549d49b69c44111139ebdab9 Mon Sep 17 00:00:00 2001 From: Andi Gerken <andi.gerken@gmail.com> Date: Fri, 21 Mar 2025 17:58:13 +0100 Subject: [PATCH 6/8] Added smoothing for fov orientation change to make it less stressful to watch. --- .../load_model_params_from_model_options.py | 44 +++++++++++++++++++ src/robofish/io/app.py | 1 + src/robofish/io/file.py | 22 +++++++--- 3 files changed, 61 insertions(+), 6 deletions(-) create mode 100644 src/conversion_scripts/load_model_params_from_model_options.py diff --git a/src/conversion_scripts/load_model_params_from_model_options.py b/src/conversion_scripts/load_model_params_from_model_options.py new file mode 100644 index 0000000..fb17140 --- /dev/null +++ b/src/conversion_scripts/load_model_params_from_model_options.py @@ -0,0 +1,44 @@ +from pathlib import Path +import argparse +import robofish.io +from tqdm import tqdm + +def main(file_path): + files = list(file_path.glob("*.hdf5")) if file_path.is_dir() else [file_path] + + for f in tqdm(files): + with robofish.io.File(f, "r+") as iof: + model_options = iof.attrs["model_options"] + model_options = eval(model_options)["options"] + + neccessary_options = ["zor", "zoo", "zoa", "fov", "additive_zone_sizes"] + + # extracted_options + e_options = {no: model_options[no] for no in neccessary_options} + + if e_options["additive_zone_sizes"]: + e_options["zoo"] += e_options["zor"] + e_options["zoa"] += e_options["zoo"] + + for e in iof.entities: + e.attrs["category"] = "organism" + + if "model" in e: + del e["model"] + + g = e.create_group("model") + g.attrs["name"] = "couzin" + + p = g.create_group("parameters") + p.attrs["zoo"] = e_options["zoo"] + p.attrs["zoa"] = e_options["zoa"] + p.attrs["zor"] = e_options["zor"] + p.attrs["fov"] = e_options["fov"] + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Convert csv file from the socoro experiments to RoboFish track format.") + parser.add_argument("file_path", type=str, help="Path to the csv file.") + args = parser.parse_args() + main(Path(args.file_path)) \ No newline at end of file diff --git a/src/robofish/io/app.py b/src/robofish/io/app.py index 9c35211..4aec636 100644 --- a/src/robofish/io/app.py +++ b/src/robofish/io/app.py @@ -267,6 +267,7 @@ def render(args: argparse.Namespace = None) -> None: "render_swarm_center": False, "highlight_switches": False, "figsize": 10, + "fov_smoothing_factor": 0.8 } for key, value in default_options.items(): diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index cc267cb..8af3217 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -1026,6 +1026,7 @@ class File(h5py.File): custom_colors: bool = None, dpi: int = 200, figsize: int = 10, + fov_smoothing_factor: float = 0.8 ) -> None: """Render a video of the file. @@ -1116,9 +1117,9 @@ class File(h5py.File): entity_zones = [] for zone_size in zone_sizes.values(): if fov >= 360: - entity_zones.append(matplotlib.patches.Circle((0, 0), zone_size, color=fish_colors[ei], alpha=0.3, fill=False)) + entity_zones.append(matplotlib.patches.Circle((0, 0), zone_size, color=fish_colors[ei], alpha=0.2, fill=False)) else: - entity_zones.append(matplotlib.patches.Wedge((0,0), zone_size, theta1=-fov/2, theta2=fov/2, color=fish_colors[ei], alpha=0.3, fill=False)) + entity_zones.append(matplotlib.patches.Wedge((0,0), zone_size, theta1=-fov/2, theta2=fov/2, color=fish_colors[ei], alpha=0.2, fill=False)) zones.append(entity_zones) zones_flat = [] @@ -1319,6 +1320,8 @@ class File(h5py.File): pbar = tqdm(range(n_frames)) if video_path is not None else None + self.fov_orientations = [None] * n_entities + def update(frame): output_list = [] @@ -1395,15 +1398,22 @@ class File(h5py.File): poses_trails = entity_poses[:, max(0, file_frame - trail) : file_frame] for i_entity in range(n_entities): - for zone in zones[i_entity]: - + new_ori = this_pose[i_entity, 2] + old_ori = self.fov_orientations[i_entity] if self.fov_orientations[i_entity] is not None else new_ori + + # Mix in complex space to handle the wrap around + smoothed_ori = fov_smoothing_factor * np.exp(1j * old_ori) + (1-fov_smoothing_factor) * np.exp(1j * new_ori) + self.fov_orientations[i_entity] = np.angle(smoothed_ori) + + ori_deg = np.rad2deg(self.fov_orientations[i_entity]) + + for zone in zones[i_entity]: zone.set_center(( this_pose[i_entity, 0], this_pose[i_entity, 1], )) if fovs[i_entity] < 360: - ori_deg = np.rad2deg(this_pose[i_entity, 2]) - zone.theta1 = ori_deg - fovs[i_entity] / 2 + zone.theta1 = ori_deg - fovs[i_entity] / 2 zone.theta2 = ori_deg + fovs[i_entity] / 2 if render_swarm_center: -- GitLab From 26710c205d04ed26de91f6838e10e406f9d5d04d Mon Sep 17 00:00:00 2001 From: Andi <andi.gerken@gmail.com> Date: Mon, 24 Mar 2025 14:56:49 +0100 Subject: [PATCH 7/8] Moved social vectors to robofish.io.utils to finally solve issue #24 --- .gitignore | 2 + src/robofish/evaluate/app.py | 11 +-- src/robofish/evaluate/evaluate.py | 91 ++++++++++---------- src/robofish/io/app.py | 17 ++-- src/robofish/io/file.py | 87 ++++++++++++------- src/robofish/io/utils.py | 70 ++++++++++++++- tests/robofish/evaluate/test_app_evaluate.py | 1 + tests/robofish/io/test_app_io.py | 4 +- 8 files changed, 191 insertions(+), 92 deletions(-) diff --git a/.gitignore b/.gitignore index 0cd5109..674857f 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,8 @@ env *.hdf5 *.mp4 +# Ignore all files with ignore_ prefix +ignore_* !tests/resources/*.hdf5 feature_requests.md diff --git a/src/robofish/evaluate/app.py b/src/robofish/evaluate/app.py index 5c837bd..78697fd 100644 --- a/src/robofish/evaluate/app.py +++ b/src/robofish/evaluate/app.py @@ -94,26 +94,27 @@ def evaluate(args: dict = None) -> None: default=None, ) parser.add_argument( - "--save_path", + "--save-path", + "--save_path", # for backwards compatibility type=str, help="Filename for saving resulting graphics.", default=None, ) parser.add_argument( - "--max_files", + "--max-files", + "--max_files", # for backwards compatibility type=int, default=None, help="The maximum number of files to be loaded from the given paths.", ) parser.add_argument( - "--add_train_data", + "--add-train-data", + "--add_train_data", # for backwards compatibility action="store_true", help="Add the training data to the evaluation.", default=False, ) - # TODO: ignore fish/ consider_names - if args is None: args = parser.parse_args() diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py index 1f4d2af..793eb31 100644 --- a/src/robofish/evaluate/evaluate.py +++ b/src/robofish/evaluate/evaluate.py @@ -31,6 +31,7 @@ def evaluate_speed( labels: Iterable[str] = None, predicate: Callable[[robofish.io.Entity], bool] = None, speeds_turns_from_paths: Iterable[Iterable[np.ndarray]] = None, + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the speed of the entities as histogram. @@ -47,7 +48,7 @@ def evaluate_speed( if speeds_turns_from_paths is None: speeds_turns_from_paths, _ = utils.get_all_data_from_paths( - paths, "speeds_turns", predicate=predicate + paths, "speeds_turns", predicate=predicate, max_files=max_files ) speeds = [] @@ -92,6 +93,7 @@ def evaluate_turn( predicate: Callable[[robofish.io.Entity], bool] = None, speeds_turns_from_paths: Iterable[Iterable[np.ndarray]] = None, file_settings: dict = None, + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the turn angles of the entities as histogram. @@ -109,7 +111,7 @@ def evaluate_turn( if speeds_turns_from_paths is None: speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths( - paths, "speeds_turns", predicate=predicate + paths, "speeds_turns", predicate=predicate, max_files=max_files ) turns = [] @@ -156,6 +158,7 @@ def evaluate_orientation( predicate: Callable[[robofish.io.Entity], bool] = None, poses_from_paths: Iterable[Iterable[np.ndarray]] = None, file_settings: dict = None, + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the orientations of the entities on a 2d grid. @@ -172,7 +175,7 @@ def evaluate_orientation( """ if poses_from_paths is None: poses_from_paths, file_settings = utils.get_all_poses_from_paths( - paths, predicate=predicate + paths, predicate=predicate, max_files=max_files ) world_bounds = [ @@ -243,6 +246,7 @@ def evaluate_relative_orientation( predicate: Callable[[robofish.io.Entity], bool] = None, poses_from_paths: Iterable[Iterable[np.ndarray]] = None, file_settings: dict = None, + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the relative orientations of the entities as a histogram. @@ -260,7 +264,7 @@ def evaluate_relative_orientation( if poses_from_paths is None: poses_from_paths, file_settings = utils.get_all_poses_from_paths( - paths, predicate + paths, predicate, max_files=max_files ) orientations = [] @@ -299,6 +303,7 @@ def evaluate_distance_to_wall( predicate: Callable[[robofish.io.Entity], bool] = None, poses_from_paths: Iterable[Iterable[np.ndarray]] = None, file_settings: dict = None, + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the distances of the entities to the walls as a histogram. Lambda function example: lambda e: e.category == "fish" @@ -315,7 +320,7 @@ def evaluate_distance_to_wall( if poses_from_paths is None: poses_from_paths, file_settings = utils.get_all_poses_from_paths( - paths, predicate + paths, predicate, max_files=max_files ) world_bounds = [ @@ -387,6 +392,7 @@ def evaluate_tank_position( poses_from_paths: Iterable[Iterable[np.ndarray]] = None, file_settings: dict = None, max_points: int = 4000, + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the positions of the entities as a heatmap. Lambda function example: lambda e: e.category == "fish" @@ -404,7 +410,7 @@ def evaluate_tank_position( if poses_from_paths is None: poses_from_paths, file_settings = utils.get_all_poses_from_paths( - paths, predicate + paths, predicate, max_files=max_files ) xy_positions = [] @@ -479,18 +485,13 @@ def evaluate_quiver( """ try: import torch - from fish_models.models.pascals_lstms.attribution import SocialVectors except ImportError: - print( - "Either torch or fish_models is not installed.", - "The social vector should come from robofish io.", - "This is a known issue (#24)", - ) + print("Torch is not installed and could not be imported") return if poses_from_paths is None: poses_from_paths, file_settings = utils.get_all_poses_from_paths( - paths, predicate, predicate + paths, predicate, max_files=max_files ) if speeds_turns_from_paths is None: speeds_turns_from_paths, file_settings = utils.get_all_data_from_paths( @@ -547,10 +548,11 @@ def evaluate_quiver( # print(tank_directions[x,y] - tank_directions_speed[x,y]) tank_count[x, y] = len(d) - sv = SocialVectors(poses_from_paths[0]) - sv_r = torch.tensor(sv.social_vectors_without_focal_zeros)[:, :, :-1].reshape( - (-1, 3) - ) # [:1000] + sv_r = torch.tensor( + robofish.io.utils.social_vectors_without_focal_zeros( + poses_from_paths[0][:, :, :-1] + ).reshape((-1, 4)) + ) sv_r = torch.cat( (sv_r[:, :2], torch.cos(sv_r[:, 2:]), torch.sin(sv_r[:, 2:])), dim=1 ) @@ -631,6 +633,7 @@ def evaluate_social_vector( predicate: Callable[[robofish.io.Entity], bool] = None, poses_from_paths: Iterable[Iterable[np.ndarray]] = None, file_settings: dict = None, + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the vectors pointing from the focal fish to the conspecifics as heatmap. Lambda function example: lambda e: e.category == "fish" @@ -645,17 +648,9 @@ def evaluate_social_vector( matplotlib.figure.Figure: The figure of the social vectors. """ - try: - from fish_models.models.pascals_lstms.attribution import SocialVectors - except ImportError as e: - warnings.warn( - "Please install the fish_models package to use this function.\n", e - ) - return - if poses_from_paths is None: poses_from_paths, file_settings = utils.get_all_poses_from_paths( - paths, predicate + paths, predicate, max_files=max_files ) fig, ax = plt.subplots( @@ -675,8 +670,8 @@ def evaluate_social_vector( axis=-1, ) - social_vec = SocialVectors(poses).social_vectors_without_focal_zeros - flat_sv = social_vec.reshape((-1, 3)) + social_vec = robofish.io.utils.social_vectors_without_focal_zeros(poses) + flat_sv = social_vec.reshape((-1, 4)) bins = ( 30 if flat_sv.shape[0] < 40000 else 50 if flat_sv.shape[0] < 65000 else 100 @@ -700,6 +695,7 @@ def evaluate_follow_iid( predicate: Callable[[robofish.io.Entity], bool] = None, poses_from_paths: Iterable[Iterable[np.ndarray]] = None, file_settings: dict = None, + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the follow metric in respect to the inter individual distance (iid). Lambda function example: lambda e: e.category == "fish" @@ -716,7 +712,7 @@ def evaluate_follow_iid( if poses_from_paths is None: poses_from_paths, file_settings = utils.get_all_poses_from_paths( - paths, predicate + paths, predicate, max_files=max_files ) follow, iid = [], [] @@ -792,12 +788,11 @@ def evaluate_follow_iid( # Created an issue fig = plt.figure(figsize=(12, 5)) - gs = gridspec.GridSpec(1, len(grids)) - - for i in range(len(grids)): - SeabornFig2Grid(grids[i], fig, gs[i]) - + + for i in range(len(grids)): + SeabornFig2Grid(grids[i], fig, gs[i]) + return fig @@ -808,6 +803,7 @@ def evaluate_tracks_distance( poses_from_paths: Iterable[Iterable[np.ndarray]] = None, file_settings: dict = None, max_timesteps: int = 4000, + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the distances of two or more fish on the track. Lambda function example: lambda e: e.category == "fish" @@ -828,6 +824,7 @@ def evaluate_tracks_distance( predicate, lw_distances=True, max_timesteps=max_timesteps, + max_files=max_files, ) @@ -841,7 +838,7 @@ def evaluate_tracks( seed: int = 42, max_timesteps: int = None, verbose: bool = False, - max_files: int = None + max_files: int = None, ) -> matplotlib.figure.Figure: """Evaluate the distances of two or more fish on the track. Lambda function example: lambda e: e.category == "fish" @@ -867,7 +864,7 @@ def evaluate_tracks( max_files_per_path = max([len(files) for files in files_per_path]) min_files_per_path = min([len(files) for files in files_per_path]) - + rows, cols = len(files_per_path), min(6, max_files_per_path) multirow = False @@ -1317,26 +1314,26 @@ def calculate_distLinePoint( def show_values( pc: matplotlib.collections.PolyCollection, fmt: str = "%.2f", **kw: dict ) -> None: - """Show numbers on plt.ax.pccolormesh plot. - - https://stackoverflow.com/questions/25071968/ - heatmap-with-text-in-each-cell-with-matplotlibs-pyplot + """Show numbers on plt.ax.pcolormesh plot. Args: pc(matplotlib.collections.PolyCollection): The plot to show the values on. fmt(str): The format of the values. kw(dict): The keyword arguments. """ + values = pc.get_array() pc.update_scalarmappable() ax = pc.axes - for p, color, value in zip(pc.get_paths(), pc.get_facecolors(), pc.get_array()): + + for p, color, value in zip(pc.get_paths(), pc.get_facecolors(), values.ravel()): x, y = p.vertices[:-2, :].mean(0) - if np.all(color[:3] > 0.5): - color = (0.0, 0.0, 0.0) - else: - color = (1.0, 1.0, 1.0) - if np.all(value == "--"): - ax.text(x, y, fmt % value, ha="center", va="center", color=color, **kw) + + # Choose text color based on background + text_color = (0.0, 0.0, 0.0) if np.all(color[:3] > 0.5) else (1.0, 1.0, 1.0) + + # Only show non-masked values + if not np.ma.is_masked(value): + ax.text(x, y, fmt % value, ha="center", va="center", color=text_color, **kw) class SeabornFig2Grid: diff --git a/src/robofish/io/app.py b/src/robofish/io/app.py index 4aec636..f0f03a3 100644 --- a/src/robofish/io/app.py +++ b/src/robofish/io/app.py @@ -41,14 +41,16 @@ def print_file(args: argparse.Namespace = None) -> bool: parser.add_argument("path", type=str, help="The path to a hdf5 file") parser.add_argument( - "--output_format", + "--output-format", + "--output_format", # backwards compatibility type=str, choices=["shape", "full"], default="shape", help="Choose how datasets are printed, either the shapes or the full content is printed", ) parser.add_argument( - "--full_attrs", + "--full-attrs", + "--full_attrs", # backwards compatibility default=False, action="store_true", help="Show full unabbreviated values for attributes", @@ -122,7 +124,8 @@ def validate(args: argparse.Namespace = None) -> int: description="The function can be directly accessed from the commandline and can be given any number of files or folders. The function returns the validity of the files in a human readable format or as a raw output." ) parser.add_argument( - "--output_format", + "--output-format", + "--output_format", # backwards compatibility type=str, default="h", choices=["h", "raw"], @@ -179,7 +182,9 @@ def render_file(kwargs: Dict) -> None: kwargs (Dict, optional): A dictionary containing the arguments for the render function. """ with robofish.io.File(path=kwargs["path"]) as f: - f.render(**{k:v for k,v in kwargs.items() if k not in ("path", "reference_track")}) + f.render( + **{k: v for k, v in kwargs.items() if k not in ("path", "reference_track")} + ) def overwrite_user_configs() -> None: @@ -248,7 +253,7 @@ def render(args: argparse.Namespace = None) -> None: default=[], help="Custom colors to use for guppies. Use spaces as delimiter. " "To set all guppies to the same color, pass only one color. " - "Hexadecimal values, color names and matplotlib abbreviations are supported (\"#000000\", black, k)" + 'Hexadecimal values, color names and matplotlib abbreviations are supported ("#000000", black, k)', ) default_options = { @@ -267,7 +272,7 @@ def render(args: argparse.Namespace = None) -> None: "render_swarm_center": False, "highlight_switches": False, "figsize": 10, - "fov_smoothing_factor": 0.8 + "fov_smoothing_factor": 0.8, } for key, value in default_options.items(): diff --git a/src/robofish/io/file.py b/src/robofish/io/file.py index 8af3217..0afc5cd 100644 --- a/src/robofish/io/file.py +++ b/src/robofish/io/file.py @@ -936,7 +936,7 @@ class File(h5py.File): else: step_size = poses.shape[1] - cmap = matplotlib.cm.get_cmap(cmap) + cmap = plt.get_cmap(cmap) x_world, y_world = self.world_size if figsize is None: @@ -1026,7 +1026,7 @@ class File(h5py.File): custom_colors: bool = None, dpi: int = 200, figsize: int = 10, - fov_smoothing_factor: float = 0.8 + fov_smoothing_factor: float = 0.8, ) -> None: """Render a video of the file. @@ -1102,26 +1102,50 @@ class File(h5py.File): if render_zones: fovs = [] for ei, e in enumerate(self.entities): - zone_sizes_str = get_zone_sizes_from_model_str(self.attrs.get("guppy_model_rollout", "")) + zone_sizes_str = get_zone_sizes_from_model_str( + self.attrs.get("guppy_model_rollout", "") + ) zone_sizes_attrs = get_zone_sizes_from_attrs(e) - + # Check that there are no zone sizes in the model and in the attributes - assert zone_sizes_str == {} or zone_sizes_attrs == {}, "There are zone sizes in the model and in the attributes. Please use only one (preferrably the attributes)." - zone_sizes = zone_sizes_attrs if zone_sizes_attrs != {} else zone_sizes_str - - fov = zone_sizes.get("fov", np.pi*2) + assert ( + zone_sizes_str == {} or zone_sizes_attrs == {} + ), "There are zone sizes in the model and in the attributes. Please use only one (preferrably the attributes)." + zone_sizes = ( + zone_sizes_attrs if zone_sizes_attrs != {} else zone_sizes_str + ) + + fov = zone_sizes.get("fov", np.pi * 2) fov = np.rad2deg(fov) zone_sizes.pop("fov", None) fovs.append(fov) - + entity_zones = [] for zone_size in zone_sizes.values(): if fov >= 360: - entity_zones.append(matplotlib.patches.Circle((0, 0), zone_size, color=fish_colors[ei], alpha=0.2, fill=False)) + entity_zones.append( + matplotlib.patches.Circle( + (0, 0), + zone_size, + color=fish_colors[ei], + alpha=0.2, + fill=False, + ) + ) else: - entity_zones.append(matplotlib.patches.Wedge((0,0), zone_size, theta1=-fov/2, theta2=fov/2, color=fish_colors[ei], alpha=0.2, fill=False)) + entity_zones.append( + matplotlib.patches.Wedge( + (0, 0), + zone_size, + theta1=-fov / 2, + theta2=fov / 2, + color=fish_colors[ei], + alpha=0.2, + fill=False, + ) + ) zones.append(entity_zones) - + zones_flat = [] for zones_fish in zones: for zone in zones_fish: @@ -1321,7 +1345,7 @@ class File(h5py.File): pbar = tqdm(range(n_frames)) if video_path is not None else None self.fov_orientations = [None] * n_entities - + def update(frame): output_list = [] @@ -1399,19 +1423,27 @@ class File(h5py.File): poses_trails = entity_poses[:, max(0, file_frame - trail) : file_frame] for i_entity in range(n_entities): new_ori = this_pose[i_entity, 2] - old_ori = self.fov_orientations[i_entity] if self.fov_orientations[i_entity] is not None else new_ori + old_ori = ( + self.fov_orientations[i_entity] + if self.fov_orientations[i_entity] is not None + else new_ori + ) # Mix in complex space to handle the wrap around - smoothed_ori = fov_smoothing_factor * np.exp(1j * old_ori) + (1-fov_smoothing_factor) * np.exp(1j * new_ori) + smoothed_ori = fov_smoothing_factor * np.exp(1j * old_ori) + ( + 1 - fov_smoothing_factor + ) * np.exp(1j * new_ori) self.fov_orientations[i_entity] = np.angle(smoothed_ori) - + ori_deg = np.rad2deg(self.fov_orientations[i_entity]) - - for zone in zones[i_entity]: - zone.set_center(( - this_pose[i_entity, 0], - this_pose[i_entity, 1], - )) + + for zone in zones[i_entity]: + zone.set_center( + ( + this_pose[i_entity, 0], + this_pose[i_entity, 1], + ) + ) if fovs[i_entity] < 360: zone.theta1 = ori_deg - fovs[i_entity] / 2 zone.theta2 = ori_deg + fovs[i_entity] / 2 @@ -1482,18 +1514,11 @@ def get_zone_sizes_from_attrs(e: robofish.io.Entity) -> Dict[str, float]: possible_params = ["fov", "zor", "zoa", "zoo", "preferred_d", "neighbor_radius"] model_params = e["model"]["parameters"].attrs return {p: model_params[p] for p in possible_params if p in model_params} - + def get_zone_sizes_from_model_str(model: str) -> Dict[str, float]: zone_sizes = {} - for zone in [ - "zor", - "zoa", - "zoo", - "preferred_d", - "neighbor_radius", - "fov" - ]: + for zone in ["zor", "zoa", "zoo", "preferred_d", "neighbor_radius", "fov"]: match = re.search(r"{}=(\d+(?:\.\d+)?)".format(zone), model) if match: value = float(match.group(1)) diff --git a/src/robofish/io/utils.py b/src/robofish/io/utils.py index 940879b..ebc18f1 100644 --- a/src/robofish/io/utils.py +++ b/src/robofish/io/utils.py @@ -48,7 +48,7 @@ def get_all_files_from_paths(paths: Iterable[Union[str, Path]], max_files=None): files_path = [] for ext in ("hdf", "hdf5", "h5", "he5"): files_path += list(path.rglob(f"*.{ext}")) - + files.append(sorted(files_path)[:max_files]) else: files.append([path]) @@ -123,3 +123,71 @@ def get_all_data_from_paths( all_data.append(data_from_files) pbar.close() return all_data, expected_settings + + +def social_vectors(poses): + """ + Args: + poses (np.ndarray): n_tracks, n_fish, n_timesteps, (x,y, ori_rad) + """ + + assert not np.any(np.isnan(poses)), "poses contains NaNs" + + n_tracks, n_fish, n_timesteps, three = poses.shape + assert three == 3, "poses.shape[-1] must be 3 (x,y,ori_rad). Got {three}" + + sv = np.zeros((n_tracks, n_fish, n_timesteps, n_fish, 4)) + for focal_id in range(n_fish): + for other_id in range(n_fish): + if focal_id == other_id: + continue + + # Compute relative orienation + # sv[:, focal_id, :, other_id, 2] = ( + # poses[:, other_id, :, 2] - poses[:, focal_id, :, 2] + # ) + relative_orientation = poses[:, other_id, :, 2] - poses[:, focal_id, :, 2] + sv[:, focal_id, :, other_id, 2:4] = np.stack( + [np.cos(relative_orientation), np.sin(relative_orientation)], axis=-1 + ) + + # Raw social vectors + xy_offset = poses[:, other_id, :, :2] - poses[:, focal_id, :, :2] + + # Rotate them, so they are from the POV of the focal fish + + # to convert a vector from world coords to fish coords (when the fish has + # an orientation of alpha), you need to rotate the vector by pi/2 first + # and then rotate it back by alpha + rotate_by = np.pi / 2 - poses[:, focal_id, :, 2] + + # perform rotation of raw social vectors + sin, cos = np.sin(rotate_by), np.cos(rotate_by) + sv[:, focal_id, :, other_id, 0] = ( + xy_offset[:, :, 0] * cos - xy_offset[:, :, 1] * sin + ) + sv[:, focal_id, :, other_id, 1] = ( + xy_offset[:, :, 0] * sin + xy_offset[:, :, 1] * cos + ) + + assert not np.isnan(sv).any(), np.isnan(sv) + + # assert length of social vectors was preserved + assert True or np.all( + np.isclose( + np.linalg.norm( + social_vectors[:, focal_id, :, other_id, :2], axis=2 + ), + np.linalg.norm(xy_offset, axis=2), + ) + ), "The length of the social vector was not preserved." + + return sv + + +def social_vectors_without_focal_zeros(poses): + sv = social_vectors(poses) + mask = np.full_like(sv, fill_value=True, dtype=bool) + for focal_id in range(sv.shape[1]): + mask[:, focal_id, :, focal_id] = False + return sv[mask].reshape((sv.shape[0], sv.shape[1], sv.shape[2], sv.shape[3] - 1, 4)) diff --git a/tests/robofish/evaluate/test_app_evaluate.py b/tests/robofish/evaluate/test_app_evaluate.py index 4b1332f..1d0ff26 100644 --- a/tests/robofish/evaluate/test_app_evaluate.py +++ b/tests/robofish/evaluate/test_app_evaluate.py @@ -22,6 +22,7 @@ class DummyArgs: self.save_path = save_path self.labels = None self.add_train_data = add_train_data + self.max_files = None def test_app_validate(tmp_path): diff --git a/tests/robofish/io/test_app_io.py b/tests/robofish/io/test_app_io.py index 6c3a16f..c9c7515 100644 --- a/tests/robofish/io/test_app_io.py +++ b/tests/robofish/io/test_app_io.py @@ -22,8 +22,8 @@ def test_app_validate(): with pytest.warns(UserWarning): raw_output = app.validate(DummyArgs([resources_path], "raw")) - # The three files valid.hdf5, almost_valid.hdf5, and invalid.hdf5 should be found. - assert len(raw_output) == 4 + # The three files valid.hdf5, almost_valid.hdf5, valid_couzin_params.hdf5, and invalid.hdf5 should be found. + assert len(raw_output) == 5 # invalid.hdf5 should pass a warning with pytest.warns(UserWarning): -- GitLab From aecea162beac0b20371be85cbbc0f17d0880537c Mon Sep 17 00:00:00 2001 From: Andi <andi.gerken@gmail.com> Date: Mon, 24 Mar 2025 15:03:56 +0100 Subject: [PATCH 8/8] Deactivated windows pipeline. It fails but not because of the tests but because of the infrastructure. No time to fix it currently. --- .gitlab-ci.yml | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 1ffc452..8649c2f 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -11,11 +11,11 @@ stages: .macos: tags: [macos, shell] -.windows: - tags: [windows, docker] - image: git.imp.fu-berlin.de:5000/bioroboticslab/auto/ci/windows:latest-devel - before_script: - - . $Profile.AllUsersAllHosts +#.windows: +# tags: [windows, docker] +# image: git.imp.fu-berlin.de:5000/bioroboticslab/auto/ci/windows:latest-devel +# before_script: +# - . $Profile.AllUsersAllHosts .python38: &python38 PYTHON_VERSION: "3.8" @@ -54,9 +54,9 @@ package: extends: .macos <<: *test -"test: [windows, 3.8]": - extends: .windows - <<: *test +#"test: [windows, 3.8]": +# extends: .windows +# <<: *test deploy to staging: extends: .centos -- GitLab