From 9602ba19b2361816b4470922b5d5476857c0ff97 Mon Sep 17 00:00:00 2001 From: Andi Gerken <andi.gerken@gmail.com> Date: Tue, 9 Mar 2021 16:07:08 +0100 Subject: [PATCH] Modes for evaluation, warning in vector length - added all modes of evaluation with tests - changed the assertion error of orientation vector lengths other than 1 to warnings. --- .gitignore | 1 + examples/example_readme.py | 5 ++-- src/robofish/evaluate/app.py | 29 ++++++++++++------- src/robofish/evaluate/evaluate.py | 8 ++--- src/robofish/io/validation.py | 13 +++++++-- tests/resources/valid.hdf5 | Bin 12296 -> 65536 bytes tests/robofish/evaluate/test_app_evaluate.py | 5 ++-- tests/robofish/io/test_file.py | 4 +-- 8 files changed, 41 insertions(+), 24 deletions(-) diff --git a/.gitignore b/.gitignore index 043b66a..5f76274 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,4 @@ env *.mp4 feature_requests.md +output_graph.png diff --git a/examples/example_readme.py b/examples/example_readme.py index 62df233..197ada6 100644 --- a/examples/example_readme.py +++ b/examples/example_readme.py @@ -10,11 +10,12 @@ def create_example_file(path): # Create a new robot entity. Positions and orientations are passed # separately in this example. Since the orientations have two columns, # unit vectors are assumed (orientation_x, orientation_y) + circle_rad = np.linspace(0, 2 * np.pi, num=100) f.create_entity( category="robot", name="robot", - positions=np.zeros((100, 2)), - orientations=np.ones((100, 2)) * [0, 1], + positions=np.stack((np.cos(circle_rad), np.sin(circle_rad))).T * 40, + orientations=np.stack((-np.sin(circle_rad), np.cos(circle_rad))).T, ) # Create a new fish entity. diff --git a/src/robofish/evaluate/app.py b/src/robofish/evaluate/app.py index ab7b17e..4d603f3 100644 --- a/src/robofish/evaluate/app.py +++ b/src/robofish/evaluate/app.py @@ -13,6 +13,21 @@ import robofish.evaluate import argparse +def function_dict(): + base = robofish.evaluate.evaluate + return { + "speed": base.evaluate_speed, + "turn": base.evaluate_turn, + "orientation": base.evaluate_orientation, + "relative_orientation": base.evaluate_relativeOrientation, + "distance_to_wall": base.evaluate_distanceToWall, + "tank_positions": base.evaluate_tankpositions, + "trajectories": base.evaluate_trajectories, + "evaluate_positionVec": base.evaluate_positionVec, + "follow_iid": base.evaluate_follow_iid, + } + + def evaluate(args=None): """This function can be called from the commandline to evaluate files. @@ -24,13 +39,7 @@ def evaluate(args=None): (robofish-io-evaluate --help for more info) """ - function_dict = { - "speed": robofish.evaluate.evaluate.evaluate_speed, - "turn": robofish.evaluate.evaluate.evaluate_turn, - "tank_positions": robofish.evaluate.evaluate.evaluate_tankpositions, - "trajectories": robofish.evaluate.evaluate.evaluate_trajectories, - "follow_iid": robofish.evaluate.evaluate.evaluate_follow_iid, - } + fdict = function_dict() parser = argparse.ArgumentParser( description="This function can be called from the commandline to evaluate files.\ @@ -42,7 +51,7 @@ def evaluate(args=None): parser.add_argument( "analysis_type", type=str, - choices=function_dict.keys(), + choices=fdict.keys(), help="The type of analysis.\ speed - A histogram of speeds\ turn - A histogram of angular velocities\ @@ -76,7 +85,7 @@ def evaluate(args=None): if args is None: args = parser.parse_args() - if args.analysis_type in function_dict: - function_dict[args.analysis_type](args.paths, args.names, args.save_path) + if args.analysis_type in fdict: + fdict[args.analysis_type](args.paths, args.names, args.save_path) else: print(f"Evaluation function not found {args.analysis_type}") diff --git a/src/robofish/evaluate/evaluate.py b/src/robofish/evaluate/evaluate.py index 4db49e8..db339f1 100644 --- a/src/robofish/evaluate/evaluate.py +++ b/src/robofish/evaluate/evaluate.py @@ -160,7 +160,7 @@ def evaluate_orientation(paths, names=None, save_path=None, predicate=None): if names is None: ax[i].set_title("Mean orientation in tank") else: - ax[i].set_title("Mean orientation in tank (" + names[i] + ")") + ax[i].set_title("Mean orientation in tank (%s)" % names[i]) ax[i].set_xlabel("x [cm]") ax[i].set_ylabel("y [cm]") @@ -317,7 +317,7 @@ def evaluate_tankpositions(paths, names=None, save_path=None, predicate=None): names = paths for i in range(len(x_pos)): - ax[i].set_title("Tankpositions (" + names[i] + ")") + ax[i].set_title("Tankpositions (%s)" % names[i]) ax[i].set_xlim(-world_bounds[i][0] / 2, world_bounds[i][0] / 2) ax[i].set_ylim(-world_bounds[i][1] / 2, world_bounds[i][1] / 2) @@ -367,7 +367,7 @@ def evaluate_trajectories(paths, names=None, save_path=None, predicate=None): sns.scatterplot( x="x", y="y", hue="Agent", linewidth=0, s=4, data=pos[i][1], ax=ax[i] ) - ax[i].set_title("trajectories (" + names[i] + ")") + ax[i].set_title("Trajectories (%s)" % names[i]) ax[i].set_xlim(-world_bounds[i][0] / 2, world_bounds[i][0] / 2) ax[i].set_ylim(-world_bounds[i][1] / 2, world_bounds[i][1] / 2) ax[i].invert_yaxis() @@ -680,4 +680,4 @@ class SeabornFig2Grid: self.fig.canvas.draw() def _resize(self, evt=None): - self.sg.fig.set_size_inches(self.fig.get_size_inches()) \ No newline at end of file + self.sg.fig.set_size_inches(self.fig.get_size_inches()) diff --git a/src/robofish/io/validation.py b/src/robofish/io/validation.py index a1bde52..0e1ae54 100644 --- a/src/robofish/io/validation.py +++ b/src/robofish/io/validation.py @@ -6,7 +6,9 @@ import numpy as np import logging -def assert_validate(statement: bool, message: str, location: str = None) -> None: +def assert_validate( + statement: bool, message: str, location: str = None, strict_validate=True +) -> None: """ Assert the statement and attach the entity name to the error message. Args: @@ -18,9 +20,12 @@ def assert_validate(statement: bool, message: str, location: str = None) -> None """ if not statement: if location: - raise AssertionError("%s in %s" % (message, location)) - else: + message = "%s in %s" % (message, location) + + if strict_validate: raise AssertionError(message) + else: + logging.warning(message) def assert_validate_type( @@ -329,11 +334,13 @@ def validate_positions_range(world_size, positions, e_name): def validate_orientations_length(orientations, e_name): ori_lengths = np.linalg.norm(orientations, axis=1) + # Check if all orientation lengths are all 1. Different lengths cause warnings. assert_validate( np.isclose(ori_lengths, 1).all(), "The orientation vectors were not unit vectors. Their length was in the range [%.2f, %.2f] when it should be 1" % (min(ori_lengths), max(ori_lengths)), e_name, + strict_validate=False, ) diff --git a/tests/resources/valid.hdf5 b/tests/resources/valid.hdf5 index 0c14ecdad78385832491ff147824f4e259701a05..de8a1dabe0f28ecb0b1600aa01918328858e0f8f 100644 GIT binary patch delta 4088 zcmeB3XlP)WAR@-dfB^rYfWcv+mIWin#^kfif&vTzU<D2eV2UAQvLlN>r-L$t$FXrE zI}@YC<b`bFl12GR`6Ubt3~8Cg8S#b;3^0=&HeX~DXJqV{{E~fny*5N2Lr6wuu|j6C zLZU)(W^O@FszPc-BA71EEXh#FO-xBuC@oM(Ni0d!1M6S`lMD>3P#R_#j1Q9sSq@Ud z1|k?37^I-IZ>2&}eo3NoNoHQULUBQ2dPZtVNu~l!IaCz`1H%oz%^Nw+@okcD;85UU z2w(si%*Y_fAi}`GkXliYT9lcanpYBEoLW*^FgZX)Lm2EM4u}EFV2We1wWKbmgc_K` zz_D>7KNI7N$(2f}noQ8ZWMC*TfXKoEX$F+P!w4)s`M#1Q=L%?$pP0<3{C%<lGY680 zp^E!t0R<jLmdUy-vXg)5RWK!}PnJ<tVBXj3KG|Niktsoc;zqg24^%mrI8--o6kx1Z zZ~~d!xmMn(Ay&zuVw;@PMb1SIXQgGGZrXlw(7P|?B(_Jv@fDMVQ~zdn$IqQ2PP%t< z9q*S5IGye2bi99o+sSt7Y{&gG*_?bB*E-(6&fvtdd#_`R(pN{#jb|ObXFPXwyL`v7 zPX4CjLBZFK500I5d>#A4(RtoZM~+Y?C;$3wD;y7gVRtH0nC=+D!RvJ9U5jITm!Okv zMvCL6GEt{|c00#2CnTKqJrZ+Fua$P<biL=WYoe@E+5KLJ1<dkJljeQ4pM6o@X-&D( z{u{sKog5x|9nh(lciN@7?SNRkoKu@H^T7kNWt@b%wGRF=lXAKk6nN0aSlsEPQ~AL~ zlZBn?KS=f+d=bs>#K5rNVEP0uCnd*?2L;SnoqA3mJa~S|KS#s9iw7C6e{{TD@!;SK z{wIzCHSZ2?kG<+x(fRw}wDU(DPoHEy)LpvOk-?qwP)WiP$IWJZhh8t7<al<n&>?G| z2FLS}Vuz+aiE_N!B6&!_-^6jYpv<AQ|9Kqi|H>YEW_zmMp~Fr7(7W(rhs4+NhYE@o z*+c35Q1Js$g$JM-4?+z(2sQX1)S`n>iw{B_auDjUgHVSaggX2n)FTHW9;^R&5bD!| zP@f-yhQuLg$Q*(O)gfrm9fAhsA!yJZf(G>=XwV;mM#Uj$)Et6F)gfrq9fC&XA#l{z zpO82Njq5|uxIY9<3y1t6Y2y$yu^fUXo`cZDbP$@j4nh;#L1^MT2u+L!p^5V#G_f9p zCf<Y4#C#Bxxa-CCC>(^Q_5;w=e*l^r4lpn%IMvk}tYBbZXkcJyuupr^xR<Y{-o9iZ z<GuqQYVAc%JMLQ-Uu|D8tzloy>2mv<hvx6QV_t0E?X_=T>&jev5xzV7jMUQYkH7n| zZ^o8*`+{_?{V6fw_7_fw?>AKQv)|pTQop}e+QpuArNRD1ZkF~Yd^Y<}Z_>5z+~~Id zS+=4*Z%4rXIdvlTv)d#0JN#g?f4eMkf6LjAc1!MN?04UI#V$Z4fB&hS+w3Y!O7?$d zong21M8*D7S%r2nGBx|JTrjjV&Z*n4{Qr^d(XI9S9iKYc9{66rKkR4Yo(G2Y`_K9= z+WR4;zIOlbU4r}0Osd*{eMjKFn0;mY9rHW(O@3Xtzoc~ezD)U?{p{zD?z<VEx_@)b z<9*jR$L?ofVA$`Y9J+t&S>F9T+r9VOa!T(%SK+w-%M$hd3dLso?@Tt{pRiGD|MHV| z`&Zk`?LThix&NA$;Qop`LHlDHnD?Ju61{)*oLBpHbS3XUQtx(lpJ-dweizyG`!@9# z?02xAurG9B>3*@9iTjRht=unnT5;c*6E*wytvR&U{(9YhwXckO8SmGFXj>4yZ_OcF zkht7wMLUo>v6+c>Abk$@6YN0dxX7-z1DSio?W`Thp4D?+*?}yMZD6(sIi&uYmY_Yz zWvlJw>_IL~*r;U>a=AjWnLWrS=PDfSK|bTz?rjh9sgH7~J;>+RH^<t8LgHq8sy!%V zGUapZK_NByb)h{d<YM-f*@Hsz%%m!NP{@8rskH}%^aI0sdr;6H_+D=x_Or2m&(W>* z_Mo^i&Z)Bp#Z`<<jXfytmY%4v2gPM&Nr^ouZUa>E?Ll$9<ZgyNDDK}bOSA{2h1u<q z_Mo)E+Yw+7N-Lck-RwbW$Ar(u9+Z|?R~p!Z($?-?6?;%xyKq9>9+dbB(z)zYV!XmZ zsnZUWIz{;I*nv`Ox7R*9Q0l#TXucgNHCIe)umh!Tk<*TLpwwQnkkJm5`qQ2?+JbU} zkzubYD0gh!cz|*80+lMJ1}&s2r%;iRfgN1wOMogG1_lOL)4{+PBG53|QBs}rfC+>L zt{0hZ7)`EHiUU<{Tr=QR;AA6}AW=}w4l^ME6ax$l3=1ZI)N5o)Fh{bsk%_|$R?}@T z-nfwu>NG1PHH8xoc-4CxVqgGamqQE;Anb66fdPbV4lyu*u*D$;1`swm#J~W;28S3J zKv?Gx0|N+a9AaPqVU<G+3?Qs<h=BowWezbgfUv|N1_lroImEyK!UBgF7(kfk5Ca3x zA%=PejzbJ!%yNi<0fZS2F))DekAn;hApGGV0|N-ZILN>N!VeBIFo5ulgA5EHeBmGi z0|=iu$iM)?2M#hYfbfok3=AN=;UEJ82(LKEzyQJv4uT4J28J02LAaiQVZuQO?l{Q6 z0KyFi85lsg;2;A72qzq5U;yEOgA5EH>~N5Q0fY?>GBAL!!a)WG5EeMdzyQJw2N@VZ z_`v}N1`s}QfPn#o7aU+<0O5uM3=AObaDaiqq5c2^0|OK{>}OyAVMsy%B`FAIaA05n zVFw2W1`uv=U|;~@1r7`hAbh}qfdPacI504PFoPol0|*N^GBAL!f+GV12pc#uFo3Xw zBLf2n2RJe?)PslwM+OEEE^uUE0O1Bl1_luBaAaTr;R%im3?Mwik%0k(7dSF7fba@O z1_lt`;K;xL!aE!p7(n=dBLf2npKxSg0O1Rc3=ANA!;ygjgdaFEFo5t2M}~Ta7mf@J zAE5Y$BLf2nGk_XpP7DkzP|V@PzyQKLP7DknEa1ez0Ky_p3=AMF;l#iI!ZJ<_3?K|@ z)Pb;y69WSXYdA45fUu4e0|N*fI59AQu!&PW1A~bZ1A_$=+c+^WfUtuT0|N-VI59AQ za057}t*fiI51i3pYyG>{K02JiPPe?;{_z8QJHfBz_7hgr*%ek4+gF*-v%AWgXMd?` zkKO)7>Gu2n-?Up4nqZ&3;+x$A)d>4f+MM=#rTy(CD#h%h9bD_}E%cP_H_WxNueQ?% zV-Oo87VY2)Rs>QE(ge~BG6ZB8$WV~sAWJ}&fh+}C4sry@F(5~Q90zhF$gv<tgB+ip zZn+ecrUGX)>;>_o!|NILffPJ`V80KfVZw^KeINs>%;)U`8F;B`&pwa^`~KhD2eL4G z#kYMR2Yk}z+z)b)M5Wk%kOM9Bl=p)iTy3YnALM}zbFKD+JQVHVx*z13z0&^sK|X$< z8nGYbnMI)q`$3-BU%x1QKgc&%S@ZUTd{bCaydUHn!LQ}}LB7!~uig*xjrH%^{UG10 ztE=A+@(n1bq^Iv(3gz#G%7Z)sQV;S3NI%FEAoD<;0GSW+1js&+CqNF^R}Xa&$QK|N zf_wpTF~}DnAAm8)H&7pf)qs2h4FRwrAm2bk2y6++GtdwOI|AexXb6Km0Sa-DCqPk9 zzyH+9U{IWZ1Z+XR0r?W-8<4L-z6XT@$N*5NfI<fpN+1hBp|*cfI>-T_1O^UPkY_-_ z400eS*g?S$iUyDeK+yt<9#AxaJOhe0iApaqa6E#11BzEr{DR^c<Qq`DgW?~Q4nV#E zr3+B{0HqU<Z$RlLI-Fr2C?SCo6sZ0a=V5Sw_6<3}eZ%5{)a1<6;>rJ+6{EqOPH2}9 z)DcCfm0;ju;DGdyi%SyoQWA?&;&W2-(n~VplXDpaI2afp+CT*%q6f_dD<~N_HgDwc O=ihvQ!+|{F9|r)l_U*?2 delta 1860 zcmZo@VChJhAR@+q1pY$-hrvWG3r4n$$!D1-e^8N{?8V~7bVF(L8~cBZlP9nVOm<-B zn7o2bfGsVvI3wP0avXcoWCivUj0%%qvM-<ff<u4{W@ZBaW=4lF{!J1N9FsjH6__@t z!VTfnfGTR3{7@;Ck#TaqvLvU20z@!j@<iqD6Ax%iR$%5p5}c{xKDon~hmm>mTa~Fy zKh!4Q(N|y=<k6k%Z_>#0LvP|nxyc_)IG7fwY~0AlSpT)*gyW_VCZ`FjJr16^{r%vX z*gprgd?OCtU$Dqgc8cC1DKRF;soz#Rt`BJ3Zx<nUU?HP`)9oYrP8~gO9H+i)a>%=G z?6gT-?$G`lN{1%6t~nU&SmEfi-S&`1AEV>Xe|e72`oA9hB(V2jZ!(+H9!thU&-X|; z)tnbRs1YH3u)xQ%-r@dEStsLOo<oxlOCL%Kl6DMINOF9;i`B^^O~Y}cIG@wMw*QWs zgM^)&Zc92TaHu$%IkG!RFxxpfN(eb!vrIYoOks+nS(EsoUm^^LUacxTnB*^b;9sfw zp%sc}4r*%6I&kV1@1eI5U5*Y@jvXw%VQ}by&360hnNJT&o#H$6psAkO>HVFqg9e5I z4kiyR53Z<_ajJf>_28+U8csj|u{tXCwK(!Ii#pwl;yb7_@6W-+i5iYIb21zqJhwa6 z`*1kz$TxSg%Z#xvJ)>~wUBd^*zZEkMrp&o{kWI<@&<56t2d8~xaB30=JotXcUB|Fv zUmWA&wGY+)usbN3$KW9U)y2uQ-cS2b+Bt<oi!|&HH9wr~pj}(-X#7>_;QQD|j!*4P z9H(5laqyN{ujA3o97m&<x=uB-uQ-OxU40<*jN&2JZ7Uo%s8u^Mh#DM<Dd<1Q?Wg7V za31HOTbT<EKB;$hda%y;P-EwQM`cqv$0eoj4lp-qI^H`k;#7Qd?!lXP-a6{7us>9v zuOD*ot@1+027Qi0A*=Kq7Fq9i%u{^h=o-d;C~$JfL2pL!Lz?@t9CDt2c5M1ycraqp zuY+2CYxnzeNFRC=q3tB``?G_v?k&fk+1K}*h3|9JKc93!C#C7&sh39%c6PcPQeD5- zvF7EC1MlxXI_U7~)IpIX8>h?tmmDr#_;gU+*|Og0d4}AfvQnKx>vbd!Rm3VCvak$2 z*t)6zpm+|0)8Eamj`ggI9Jp#99e818;@EOc%}H*yxYIG8tByH_%7=ch);Sblr*nuw z{M|vXHOz+;6|x-P%-ZdE`0;Z`|BkYQ$IJ~5nxEoxN;!1bVJpWO2M>Ye2i5jGbevoN z;GkM$i{qht_4NntMG7CX@#S}_pCf%}O`iQBZv%a&04`^zdSB0j&U~haeA{;&e0MB( z|L4p6`%meu-`Chvu-|F5kG;n+$^8*~!t5^kRoUk~skg70{&tVwqx$_~lfT(KE}pXQ z5JTzyTN3g6w%)AVuToQJH~n~py;Y%$eYm>H{yQ=H_9~55`yJ|kW$%}2IJWP~Y?ghh zH)`#p4LtY1@{zN@;aY9)_rQ6d|LfZQ-QG<591aNVd!1Wr|4*RGp4TvX-)Za0{k{(} z?cDC{+}GThx1WQ*)V^DEirxIMdi#Z{lWk+>d)T+K$lGuD$Zx-MQ}F)D86~z>-Szfr zclO)ukI3E6{NHc?tOa8G>vM8T_PZF**vD$5ZSN3fXD@t@d;g`(u>EdL4)&M#YuMlQ z6tfGqs<rRQzqxmMZ=L<A{|Wmt_f_s!Joeh|oNe;{0)BCOy>)*3wI^KLw=F-zzPEwR ze$&Os{dphj_dnaSW{<(Yvi(N$+V-{fc<gtVk+;|M(Xsy@X1iZyN7eqHFa7H6QgSES z9WN=}pMKxj?y+n2extZXySk0#_MB(A_r=Ys-Osy5$9}?poBjD$z4yC6l-hs3;kB)^ zZ=Jn#Y>oZBU)uX5WmE0FGoRRTGJDx?vzN7(`WJ71`RR8%u}O0KZ~yez-!6M$-`=O0 z`+ZM7v`adiYR_ofVds0a%)V#d@_nUG>I>}|@4ei2ba(Rpl)l>iztwp6&B!U;KW)iG zJMNtZ`y<UQ_cJ6L?YEt4Za<yTdw=&gX?y3eto_Fo&e`q0!Lr}rYLtEX`j7j5txMdm zH@)6|^Tc#p{WE3z#ko87CB3h=|5CQYX5!aE`xVy9?OYEB*=PC-?)RLKwO>)<tX-t= zx_xi!_ZQgn>|D23K(OAv^i#zC>gVkH73Uq?7yLJ8KkK9T{bt@j>`ETI-J9!Dx1U?P zV*k;|>V2PkrT3o*^xpsKYPNlk&T%`A51#wje^%K~bf0VIHNV)tQ2(jj?AIyw!mNV( zi@pWf=ZJ*ZPfO&tPj|1_uQ8+2PF&f;{=x$J{T7ew_xoHp_-Rk;F}3|`ES&89-k0sS z;%?u!{a&v9H~YPI8I{rYdrvUium4@Sf6|s1yEd+D`_QMS>|Q;)xcAYrI(r4r^8LBR jEq0;TEcag#G_c>e@c`rG1pO*TiOIM16~T4RfBixLKw?b< diff --git a/tests/robofish/evaluate/test_app_evaluate.py b/tests/robofish/evaluate/test_app_evaluate.py index 2a7fe48..fce024a 100644 --- a/tests/robofish/evaluate/test_app_evaluate.py +++ b/tests/robofish/evaluate/test_app_evaluate.py @@ -22,7 +22,6 @@ def test_app_validate(): self.names = None self.save_path = graphics_out - # TODO: Get rid of deprecation - with pytest.warns(DeprecationWarning): - app.evaluate(DummyArgs("speed")) + for mode in app.function_dict().keys(): + app.evaluate(DummyArgs(mode)) graphics_out.unlink() diff --git a/tests/robofish/io/test_file.py b/tests/robofish/io/test_file.py index 31c54bc..013d3a5 100644 --- a/tests/robofish/io/test_file.py +++ b/tests/robofish/io/test_file.py @@ -187,8 +187,8 @@ def test_load_validate(): def test_get_entity_names(): sf = robofish.io.File(path=valid_file_path) names = sf.entity_names - assert len(names) == 1 - assert names[0] == "fish_1" + assert len(names) == 2 + assert names == ["fish_1", "robot"] def test_File_without_path_or_worldsize(): -- GitLab