From 224015ec7b1a9e022e3c5f6218d6ea9ad52f7066 Mon Sep 17 00:00:00 2001 From: Simon Johansson <52204459+SeemonJ@users.noreply.github.com> Date: Thu, 24 Oct 2019 10:11:17 +0200 Subject: [PATCH] added the latentgan model (#59) * Added LatentGAN * attempt to remove submodule in LatentGAN * attempt to remove submodule from LatentGAN * removed ddc_pub submodule * Added refactored LatentGAN * attempt to remove submodule * removed data directory from git during development * untracking pointer files from git-lfs * re-added data files from origin * merged with changes from main MOSES repo * added the latentgan model * refactored latentgan code to comply with travis ci * removed temporary files and too heavy pretrained models --- .gitignore | 1 + images/LatentGAN.png | Bin 0 -> 34333 bytes moses/latentgan/README.md | 26 ++++ moses/latentgan/__init__.py | 5 + moses/latentgan/config.py | 92 ++++++++++++ moses/latentgan/model.py | 226 ++++++++++++++++++++++++++++++ moses/latentgan/trainer.py | 271 ++++++++++++++++++++++++++++++++++++ moses/models_storage.py | 3 + scripts/table_config.csv | 4 +- 9 files changed, 627 insertions(+), 1 deletion(-) create mode 100644 images/LatentGAN.png create mode 100644 moses/latentgan/README.md create mode 100644 moses/latentgan/__init__.py create mode 100644 moses/latentgan/config.py create mode 100644 moses/latentgan/model.py create mode 100644 moses/latentgan/trainer.py diff --git a/.gitignore b/.gitignore index 99287f0..74317c2 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,4 @@ build/ dist/ moses.egg-info/ + diff --git a/images/LatentGAN.png b/images/LatentGAN.png new file mode 100644 index 0000000000000000000000000000000000000000..0d541da90bf067f681c6cbb1263275773c1c1df3 GIT binary patch literal 34333 zcmZ^L1yq%5*Da}pbm<1^MghT1cPrgp0-|(xDBU3;h=72EfV2XFbcYCnpeP-RDBbn1 z?K$87-Er@@XN+?^gnPg5`>dF2&bfB9hMFP)9s&;y4UOQAlAIPA8ipAf8oC<}27ITi zTNWMuKzG+tltHWbK=Tv6z_OKAl}1CWPQ*X8#D=eNU6c&n(a;DxQU9X%I+s|Zp@rAn zk(1W(G5ei;*U>c=GpTWaK@c=r_LZ%j#PR6oLF zk9yA#LmzXSgg9xYYR;7#-dF>BY%X0AI?N^%W<0DdavJKr|X?p2e71xmm|lOqA1$i+c7U)@!4CW z(1*XyF8NYHrn~`u>&IvT6B84Ud5eRiV|Gr?hnvLMCMG7_CJh!24i5M3ZEwzZ4h^X| zwXcro>FFsbD467((ntn<`}{ebD}cb;`yli((SqpG0G8$sJ(BcZcfxhi7d`H)AFfLT zG#J+2FILN5pp3#GV^w+ZrKYf^X8Yxx6uE^B*QChkdf{owspI3=WNO1a86rY#J zkmWpCk^LY%L$g5Pp}oC(_(-l~z@bKymm?-7<`C0mIx!z&e%|Nt@dQr_pFjVW%%q5i z2qMJ6kz&`;(&}iHkdQDmG!zjL5p?=yP-ER+V>385HnzY1#mdn!KE2ZH{nwVjllT4y zRZ3=N=lv;cuU@@s^4`c1c4wWz>7{(IxVHAr^VetfT#5PcL$P}2*`C2s1AYCc-U$b1 zCkMYKEA@4BboBLmSLFyK%%+B6W5%%zg`*}?Q&UrYcH1&}baZrU(mZ%?ch@>r4x=QwP=Qu9k}zA)nf5;AO{3d> zyAOuGHM*~Ev|dE$%<+qhi;qhQe1GSel9nd!yUlFk`Qsyl!E4)*oR4c=(PRu&OG``W zn3o*BHIT$Re}A{Pm?*iwF-w;dT5VMCBNPVjugkXv*QAA?pNf0qt}$_NWM*Xe9Q-jIQo3_z{$tmdFV`xw zc`Y$!aJp5q1m@oyw)wnnYC5?%KWJSU&Kk+M6&e^6LM)>cA9jewWic679_kwyc=j0c za*Uqf^98%-w;ztdUEQkn+gpQou)4HwH&IH_rClYKJDepb8+Ps5wel!=V%^rDvln;t zFe8ZQ2&wK}awRYd>ETUuIXc;BwR&^oMrlcj)cuUd5_xqE4MOFm&a*9>jGX8*B@o#- z2?kvWl4yhy=ZA|#VM;@l2Gtl0X%kGW9mD0c>r>SQ^0Aar+d)A=aJ(!Wq%wrBl8Nc6 zp6cuA-QS6lb8~ac%gdwqp|<^=y`qPNNnvwb$nop#p@7pvlH-k;_b1g>y`-e18ln0( zO&WMx_ox;Y7cUdhIy`*XpLR6vv)DsGb8EB0{CscH+<$-Fa_8c7he$8{S`LClkOa%o z!NCB6Hx!bvX7Umi7JXw!i&bwzI7TMN(;J2%9>yQ{Ya=2qVX$mqtj1FDh!YVnpp!IHRgJ%Sj~8jy5|GKY^J)tY>M*^dql1^o8A2g(0<#(0IwF(N zuCA)ef`GVdt0?M;Qjk)`-Me>tdU|;6LoX2o1O|?x*_kx^+_EpPs7SanpuJe8U-{1C z=k}lPqD2a3&zN!M!yveE8yg!nFfQIdA|WA}YIVofv{>p(7WG=U*nw>1(kp-DnIW%N zVbZV7gjw_ObNRLNiBF$?O_Wh>1k0I#ZWDGE zmI)|mwYDQQ@>PU8>4k;M5d_q6adG4n6qgAQVFizg=p;_h4rcl94?U`08On5rl#=WJ z^Svo6GqZX!L?ZCGiuN{C!i|r4vPw!gD|tOeZ|oQhsBZUf{0ybI2M;I0vl9}&W3|y` z9>c!MeN`>s<@mS`8(Z=5-mlKj5E)&osYyu|rJf?SM)_B#N`P`!H zEv1oxU_1o8+7||BR~X|Bn@`l$EQ#`2eubm2m(+1uOi-}vmz%*50=Gc)t?<0iDzL^{d)i8t7ANZ|UT za2HAy`ijh3gFY3j>q9+n&9KE3)XV7V>Z0J%yK8%B(xKKiHb1W|g#D1QWA4JL)>(w}g_E;=_%Jg&JN}-YUV3=A z{K22^Tw-$kHP^?nso?BtY=b*fp+UX(+bi94pMtAB+5L&hv7M8#_Ul(!NeR@Qx}_y+ zJz3&>Y}DiGQjdoM6zb;mt)V@w)xzFh)MwKICz|VNBy>~;o=e!+uFE_Q-{0xpxf6}s z(EQ-R1NujlzK-|rv%0-(X^{X}k&HSL+#|c@C2hIT5)xA+v^xbevmCBXbX?VJ5enr- zw@hgjlN#$g(VaN{etwL1w`m8AyzW_9^(N8*0tl=aij0o#hqTnz)^6mJjv|DOEyklswayI+EYC_C3sP)J8 z$zfegm}mZ7R2azUsm+A!Yy@;lg!jCx{V6L~TAC?%!J#+WF*VgtH!c)->|xDUVRC;; zgPNsCfryCc;X`P}&hQ;5*(8?A%JBxfKR>a-+dL1vwl(4zJOghdE??aL%*Rh&ql!`l zI#@=VTFG{_S!_*BxGyjHr4v?-zWROsS${Qw874Be5$Gsq?JgSdXv28 zQuQ%0F(&i)KY#vAVN;)J^5)T})cepVhY)+!CunF`-mtcM$BtYA_2TsD;yMHb1pd1} zT=&*=y}G%Dg^wU;o;-PSUDPXVh?I=Xu$FIozBBS@XSLqxJHqlqcQ*?g+l7w1fx*=I z*@>vf4_=Elw3(7_o7K?*#RA1dr_ucGXGnV=Z+odV$c1o()xCQL@gLo0ShOxQ&)YuC z#@eK}q$0wJ0h+Q#6(Z0*SZ?0D$;4FaGB3lfdiUs75cQDby?YBk ze@5%hyMM?MB+}}NG^w(9tQ4njatpa)6SyBp=RLO89ic5aNmKn#J9DzMU%0~6vBPjAfY{d`0$}5tXo`6 z%%E^n8y)G_8f$ppVl-&FX!Z)F((}|bp_@#W%m$l(2A+%bIh3~|7;QmcQ zgLt&+E3}+J@x+;=jgEV36NTdI!`F+BA3q*ctl{@1L-LzQle1SMDyRKDXI|WNy=%76V;!BP&ElkX!l#;C z^^DO~>EZZ&+4}lqlV5yNcI0Y*gYw-x`REHJ@6P$l5x-_zgQLZj0s{kWVa|ZeQB0(D z)isxTvHqoY6%g#*08WA{k#T($LX`8+*R-^>1_lP?t>gB}4Qq?^E6ob`62tWboo9w9 zLl?I$Annah zj?n$C9HkX0Bt3oVeYiOf@p@g@4U6`=fWs7|8#d9WYRk?`TRK)VtHW&06Cy>ax%jaS zRfpDsg&!*9pZ=b4tE5nJMKadF5N0AP3u#76KrLYM#%|&Xi!$fwxJmv!nCfml&em9# z^pE-S%ta-%_~H$Ey#Y6}K0IIgMc^g|ErOXKWwVVrAw zdwb9aYyA(L!br!T3%Q1Pu90kZxwz^s=IjvJUt_=fEtSKreXNL@uk{yq%gK16f8ftl z&3HI!IFYr{g2?c2?Ay`i#Zdp6&AI>uN|h@n$Z4&AZ#pQ|EyKXXtR|>?_-O@zKskXX z_uo3aZ$~5K@&Zt#{tA<*sQ+wBAT;`j&W2Kqg@uLj@o^YcbOQ{ydX@*DD*wtuT)QR= ziC$saL>V3J~Kv!f$FKcAeO+}hgOfI8nK2@!pI z9zR79^5e&ksVM^ib9pMxO{x~R<$+RzMxU~+-_!N*f}9dg66E3EXZwFt@Iazv6$1 z#=-FlCK^WR8UeXRz=B8O3rWJ8k(1?y{4|+w9lq2+Sae1bt(s0$%EiUUZx#RgWxU4S z`ThI%{{CpL)XSd|4D-;kYRy@XUuN*xpJ78&(8EeddK-?9 zkH!6VH+OboqoQz#h|~e%I}dy%UIqZQG>ds~22DG8-W4Q=kDo=r}k|Mbz`VFK>GRkl$4b%D%s4nJw{O@Fi0iPLg;&~PoeYVK$P3F zC544yJ-8ygIFQCAEG$g?8!9w0L9$$)%6dpLHYNrJ%qO9Bn)lFdCII_l2F)1w2>tb7 zfU^?5+t$n!<SGb!9Vk(#^d7ywD{yy8h1;&3w3zx63%6=!RWC1tL^dRh zj!OKY)7H*#wkTEWeRzIBQMtkA!YnMFjr&6~8R zPZa}bkapuh#TfGB@P(+_0POae&$wk(P@XnonsN$(U47Fp?>L0f8 zZp@}A`$8RQ%*?gj;7{XCI~zDVcc{raPIP7+@((d8I`M%I$N#E_IM9K3u zIs{h;Ou+(9->#ZVybZgP!iof*g^%w7-P6y{58!({zKxB|7MFp6LDl}>+_-pbN`0Q5 zu~{`Kt0m1Xov=P#&+~-E>($Qci0#eftIW*Io2xOtCHSsb!7(50CM!(y1V2YyE{}d0 zrF07p0dBbtfZ*j;A1WG>#hbwRRC#nD zT;229;tYyyveNtYl4@gqL0<)Uxj9<#ACN$$nX6i0BHEa3ZH2s}pcbc#WSxqnyf#di zNi@qFWL{ySMr`N5x7NKsDCUwBreq|3vOHH(XE<$Um({$3qqI;upmJ>jhsF#GIhCJG zM?-=a+mICePFQteY`n$TD40rP#ayM#6|G=`t#=19M{hy9gew*Iy-d}Q~ocG3)%;1)u(qC;v6jk-13w? zqOg^$n$NB$!ZZj7a^{_9EDT$49{`P?ac$1e%{@9OZR6EC^ufud;xStq2msKf2-DYD zfEZtmTW(AI>u=7}zuC2^9dEk-$bo^WamS&eZiVP8SakF^JLho5ns=^tTdCZyFC+I991EW-XZB_KnI% z)IaanU;Id+zVyZ%@db-qv!JErT=L?ylLD`xwAAQv!V4%sE@`h`ap*EKc2k~o+#PcS*=qS&Px1p zQKPEb;db|&Y>{3j3qx_i+&J^}*$$toV7aj45nts{F{jdBTz#0hU?&FCg)!Fsp^WN) zBR2t}dg*vtTG~GD){Aq11D_b=BPuGmH|$uXb{~%BV{ZsNf)en?;S1?}qI@ocY{VQ( zhBQl)TI0h{tY&s{3(tX|v=6;jhkNCUb!-e%44V$mH-em+`R*^qbW96(e^gpNcu-PQ z)I)*C59KH-g)8-%+yyp1W$IzGU_do=E1r#RWwf37RW4UK8+-d-u8Te0y}jV8KYcchyjfh&#kNam;bV=AmyVsCT7AYC1KdDMh=1o7jqi$24 zGe3M1lH+bnqgCJn48eIp1ESjdek|H)60GDYJbd^Q@PDH-P}Unf5;IZ zY$C`^2ExA5ycJcg^qn3)`~oaDd$4is0?(X@W)9pc?DIWlKo70xYoLgJsGVl91@wPtJvutdYk^>tjkwaD z4~J8`FJ{&nL=Y&G^x!R!^FZ-gnG@N-3_?@c!3Y%Qy4jbYBX~kd=xbY)KD|IQZGKhB z5j*0VM0Ed24-UYD0qyV}pkN^J-vGl41?@w&s5cxPV2$rEL4DasXv^z?-U^g5jbc0^ zobCV;L_k2mezM~CXhAQajxG1ywF!LWH~fJ%C<)c2!tN`8aM4;fpex?Iaf92eISO%V zXC6q|$ZN4ehWXr1WjOcF_F-qlaTzk zZb^H6tFgg9iN=1EOWVs0utDyMjU?CD#s2SUs0jQpU&sZctIJSvRy4`;2<&r?RTh5* z6iE`(`(3gt)uk}=umv@jNt;ac`JC)K0;-P{(_|14@q?((7V(t%wptqVI5Zlb>Ko7x zd^V^7`&lbiY+?gf^wKL#nTLm`qV%eei!^uFD#uKLCq$ZYz0)#YKH95fc9OpgN`Dv|?*{EJ*86>QboA)Qk8RUwKJ^PA zwY+@!a%p*4xtX`obx|Qqt%4r!Bd%_~+~09z#?FI8D0K^B>EA*5Jc}Iw?IVkN7(~W$ z&Mxf?vj;Z!fF8H|=ez%WM|gQJpVD9M#TYsh{UW2__73r+=6M<6S_ISczq6@l=llWv zefjs)go%}eHqW0t!NA}DG1~dz14wb2>gtv75JDoWk%EF$FBH(oprHf-;`W*S_)$Y% zUcT!SObza_$|2(hX8+V{Vlug{mYy0WA>=k_l!$yYKq@xY)-7MfSI1bS3r_j@`1m|L zJ%OP5Sg1@@a!_`#F^Pp_T@CcwPQ2B)^rA92@W3fnHnN<1#$_)`<>oGX0iW!@>;`>B zsMak5yVFucrAZ@3B&alwLm9U)E_GYKs4g!rw?N*vwRQel_X>UhG#O}cJVC;>I18_P znn!Xs8pGF4-PeDh9c8Kd`Hdg*Yb>xm&80ST6n44z=;qg1)XlerCx>=_`|mel%ML8iRDqqu#tzSu zjl>K3z&Xxt^xb8kqDS2~mncJoW}Rd4qLVI{g*isOGWcPg{^AQpf0{#cudDWYZh6Bo4Q4u9Md6zDQ!seSE=5o$a09T}b5pfb-^^p#rkuOUo4ej2Efw4ApPyJHof8wv5NGUcCwhQ}%p1HhSPM z<-71&L0^9oDhXw94(C*Ba4=L8s4&U_5}nV5-QRirwsUqK(cN-g9-t5wZd|~WQGE%{vPMw=%VD4CqPK4@Mb*wE#{Z*t1`sUr#r< zGU*dz59IeGy;RBe6^SN&6Jtk7z&d+(I& zZf<7q*$g~F$AkgC&UfcNjS)rpBX+vK(IJV1=7XY+o5apvJhDbJ6tMri4n5FB*Y}Z0 z@bS-L5V*D^KU~5k(n@{uWFcP2)C)LsP+&l89bp`C4fn404g$Cg#CuC=PgnDZXzCj> zyE7fRpb)nD-1D*7z;NSmsk=Ih;5%>`vb47L*!?jIWPLjU;Pdq0sT45TOLQ-%O=tWN2Q2@ z9f@E4*zd3+EWrH`b-8cE-uQ-%DUEpw;*}} zvuzm)27+1lDZ;N#@PQ`CeSjLpL`@K?FJ z?ei;bTP1Yho}pmNg;>W%x&0*sx3m|bB;CG!yFtd1c-+Ioqd_K9+CT?vELC#@0rat8 z9I2_x!oq?I5DD|wz24!pqXzd?!q3r%9ckfD;Htzob}E7sj?Z+q9nT9@57xYdT(>Ff z@gAp^^n9tctu^WE8yJugc`o4i75IVtf`a9xCCF?IPzz%gpFn@@On^u$`SQkIAGD~` z(^G&p9i5#iPoDg!=#BaJ8kxL&HSm8W4yfO|$2cb2-Sg2Lv>3oqjTm&$Cl&hGm2D>( zUltduXN&AXa+zcjfm36B0b9&b@#>uHncG*xo$+)tf)+{EvL9x57?$fB1o zm6VjAQd`hQB9vmN1svMMTD%-^Grm9}>Bw(tYO1dnGZ?S|WI?T{W7QpxKi`D8xe1>+d@35AsfiZyDvkD8mCyXpgo&vd}4!C69EMG;Q}uAZLDMya`wd0sz0N*f!WfENK|CK#n@ zL_A93;#6d04C&^yu-2bd)AdRm!;ST^!%PUG`ubQQKKX|?#P zP|k|`dZl%%ynG~vOEaFwtl7ogeG$*x=WsKxx_S$c@cZUw1GXFcFN8jS5fTwO2BvMj zh|qzSm6a9ZfczTQmA?{?=MusCNVwPCR5H%c*ScJ5J+vjH#ir0hfbh<*i*>+mr{kqPg?a=8B#4cOZjjZ7xWYk6>>?NI!5(Z;uxB zzsx0hdbZI<_lsAt=&sV?Y2|tRyKFBNBS*jQjm;xD4F5`QK-yBgFtM{13XbA`>(k=i z-JYF-^orM=y=Ce?+rv+_K6@ON=1*hw9|20rj+wJf_^DDT+j5HywspWc&C?xWkyZD5 ze^#9vpS@ow&1HL~d5H3xVZ2H%L9;biICcL%%ftHGBqi9vPbxq`Lx#1eElg$Tr1vG& zL+v6C5!KIKDOv@@f7SEP1d+3$85CJgLb}pUPBW3b$v=6t<^R!wUpvj81Qrc$P?-kN zC+F8dK?n%{iXgUv)H3ox3DzH|Hj_0z<6Cbvl-javh(+p5<-^O+jL2= zwy;q@t0F6vCuPjmxEX)Iz@hzF@~P3e_2>Pk)psba#m3NgT}AzlhT!v>l&q{`tK*9L zwb1#poPD#yw!VlErhm*0xG#HZFp(|039UzB_p*~CzsL5mIo%R&WZcvc3@H&nwFD`< z8$8zxg?n2~9*~FkW|PY8{jDlmR=&}3%93__kGY7k+W%H)rTAqQF@fx-2;7i4lp*jh z9@%K@pwi)KWkaSbxafTCr+#emm>%b^cAS)*#peI*Y@mR5dX>cC-7&MueBX5Hm3NiZ z)*+1A>a)y9QaN0J{#)&%Y%|A(LhA7XBaCK#JvKJlL@7_Uh|WAdSBzbH#Ejz#0Ff9; zo;zn$IW7F;-$z1k$L@Ux-U;uYpZSp-&3WGb?~6J7_h(aL_`%x( z>HqzE0E2}6x?HhC=utr^N3gsdn)GxAUWXkWyeZ=v+~4rMwxU<@u%tKPIN;unn%X~X zpl%D}@0(?O4fY7kZKJ>mnFDADw|3CZaVt9WKTqfXGP(SJJuu3bjrT>Fr@$1*G;NP6X!_s7e?jSd817=G*J3EuSN6pR6 zfcTzE1U};;lZyv?7BTWEbO%}>^ne!OK514NxM_d#ra{$4h*tom9$LVG zpj>q9er4 zorGo$bdW{lOx`0H`&|Z0!$C&XEz`rs!g8@8QG`b*FCRWU2>%QW)cA(b!{2)_H#dj! zhGJu5OG>ux?(PCN4rVK`J0f{_674nj@P$G=qMDGY8z_dy6`=ZFiIg^i1NFZ+7hkjF!^Oq@p=xg4 z3IhO)DI*^~L@WUw8V3#n#!2Vtx3A&Beq=_-=1ZP#Vw&NP05H8=h7fX@gBH&a2S(JZ z2m~uLvjy!dlq?69H@IAQEyuB!{eXg!>*1uJGqVM#mD?Tf{;W!*pS~tu{w$|2Fm#{QbAW|9I;CjH;;kz*-K4j(G0uYEV~T;X z3FbTBlU?AoA%ZCQ!NNy->mC;VGZ;v56rvQPTw(kV+?tSQ zG|-l@TtIJMU0G&FHUZK8Dels60>QQW%}Z)u`;hd7KKdZ8Q|mXLJ|pD8*Ysx!V5z%G{ZOr;3*0zASmeV;v$7P1f*7N6BS6pkP~Z@ zl@JDI5U)O|do$K{Va2x+(eR9ojg0Q)JWlK##}SA6F7p!3Cchn?Jg8BJ+F!weV`%s7 zy}zJxkQlgQH+jbU7MLC(rcqR#%tI7C4?x;b+#=8e?3&ARs6$zxV2np#Wgo zp&}(ERiF?Lxfp%~R{?DOrbGO_Ca>QHO7b6<3Ux#1g@!oPDUnRa!WpI39vE1>kfm%*z_(6(?giynPtzS?^U**ceG}hzw>z5n{XXO!1%a0 zNWROp`J_8wu7g3kLu@EU;JS#1R0H@JJMazd;t~^kspIJ+X*ZrEC84QbKxHV?do2Py zJ0}=J$**0Dz7I9Ejjij7G%+I{s+NK4oj{uL#=ziUu81d;&X@@%{0hTba3+_67B0Ik zy9~&7BwMr>0cN%eu$TAtUfuySb3{Z0h*?@XIxq~enZ@2UaEHiVALT`(9G{rz#?y?J zLRl5{5_g3Ik9PrSclY$4$N?f~&@AJGf*^L-%YS|L~ z6R-q*q88OB;+Yt38ZYOoCo8U-dMA!HDMMokK=r^y=$T;_|L6aORV}WrdV>=ZWT0nR zS;40UxF>%=@D=s`1H}Qnk?Y`eO5^kDJvrD=RZ*FWT=LYVFoIWLK*6YS8i@d(wnykS zpwO?a!Bh%|c-B*$()|8?v1%6U&6}@FOW#?vhqQ-Y9(gXD^1W5_mc?hx%Y<{G^RnEA zHP#>&!=t$+Ah1G-&jPnZj6&&V;BXiN1`bl*#@4nVs?PiT^a#@47^K!q%Ph>y)*oL; z^WO9E5h~l-+Y289U8&ta!^d}YauT&x@a7HjBneK(_QUh`$Y$RiR(txeu40=qGtRHF z6Q4z$P`w%to`XybxI*#`T-(7p^RO$5q&e$sFg!h-l8=#mFby9_*DN7d#@j<_gP45m z63o=3VRZcbwa_JcySrU&7?{LG_5VaopDPDlNmco`jB#k@hLaZUjf- ztc9yX{5ymRck$!z*lHW79CJignI5pi(MmuiQ$o#=)aO+~RiiLu!(@E~rUfuV&$I*% z!WixOi!#J2AV3VZH~>GeNZb))W4ml7>G6YDpOTMsTfJmOl0Olc(`t{`$@EuIug7r{ z$;!so=>iRHalnR^gX1Bv^Fm3#fcX##pf0;FW7f3V=z+R!PFf3_N^KGyC9LE%$ClnaeNDNbG$*PhoF8| z7`+*ooz2w+HJ;n34kr!9SPLuir`gc_fkT4I=>&0!6SckrrH+1M!AgJ5?!g-X0XXe| zVDnaBWCi9d7i{35wxB#spzgm^5OM$?0t7RdaTmYhGV}9mgYNYX;#%gaE^sw3vmh;i zRRxAWo!b;o8jKIM&@aIrw`8`1G;8)jAN}xLn5(*qCh zL}zEHZU{8m&VNQt~8x__EdMyNH<^6h@$1C5FOGQV=-(!9szcu^1_a!j^FSPo7aRjq+Y+_;=25LYSH9sK=N$?s= zJLD7SJE6TH_sq-3xFB`_M*{-g~7<~rIQdA0~+Kv298?1*z@dc*s!5z zV8|;j?$u2r95-0NonQn&YZ0!LWPaT)24ML;OeksSKxRPTL-s(EfhQVACwacTJoqg{ z%b+#ry#;b*c^PD@yGBOGFb=@~t-%km%@twN3!8!wO%AQZ6`~nKtpo`8{a=&LfYtz5 zvl|&0ylC+FsfC^>PZ1z9F*_R!g&ZCwU{MAXs}5&N2ky5=wyn{}Z+eVdw%$RFWq*VU z3YB;kLZQ=tT&#d?TKayp13v@h2_}e_Wo2CYx>wO(Z9$y!mX((~xVZ4@>$brx0J2XN za6DiPt%xBa(X!G~NZ1zMD0a4T%n6^CNVG`VXPL#^XWX2}%}3B9^~eAML%7T5%ouW4 ziS{5GWROn)E>loY9ECSDH04xZ^8?1E38=Xp|EUlmt5v z6#!fejQJP3Rlj|!z#WFu=!OnVlIc4SX>0P+k}b%(gN<2pRn?m!BG1Z>Plb_4&-M^( z=z*u+_rk)$hM-7*Cl5%(mD0NymoEbV;XF^(F(6aBxe;p6itizlMvoqN4nKMlJGR z2YHp1vw&Y=V$LHyIfK_*FR1o_nFQvcdC8{+mc*3vE6&1fCSju7n%aV5k^#Rvd;3l+ zOK>zbHF3P4!z>M$PrY?C2Hu*}>&qUv_kIGvOco_i=0JHSoHL+{nw!r;yi(HrA}Q*s zB%*U)wzHgNF<-vUElAV(iLb@&P8O7&5C&ur(KK804C zpPx9xj=*hMvoJ*7Em&-VKtnaW{jVS+0`pvb0K&Sc5#^UonQ8<6+mD(XK&5+t>!?9T zNh$E)BmGh_)CeVLZ6C9&=VORM|P8w_Cu*9*@&6O&c$LNT37`%!o-BB z$bg!QG7};0%gV}DS61AX`C!)*wi~pYJMCOtHj@}+H@RGF9(OU{pYglq+gbhNKzrt4 zcV7J&l_a8KOUJXaEKSzO&hJ%-wH++FR&B);CJW+(NJ9p`D7i#{xN=>D{N0oE^dR7K zykeVG3SPe^@U*;SiQjS?%nGDNtwG{;V2qfPpLuiVK@F74Q%jt(<5^gY-2JIN)L8(& zj=f)#Bfz)Zy}KYj^R=_H0OqCjZ{MK~5R(~-pSCw`%>mFb;NLKOL4m1=26fHSn;=~Y zd%WS`$wpP0dc+NXKASW5_wF09B_x+Af#qMp3zv?QK%q8v{zDlU#dQ{h9hVM4E!kFR z<`fkwG6a?gWG2ZJaNIy&Z>*8xRF3B&m?1Pc#Zjcjn>I7^BW9F}kN){(ho?CkAS*F8HWv1UzJhx~W_mWi2TJJ0{j&pnxINt$ zLCO3RanRMT@!wDNIXC2m#DGZScCF*U$H&L=@LFp!X`6+O!#TVFdc}qJA95u}(#pb3 zsb)r|*0DHPUj{5+k-n>tFlNZJ^8XzbIcPR7rbzzCLvE= zyQZ;L(hAtGb6#d^`Z8>^3EPdA#KfP4EvCWt1mJ5JA7c~08&H~E#sU6EeQ(E4nQ2qK+S>Wx^T zLpYZjvw}YH2ag29Vc&jeE)k?m`r?K6uu!38iG$u zi`Hj&4XhpQ$vv-&OD3`PkBWc){;hdiKtMss2<{XxH~^!Bq14u%?NC#(&qLy@+Khfu%} z5*0RUx_KDzx0Uy8>v)rTAkz67gp(-C|DlCuv+b%#dEsXr|-CF2|`3Az)WFgd% zgi$%*>WvwDC~9yP<+D7JKyg()grtY{Jy@rJuoi@YGiHyd^!~pQ3FtW_l25@73<^)u z(w32cD<&o<6<>S<&h=rvb-8KN+d8(ZV2A)qIkdm{i-*-Pf@!eI4bnc;p~3NiMRmqo z>=fUdrBNN4lo3t4;89*2kN|LwqSk%HKn%lkEL46f;{YFUFh}_`46wzc$dgAzy6dAMeKy(Y))vNUjtKDZ z0|p%A2Sa0H_cw*~GIrLJ2+*N^NV{USqnt8m=Bn{=FhJCb>ZYcpwffZ#fRhGHIPf_Q z((zGo@(Q*$5^&V^y!+KUgud=kAad4mHRVz|L;=ohJy4HMS$XXec*5*&toswz$!S@ zV9%dp-CbQT4^8*M%K+cKjp@5HaX~-wyVEfJ2BW8J(RqS5tEWVk+sH zzHS_>dqDVD4=ju{jE;}r*e-GHo9OE+0bK_WC8f)tqztFnd$?9c%H8d41Wa5?CAT=3 z6px-Ce5$V&80qk}GJ0sfEpDMM=uG^V@+rBatPJ393e2epq#ZLCmp_1h0oGbrJqB?G z_Hne+N#w!%iBpT(2hTQm>Yy_}fe`3hB_MM!u7EUvXU2?k8!CZB9JsAw z@A>&1Ljx~@&;YJ1li!xHgEIL1k;Z?Rl6jP7WoAZ;2bFFPIKruvCfK`b0G=*>({U6^)t$W7%8FQzYpi2(kL-22?)XwrlzKid~}fIZTWa{vabN#zIxS9armee%ofmwdeXwOG?VLepL=n_506(G@xXZ-P@7@!4?KJe^=Eu=dz)JnY!^;CUE8yn< zn*&Esv7&{=3!toOYx#iyd@|;~lIeR{@$A51t14qx8{pudqfxJpFcUzo(<>*@*gHauO>~iEkUuD zXmHdeuK$vJpnx;1f7&ayJh$vP5vngmNzJHQk4$y<4iWv(Xz$g;I`HV-JTa3_?*96=##Su!oT;=iwgs8CpokR`?d8XF0A2#D3Y z@D}45Do%k847Sb=`W|9^&%n$4w-4H}kBHryyyh$Be&OtK{tWY#W*B& zfy)~{plHO-XHFO3UC7-n(q=CZ;#~xc0Rp(qwV^j7&HO>>90#-w?n*{x=8Gf6FJHa@ z$W2U3(}jTmY9UZd!J7RsNLa6ckHBbrjPQMxj9%~g`zwXQV|<|02ymstKpTTP`2e_k zM7c?04IoF2)VFU1)~}M0x&It1M6F4LCw+lo7kC&@`bOPg4-%l)m)79chzJEBlwrY7 z-|IDr<{RiVkXY02Jh>3y5fHhEZ%s;09&HW2fB_EkgKQQMr`)iDxd$I*0F+brG0}Aq zEbO;V-;aS2Ok(U{mq0^EEoDG!^J}-0;4!GO7<+jK*}O;uo*#(IWWgxs;syllXj;h0 z36`iC89mbl9ob`T9QP8iBBkc7Eug}bse?dy_io6{$Bc2lV3F3~Oy-xH(_#4{$^#xm z&6_v!jB;1tW3bu=@+`}ff^Rte@PZG(0h!5C-raqN+w>jy5I_$*48Ay_lep4`Kab+= z^U!a{33d#@dNK@kTSh(yP&`X^w>Z8&ZY+WxYsZfCT_RYR=BVg_SrlrN2pgL%7`(NM zRD}Ka%K_&BeF{cF9Ubx~hQL!FP5f9}GlH@MTM4eqAz}w3k{2+{qBf$>&(F)B6e}=) zfLs1EU7z^Z5D$zm&BB=AW0m4MsXy2fzi~ zDnIAusz9s*?WCiCz2?^onATV;EnQl;~KEcBV=72`r4>ky>3^#!_D#;b~=GOWO zJS1@42}gZ>O3>AT*C;n{ooQ>c0LBzjPfkQ`5Ii+iPXANccZXx$hi|u2N(gbwE`)}% zXJjU2kCZ|sBT}+MHra&Ctjw$iGAesDtZd2(QAQaN@A-Y6=RJ?21&vW3u z@9+KnemTdO9hudQnO`rtEEteKxvRh>avdXk-2 z_MV>O>C>BNM9kXKfZ_fatYG|^`Q*u_u;)U{hh~orQV`S*Y&nhqr3gziRkUL|)Rcm1 ziru+|?v9_kX$>j?=sW~WOh@x&Z0x1W$~=PMQNwb2b?V1^9;wl6VQ5K9Nx1`ZvE8v> z0FW{_&*}h(dY7|cfL^*b{LNrrgsRd2A@uu5EDUMX{OW+Y=+VBY?>#kpOl<(oGd;MG zQN2D*3zR>kFn~z3Pm~z?QmAVHX^D-#!AfZ__sj$ zpa1`VF$4NPcv5nEy%h#od5EqKtPSpZTfeQZH^q@Ikh`%2XPyqVi72-ekti_7?AG=>_pUJ-R*?YVUxzoLJ(bR?;kdFWH*&b-zb_K|hQpV4zaOq%}cwuTHC zug4A6NVqWVPY6wR>nRTDt+t8%_(HY(N;JC?y#>&if@mDmC~1)x0^#q#6(Ds<1>->& zC-tX3LdgxwGA1S_!qO!`xWNY|bl7b@r|MAyw1ca|2Tlso7>~F(#`s=z_;MFfav>J2 zrCtGC{_0H$P@NbZhaJb>pl$1Xgh1Epqu$eP4jlMWLLc)L$`imU$g#Jv z$%07;Ub5gjtENUTX3KrYR8KFaxY!XhQ(=A??5Z;NMfiIOIXP^d9>@(E#4) zAf|6FoVoWToc>n(&T66Y3sVB5^c_XlLCQCv;eX>WXZ(uHoM37r5b9HzlK)mPQCzxkY9NhT^RJ6&v0NGYjml+vx zVp8awzwq(r1)75g%kf%(Qs``6yGA2fcrFs{hX89m?SK9LJ=Rpy^-l^+ zay&Dgm*q~Q zPCrDr8?LUbjE#-K`L6)<0iai1ot-eI0iN%15nq=qLVqC1MK8nK zF^U9Cr3j+tS9c0flOpnhH#Cl737tTPXh_;n?+)L``7VV30puPmi@VSK^NHeGR-(tP zF+Zf&)8axkfeS+eg>(U@{}Eo^B;a&DXAY+@x*P#l^UJ%fUT;vWqCWR|xs!mWo+8b) z`~#!~0*sWyCg0p{Xw&F@k$M|j3-KXMwa#qeQy~g@c3+Zwp1FwKiqO1yF1hdkwPsPIeaRmE z#Wbbn$+(1oz_!rPa6yBv>kad=M}6~Lc*<&zq;#7|?DC)lbqPG&lgCD0TBj(UrrsD` zXKNJa7Cn{oY2{;le>n)_*3yw3MYE6F)}r1W(bGyBD+`e~pzb_<+F|NcNcZqZIxfF_ z;)}~l@R$EFnrPw@3V!6)EX}rv(k~@;cCRvj5yfo>m5)PNoD>(A_iDLmxpeRKSmq(h zL(U(s<$kL>m(1`9r)*L0)Ee*{{+D4bkKI8*PX5r}KdV1%H*<-?zKehKx&j&1A7q;2 zg1wRR$f?8I3~Ouv-YYwI?##PA54 z)FWY#fV2sa&n!AVfPO+lL!lV{5_Ny8eSE4@X8(}!^!rEtJU&FgCrZhH;U7QNVNi-i zMz#>j-Z3#gUPpJI{GGbl(?u4RH~z^Lf-+>H#HTj_Z362hHZsy{1??SUDNauCi@|+` z7J-4e%%%Iz=%D8T*D!v61_pt^GqfkI2PF{Ih4O)AWM(#u`pT!%f3NpN$t}7YSrTgBU@WNr`ORe+9)iee%aoxwGOzTq8YW$1a7t>6{t$H4{He0`l;up1E2(UO!*-4(Ie(eNm zHo-3Qz^CmEogdSp23RufQis-f+t*(jzFhsxJi@ilaCtnV@!Z*nAs`9{0Bhjx%g@5Z zG=OLfLL5#aun3EG$fdtk@h$r(pwOH-e-%TAxrUMZMWdZIy15Hk2j@QDyJHs9Jb%dS1zhCz$2|X)tgY%av_3+ zTUCb!ylnAGVhwlaReer@$=^>$4vJ61$f#nQ%%lF9%l-^XfCKW!}-d{vavg{ z-pn}1u4SmIt8zkcsaH6=kTbCSf$A?5B{nv;o@Wn3!^^PT*&r-RH+}Y?;>P!eB+iU) z8OmjA0+XYR@@5($nU6WF*P zav%e`{Zj5a@tn`9n36w^*r4^`oDhxJU@a5-YI`)}1l0qepM?T1S;|W>=RY-uv|SVA804_NJvr)v*DRukk@|bSaL~lY~kJ+8BijboEp=^WY`KR&+2H&4pZ29YBTJB#& za-5tUGrnAQyo=)28%d@pA?*NCt?OtIvJ;4cQ6oA!If13x_OkA*+*+sKGvB!VHrcX4 zE}HMcDsLnU^te@SkLFZO{QN>7>)t|j#6Ewx8w6nZ?E%XRu<2o*@hnA^K2^FXWA`L(l2x4hHO|R+{lE_s( z+DFf1ENvc6LQARSGy=!L{L4px3YbIzg@cY zmpY+*r2AJDuUkN%%D46CN=o{$sR@Q=i^VI=4*W~ovI?gtO$A}zBn!$@b2DLCjfCgJ zhxMn4OG8X-nkLCdq*ce4)a7Db0!QDeL`;RJD#o;T_7p%B(#a zCY@{_%v+k9@7maux~vpEdj_k>)*_86w#zk?qoNX| zi*+#5xY*H__T$-UCKYL@s>&PZmAgImGas3(a{M9V)o$d!U+Q*#Y1gryiVyGK|GeHL zP5ZOD#wmp^siwr5$CG$-M`6OUx8(I}09`L`_31#-k+z{Dr;ew&%dE()P~hD0T)lgOz^o2LA7=2i6c;O&j-p-0&w$rt@OtYi8r)%#pE_p}+A?bqZOC`V8DpZiU? z$7q_sG#=Umuyb)nS7xcjPNzDl+~|}yD*2nr^ZKjS=LSYc$5Kt5m?h1!qdWdPe3XK; zcOsfmYU1u7mCDY%H?#2V=f)}RFcR;WjE$qC=x&RA3WG7V(Z2l|#A)^43o>_^AACK` zpEZ|r+x*f>T0=3{`Tbdie8f`MNNqT6o(mhxA$#Ju#TCM8g#j6dvbJqCn_$#^;-^pR z0Qh8NWJHj75B&Fn?BVgCCXyon>=RGw&~GHHWw^K$fz@^B`6Kl0&@j2DL2I7d-N0DM zcEm4^L4V^u>_pXZ63b|7gJp3-RP=F9wiY(o69sf3%2m)7z40;e{ZZCRhW|dp>zjP+=iSitfwg>j&42D$*H z6Tx}`<&qx-Xw2;FZ7cq{TmSpI?u8uK7jSOCWZiYO$Nu#VsdBSshMx=ep3isZOR%oB zTP-Z*>>a3o=E(U{VxzH_cVF4S?LehJk)jp@f-ns7aQBnE0L2P?(!siO0a%B4TMVy5 zlu1{B#epPKYB%Kb{Q2|v_)EY|f%`&$3#;!eh;NF)xhMlbouKoV2uEA93dOR-1jQWd zm#(Z8jKn(+j1PTNy>(;2u%uq_D2K0w)?4!Y64!?;i)xN1#)sC9D>`+y8g zZ-OUD60$4-)2rZwl`WN51Df|JkXe4hxNa0cV}TQ4q51DKyk~w+sqxP3ko&(l#|!)1 z+P8L16tDY>mU^ErDU9N@lSy5-p)XpSi)T7Ez9d_okMTR zJqX%zpekw%`}aSGt4mFd)crr;c@{)%YVzU(FL9SlQ$vG<>spKfYbaU>Xf>iKx;vkQ zQ49$Q316C{034oCLNF6LSn9lFwH$76I+DRghr_wN1ONa~Q$W8HbHi2A>M+u06<6S* zd2K6aPD9<+)&@p-JDgVNEwI7i){)tQ=EZejSn8r6mmrg(8Jdw(P#~!2s2>X^;SRdZ zR9slN7}Umf5Uj9jz+XRmmglv8)Q)oU)nSugGkr?{g;J!VoCOyoJN>ruA~7T!aW3`5 zOE^+9CM74cQM99bn30{`bd*vRJJ|n3NSBg1%?<#Z;x@x_ei>lpBgkKW1n7bh4RRj1 zw7NIt^2vDSzH=l$aBOWX`U)fjcLA)T6>jzJ0AxPJmOUOLtKcG_izCI#>TaNRgY_r= zmX*FETU$j(=W}ao7}C0+pQ#E5VJ)k_7et^wHEiVq{hC_sb%%b$vbsNos$Z@YfD3N* z;HK<2rX!FTd*!>jLsxsIR3n3I|D9y=f4Tf?vUhBt0kR;mAq5(b^q<56Ba@ZZ=-+C`Hzs zGZiguI&cY8@n~Y-hfG3BYUSgaCkZUN9$43D<-wO@OUAnbAJjrXsn>m3Tw|o&Dhmtw z#i;@&;Bb9J#3>3BBnHvRU;pzxzndQ=7PnO8r*zvCvd-)}cp;jt?yfcy+PmfYb4_J!2ma+bP5AwQto|abg!(*Yr|K{{H?(C+e;2M+5U4__o~2X%kdAp1oLZtk1NeGA3`YWmNl8uxtY z?9|rRXHYb3NX4V}u`S(xrQrELb}CgacXgwaN&}Y^t6-RK&2G zG!ZrdL0J~{t9S$+CaxH~xxyPmm^eaV7mh~Lwtij|)*A)?hL(he$^JypgG@ z<0Ft76z-$HlCxQj)58Sx@7`i-HvscROSK+ET)6mV6L*w7qoqxEg-qx#(B%lyWOt+i z>n{Yo{Td@n30|5(NKum@7aC`iJ!r`B_1(4U>8|2X+-f5u8w_M=0Uq6jVGQ0ohs8Bs zn!_PQ4Qi4=ll!;2Tce7LVJp6S&9@2k6gtPJ>Q+`iz=4I;LNY|czslr6 zh}-W`bvnLVgH0Xl7D7tP{Cjq2`TQL3kokezlh@_+&N)p)H%d!r*S3hOei?jz$L4m(z>VRVKT^pJj&2$z z+LY9fvhTIEFFI}$qRz`cAt9*qXNKTF0RIFxXves6kgV)j4*$`kQ1<4cw}?PYjB$Z?9oL5E&M-xsB-lXGgb6V=eXaE_ z9**uN5S(CYME2%uJJ=nIfkdVy0__OINrm=2UD);x4&R$KM}zjCgoq zt=WDqIMJqjvedL6h?OVZx@VLWbIuB3P}wLR9BpR)lyrAtK7{s%h~j`V?Pc1Co|zCN z`g^(@URS-H`P{5D_}te)D}fZo09vnVmU?F9Il(G7h;`gokf(Uxnm~n@%gpX|Pw((H z{lSc?kukJmyu5$pTrd@y3%blch9!)aPOBpfWpDQjJy6Vd z6LI4JI-DF<~P8;t`O@J__3x(CPu<34(#+UKiQ8zV>H;U8I!*qKHcH-2MY*nQ^I{SB8XpKmbG92RGIF>UjvHiYh{TaapS#pU-} zCOUSB^ArbAS3P`@n`_n|YicrI{FV!J^{V{Enxj6>BcSwJ&^&rthpZML!Rezo+c8hY zMELz%1CU(*Ai%2mr>{v`HEp*<3$Yi6#}&7~FHw-|v0ZDNzj=7eWKn3q1q&micb#;c z!pDF5nAFISOLN4zx4(K|IwLLXN*7Ya$KpHSKxeeX-nM>y^i0}>->$4jROD^Lnc$iaHm#LidW`1X z23<(m9LYWx$FrXwBzvpAC73zj9GCUB#1}^d5=yB2QuhveRa2}vcT8>FuY*o4c6%WI`mRA1q7jXe+7mU5?8)E2zn%XdI4BbfeZ{Rlo;FezH)oaJ)~TI zA#hUQ`FZAh^9MxB8}DPi_eS!%?VI&=>Jz2jfEAB9l=I7@EwgJJ{G zdkq}6i=BMD(naN{326Mh74ycn0->5R{tc-<`^E$QSKhuCuljU`R@?#)ggVXQge)NJ}fwe1gi+STg+0_UTp^7Z*J)4JG*7 z`%JUZXzvvIdv@s08>T6rg_%RHO=2dDUj22CR21nu$l~ta1%Rxx{YpjSW8%A?6SMiN$6IzZmJ>RLr4 zM4`q8C+viKAP`|q#>=2VkFwh zmxQggPNmiT+>%wO@Lzf8f1-c>3hDp+nFMzpoVSKhMh1CwaGxYD2H)*Imj;J{oXhe5 z>KzDNhO!j0_`y>~vj`2_dWkG`CDwnFdasV<`&>pkCaBbrT_VgIzR8OZi4F}t0S^LH z-7OI$?`{9pheEyEMO^v*T-(&y{PWaZq;34Wo{@SSf?^P7U65b^BWI?jt|DpCmn@;%<}xBb?~s<({+W>LN#LS zsU99x)^GlTGcF7Qj}ZNgaClr?85p@NZfkSs#a*JEOZcCc&xVsi=QEn)P_csJ=NVeh z1A^p{qeuUuf$Oq*h9AZqh0&`>`cQ;Se*fOwNjE{GNem8{_UGe49o*rnUOKB+aD*HG;rPfblF zbipQdd$_o`Ui~sqf|Ux2N>JnY-QZ54VG@?$Cqlr(fU$WB%K8xH&^rxdJ;0~;1*5&1 zP%+{*gNFvXf)$dc+{1n$Aq3+QnQ<6jYsrn^d@gFOSW@AAIJPB9e%hQxHJ69532w~> z%73`R>&5M^7N71dG}%eb)QV=Wp~_T%t_iz z@q<64g1v1N4+U!mGuEuEt*=O^QdgZx&%y|gLS%Q{CvTst`LQ}TKYtNfGoI*R-eqx+ zj7NSY1P;v8%Uf0qe=m&T8zp~q`b(Z8U&1n*iWZ! z>8q7+(#yZW3ITTF85^Z|(ogf7NZ>mFzjmX5vZ7hm@sqi>-76%&*s4C)}78ig&;N zHk_tRrOnaM&`?lN06lTdx;eDHAa`L}lVA9~sSfb8(V5`r3Z}HeB?c}pFB6;ORv%ql zE!_H{#8H2okZ>QY1MpL1L$N|J^Fh?&5*^yuZy+8ff-m~a*o9e*T{*S#K`WEDTS)cl z8)6h!e#@Cp&bX52?NfgqUp6-Wdg~(BQA4L?1p?utHIho+J$v+P*~J`dY18(gciD6+ zLtn%Dss$5=4QlOqcbI9|b~gC=`-|exLg1&QEjSg=Z=z7T=v-r;5#HHAwtF`!ZB4>> z5*8qkf-eTX!$1i;gYU#80St9>bNgH(@F~&W#u;?M>0vung~(1l0VWE47n7Ev98Ez! zVP@CcE1wzbyEd!ac)Tvvm9FaF>zg!Da?QmVBT_6yT-=I$YA3Gi?#$}(lZmCdCwLM^ zAY?c}S3@ziQk#lh_EBco*Xd~qLfC4j2FvsZc=Oq$9Y`0xeH($Foou$nLU27Bo$@fIO|+_u@|(_-Zz*CLWAOyoy)(uxYweN~Bdsd)Egoo4St^%b*uAF_01 zP|WUUT@*YC{Ev{7SKfCnQInB*hj~DcfKDJNLmn0!^3Cov8yg$a+Z^n2fMM`1!H|!d zsa-w=*4|-`1Y#na)k89#~+H~B+pX~I2#LH_fYhd8pqc0J_tpxY=>kd>P?0Y zhBOfpYrAPOduSlwq<@#3Ov+0+sN8maVNO(}oY#@WyOEeKd{q zVr#6u1VkMl%|8A2qNM~+QZhhHvI-t9Y>lnP&uFzhh_Yc-36e?$?JN34$12q;af?5C zR!7~QJl6kS*g^|;OT&Z;_e$qo+cnn0?|{xBXF`)5&Ulb{4s)Q+sez}XCCOLe6P zPaY0c;!Y7=V(t)<$~oVzPzsyCkLlGhDfiEjxhe(GZZ(^KY~dQ)A?79Em|kTT4r=8U zqLbLwp(iVLuqB$}a)s<|+Y6Hgfj0-%W+%=hQz<0ZIF2L8;Jzs6@`V5M<-Cy(MWVkn zIm@12ZTzmy0d0YfJl;ogo>C zhi~tRD~+%C(Qnk-C>nRG)OLd5++-*vqh74-SbF38S2Neo$3J(hJ<91fMSQ1iUNw>~ zlZum}nY*rvyG@K9pudx(2XM6l$yP3wr^8II zAQC!EZPBTA`130Zo|IswQ&X3iZY50Qdk_gLU{JYtYmTgBc^Rq*Z1jbWiQ!=nLA{F? zVHjEvi=bCpR#sM8ngDGMTHM;yUW5v#sR>}1gR`*gs5=R&&N^BtHwD5(a{?2pPW&7h zIR#QRru4wD7JY-=oF=yj2QXm2AXlKRc@;iQnWYdi;=Hi3)JfhQ02MLh$L{nF$Pbc2bT760%jhh+^;zBxAU@+eMKrx0Y=ExBySVD z>Fl?Ucy+@gc-<_IVDcG0^UikJ4(=3%O%IN6b6-`^(TRG$$xmz|Bp4X!7UckoHT)ag9lCM8ak~#{;s9-y*(V3_R!| z^Q=bvJ%Na;u}ufS(m}K9H*Tc(uzKy)NENvaI1w__&5C1iO+l~bZQjn+lUDj;516=z zx9GCU1yOg3fV_E79y4T=^z~QZR)C1rQ*`|yKx!|T)_FHH2p-+^861QAa2Mj;v17-u z_X(DR__GU!(H&q^;F(gKB1^!3aF+eTMHP+%T`f411BDPAT3$XmDnoW8O~K}ETN`u% z4;q3mm71$*XvChpzMJ@bc0x|Wg28?`HJ^yp$N?N%s|?uAr&qvw^r$MNm42kK6}1S= z0So9Jgv4?c*c@)Lv5)o^plna;P~8I>+_!i=XFGs@Imr&cCXZL!*wR8twVN|-;N>81 zrL(ZGaApgHy@e24Z(%vPx%kV1DZu?hD%05o6G((JK{X<>G1&~1B3KX(8Kni(FXRpz z21|&ckD?Qz6_9(v;xpg0!sSbs^bV8n?0zsjgAuX%huIFD4njUS4i-41G8@YyAt-~8 zP~B`x+x0V>Ed*&EJIB1eQ3W0jt%R()j(8HAf$7KY@gDa2)z@wIe62^KIafZSqBT{{q7T&g3it9 z)5L58*pGxBv03XtKMdgO!7;Em7)T^`PP~GuUh>C}A0o`LCyc6)-4x6{>4sX_@<#WI zf~>15)hY0UD5va$;EZmnP!cO}C)Ivy-&;F&8cjdjC5!QU0b2c!EclY(uuJ9Ve*tWX*9eHArFK|F(1 zY8359XUFd@hS(YTgo`f6#RJG!BcEZFxRHZ{jz|{$i~a72`d!V zCCodObA6|<;CH8!a<4MJv0fF$)(E-A>naq?`wHH=iP7 zz<19G52uzrWczT9m@?DBvG=OFjDhULTb&2p2f(gmVtOyW(zq|!41(#1@ZFMcn0@8C zQyjVu$Y44y#$0DhQgeykQ@>9gq|D`?z<-^6LI!%E7w29mib$#xe>mg4>CjMCKJHPGv*+ z()@&MKUVrR3*^;m*jwbicM6`Hs&S+vFxG6s8@hLp`Ofs0?mr2Q4O#dl2yz}iR6$mX z*)w{Katd^tK`4irVj!7c9X1Y9O!`q@k#3l*!Hf;l{(hqfym6Vq)6fu_9{@TpTnr!}}Al@6h=;=hD3oFI;$f*62(Z={xhib^qyoc!bW+mNYOJJ8K*OrC%fx)Q*Px8RJ|f4I zS5a~lRNJa!Oz(;cPFkqFho)-8EdU0vGA_siWDn|9INj4{R-~k*^-g^NU6sw)os{(S zT4rY*|Fox1t$K>aunsRhW+h`Ie2%tKG7q2B>_;kn`jxHYT@PP zw(NTqbZ-vMTm9Hn;1Q~yx;x$i?;W(cV2P58h91Hh&Rl2>u*i-%M7Ur`p5xL7PbgNt zj>Qm0VC9adq&m@lTVZjB8Q`(w$MZ8YZyFn;=)hq6aKf5|BtRyJ#)bS={Epc7e)~YQ zsxd^|vxD?(E#!;agQaCQ57ul9XJ(_I(d7s{eB?+e?hLVOtooM`L*T=PIe$klzHlD} z&;q+LXehhj?n;)$le_yS0k}NW6KY*xtT;-Cx1K4a zyoUOfaKenjP%oU35FMy%Xc#A?%{CG_-2n_y)sU4nM%=u)E8*K4ei=$_(99`!2IN3p zon$ZEj9MvlY$krQ7C$BXC=i=YMG1=_|jfN!^<}b6<_%mdNNm@T(^)vrx zvDbFOL;rk?H69u#Cnu;2nlRiy+(9-|=zB>~LE+WSR%5vS-n?pQN%Rb^U*9f@A)!=% zuHD-7xkb|isxk}@eLuAV>mEX63}izZhY%uf$__T#{3@}+8}4d;#jeYYg~ z5fMSf0N}8j{p!`LZ#4-|L;b^eYS25r7S(TC?XfZ(O?uGkz{k zPFde_=4_?M=y}{&IE}MT`92WVa;kS^|2-P_+#Lt8BpczQNXOBUnzgiR)Kc&T+l?)VON=@C|p#?E?l{i z6vX;&pOiZ?1z-(5F|;@a(9s6A&1Veo&N-B`XZ!hb72I-rw9%0#boVwllQ(%S%+l#S z!?XxM(|~6X!7sT~A#{vH6zp8i%FYH!gl8iGm+eHXsN1+bNZXreHtl=o*x)w3A_Z#D zNg*LA*EJD~m5El47jnVZOcYN)z!8p4aNDO(?oq%-FbEoWa(&qVVvVe#qVKMKk8t2B z270!2YcPi%-onC26|q?y4}5qlYM*#%x5P%;?*JkT(Ci)jg`-d@D3cgL$LH5sl<`8^ zZT;YJnH}Dbz@S*@(*E?-4$sOmk<~Ao5CGuF|9{Mc|NsB>yZUb*H+I$ytzf3g#T}dQ OPg(xFT%PPzkN*Wni%g#Y literal 0 HcmV?d00001 diff --git a/moses/latentgan/README.md b/moses/latentgan/README.md new file mode 100644 index 0000000..ccc021d --- /dev/null +++ b/moses/latentgan/README.md @@ -0,0 +1,26 @@ +LatentGAN +========= +

+ +

+ +LatentGAN [1] with heteroencoder trained on ChEMBL 25 [2], which encodes SMILES strings into latent vector representations of size 512. A Wasserstein Generative Adversarial network with Gradient Penalty [3] is then trained to generate latent vectors resembling that of the training set, which are then decoded using the heteroencoder. This model uses the Deep-Drug-Coder heteroencoder implementation [4]. + + +Important! +========== +Currently, the Deep-Drug-Coder [4] and its dependency package molvecgen [5] are not available in pypi, these have to be installed from there respective repositories (links provided below). + +The pretrained models of the LatentGAN are currently not shared in this repository due to file size constraints. These will be added in the near future. + +## References + +[1] [A De Novo Molecular Generation Method Using Latent Vector Based Generative Adversarial Network](https://chemrxiv.org/articles/A_De_Novo_Molecular_Generation_Method_Using_Latent_Vector_Based_Generative_Adversarial_Network/8299544) + +[2] [ChEMBL](https://www.ebi.ac.uk/chembl/) + +[3] [Improved training of Wasserstein GANs](https://arxiv.org/abs/1704.00028) + +[4] [Deep-Drug-Coder](https://github.com/pcko1/Deep-Drug-Coder) + +[5] [molvecgen](https://github.com/EBjerrum/molvecgen) diff --git a/moses/latentgan/__init__.py b/moses/latentgan/__init__.py new file mode 100644 index 0000000..bc77b73 --- /dev/null +++ b/moses/latentgan/__init__.py @@ -0,0 +1,5 @@ +from .config import get_parser as latentGAN_parser +from .model import LatentGAN +from .trainer import LatentGANTrainer + +__all__ = ['latentGAN_parser', 'LatentGAN', 'LatentGANTrainer'] diff --git a/moses/latentgan/config.py b/moses/latentgan/config.py new file mode 100644 index 0000000..47d4e37 --- /dev/null +++ b/moses/latentgan/config.py @@ -0,0 +1,92 @@ +import argparse + + +def get_parser(parser=None): + if parser is None: + parser = argparse.ArgumentParser() + + # Model + model_arg = parser.add_argument_group('Model') + model_arg.add_argument("--heteroencoder_version", type=str, default='new', + help="Which heteroencoder model version to use") + # Train + train_arg = parser.add_argument_group('Training') + + train_arg.add_argument('--gp', type=int, default=10, + help='Gradient Penalty Coefficient') + train_arg.add_argument('--n_critic', type=int, default=5, + help='Ratio of discriminator to' + ' generator training frequency') + train_arg.add_argument('--train_epochs', type=int, default=2000, + help='Number of epochs for model training') + train_arg.add_argument('--n_batch', type=int, default=64, + help='Size of batch') + train_arg.add_argument('--lr', type=float, default=0.0002, + help='Learning rate') + train_arg.add_argument('--b1', type=float, default=0.5, + help='Adam optimizer parameter beta 1') + train_arg.add_argument('--b2', type=float, default=0.999, + help='Adam optimizer parameter beta 2') + train_arg.add_argument('--step_size', type=int, default=10, + help='Period of learning rate decay') + train_arg.add_argument('--latent_vector_dim', type=int, default=512, + help='Size of latentgan vector') + train_arg.add_argument('--gamma', type=float, default=1, + help='Multiplicative factor of' + ' learning rate decay') + train_arg.add_argument('--n_jobs', type=int, default=1, + help='Number of threads') + train_arg.add_argument('--n_workers', type=int, default=1, + help='Number of workers for DataLoaders') + + # Arguments used if training a new heteroencoder + heteroencoder_arg = parser.add_argument_group('heteroencoder') + + heteroencoder_arg.add_argument('--heteroencoder_layer_dim', type=int, + default=512, + help='Layer size for heteroencoder ' + '(if training new heteroencoder)') + heteroencoder_arg.add_argument('--heteroencoder_noise_std', type=float, + default=0.1, + help='Noise amplitude for heteroencoder') + heteroencoder_arg.add_argument('--heteroencoder_dec_layers', type=int, + default=4, + help='Number of decoding layers' + ' for heteroencoder') + heteroencoder_arg.add_argument('--heteroencoder_batch_size', + type=int, default=128, + help='Batch size for heteroencoder') + heteroencoder_arg.add_argument('--heteroencoder_epochs', type=int, + default=100, + help='Number of epochs for heteroencoder') + heteroencoder_arg.add_argument('--heteroencoder_lr', type=float, + default=1e-3, + help='learning rate for heteroencoder') + heteroencoder_arg.add_argument('--heteroencoder_mini_epochs', type=int, + default=10, + help='How many sub-epochs to ' + 'split each epoch for heteroencoder') + heteroencoder_arg.add_argument('--heteroencoder_lr_decay', + default=True, action='store_false', + help='Use learning rate decay ' + 'for heteroencoder ') + heteroencoder_arg.add_argument('--heteroencoder_patience', type=int, + default=100, + help='Patience for adaptive learning ' + 'rate for heteroencoder') + heteroencoder_arg.add_argument('--heteroencoder_lr_decay_start', type=int, + default=500, + help='Which sub-epoch to start decaying ' + 'learning rate for heteroencoder ') + heteroencoder_arg.add_argument('--heteroencoder_save_period', type=int, + default=100, + help='How often in sub-epochs to ' + 'save model checkpoints for' + ' heteroencoder') + + return parser + + +def get_config(): + parser = get_parser() + return parser.parse_known_args()[0] diff --git a/moses/latentgan/model.py b/moses/latentgan/model.py new file mode 100644 index 0000000..d4ae374 --- /dev/null +++ b/moses/latentgan/model.py @@ -0,0 +1,226 @@ +import torch.nn as nn +import numpy as np +import torch +from ddc_pub import ddc_v3 as ddc +import os +from rdkit import Chem +import sys +from torch.utils import data +import torch.autograd as autograd + + +class LatentGAN(nn.Module): + def __init__(self, vocabulary, config): + super(LatentGAN, self).__init__() + self.vocabulary = vocabulary + self.generator = Generator() + self.model_version = config.heteroencoder_version + self.discriminator = Discriminator() + self.sample_decoder = None + self.model_loaded = False + self.new_batch_size = 256 + # init params + cuda = True if torch.cuda.is_available() else False + if cuda: + self.discriminator.cuda() + self.generator.cuda() + self.Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor + + def forward(self, n_batch): + out = self.sample(n_batch) + return out + + def encode_smiles(self, smiles_in, encoder=None): + + model = load_model(model_version=encoder) + + # MUST convert SMILES to binary mols for the model to accept them + # (it re-converts them to SMILES internally) + mols_in = [Chem.rdchem.Mol.ToBinary(Chem.MolFromSmiles(smiles)) + for smiles in smiles_in] + latent = model.transform(model.vectorize(mols_in)) + + return latent.tolist() + + def compute_gradient_penalty(self, real_samples, + fake_samples, discriminator): + """Calculates the gradient penalty loss for WGAN GP""" + # Random weight term for interpolation between real and fake samples + alpha = self.Tensor(np.random.random((real_samples.size(0), 1))) + + # Get random interpolation between real and fake samples + interpolates = (alpha * real_samples + + ((1 - alpha) * fake_samples)).requires_grad_(True) + d_interpolates = discriminator(interpolates) + fake = self.Tensor(real_samples.shape[0], 1).fill_(1.0) + + # Get gradient w.r.t. interpolates + gradients = autograd.grad( + outputs=d_interpolates, + inputs=interpolates, + grad_outputs=fake, + create_graph=True, + retain_graph=True, + only_inputs=True, + )[0] + gradients = gradients.view(gradients.size(0), -1) + gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() + + return gradient_penalty + + @property + def device(self): + return next(self.parameters()).device + + def sample(self, n_batch, max_length=100): + if not self.model_loaded: + # Checking for first batch of model to only load model once + print('Heteroencoder for Sampling Loaded') + self.sample_decoder = load_model(model_version=self.model_version) + # load generator + + self.Gen = self.generator + self.Gen.eval() + + self.D = self.discriminator + torch.no_grad() + cuda = True if torch.cuda.is_available() else False + if cuda: + self.Gen.cuda() + self.D.cuda() + self.S = Sampler(generator=self.Gen) + self.model_loaded = True + + if n_batch <= 256: + print('Batch size of {} detected. Decoding ' + 'performs poorly when Batch size != 256. \ + Setting batch size to 256'.format(n_batch)) + # Sampling performs very poorly on default sampling batch parameters. + # This takes care of the default scenario. + if n_batch == 32: + n_batch = 256 + + latent = self.S.sample(n_batch) + sanitycheck = self.D(latent) + print('mean latent values') + print(torch.mean(latent)) + print('var latent values') + print(torch.var(latent)) + print('generator loss of sample') + print(-torch.mean(sanitycheck)) + latent = latent.detach().cpu().numpy() + + if self.new_batch_size != n_batch: + # The batch decoder creates a new instance of the decoder + # every time a new batch size is given, e.g. for the + # final batch of the generation. + self.new_batch_size = n_batch + self.sample_decoder.batch_input_length = self.new_batch_size + lat = latent + + sys.stdout.flush() + + smi, _ = self.sample_decoder.predict_batch(lat, temp=0) + return smi + + +def load_model(model_version=None): + # Import model + currentDirectory = os.getcwd() + + if model_version == 'chembl': + model_name = 'chembl_pretrained' + elif model_version == 'moses': + model_name = 'moses_pretrained' + elif model_version == 'new': + model_name = 'new_model' + else: + print('No predefined model of that name found. ' + 'using the default pre-trained MOSES heteroencoder') + model_name = 'moses_pretrained' + + path = '{}/moses/latentgan/heteroencoder_models/{}' \ + .format(currentDirectory, model_name) + print("Loading heteroencoder model titled {}".format(model_version)) + print("Path to model file: {}".format(path)) + model = ddc.DDC(model_name=path) + sys.stdout.flush() + + return model + + +class LatentMolsDataset(data.Dataset): + def __init__(self, latent_space_mols): + self.data = latent_space_mols + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + return self.data[index] + + +class Discriminator(nn.Module): + def __init__(self, data_shape=(1, 512)): + super(Discriminator, self).__init__() + self.data_shape = data_shape + + self.model = nn.Sequential( + nn.Linear(int(np.prod(self.data_shape)), 512), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(512, 256), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(256, 1), + ) + + def forward(self, mol): + validity = self.model(mol) + return validity + + +class Generator(nn.Module): + def __init__(self, data_shape=(1, 512), latent_dim=None): + super(Generator, self).__init__() + self.data_shape = data_shape + + # latent dim of the generator is one of the hyperparams. + # by default it is set to the prod of data_shapes + self.latent_dim = int(np.prod(self.data_shape)) \ + if latent_dim is None else latent_dim + + def block(in_feat, out_feat, normalize=True): + layers = [nn.Linear(in_feat, out_feat)] + if normalize: + layers.append(nn.BatchNorm1d(out_feat, 0.8)) + layers.append(nn.LeakyReLU(0.2, inplace=True)) + return layers + + self.model = nn.Sequential( + *block(self.latent_dim, 128, normalize=False), + *block(128, 256), + *block(256, 512), + *block(512, 1024), + nn.Linear(1024, int(np.prod(self.data_shape))), + # nn.Tanh() # expecting latent vectors to be not normalized + ) + + def forward(self, z): + out = self.model(z) + return out + + +class Sampler(object): + """ + Sampling the mols the generator. + All scripts should use this class for sampling. + """ + + def __init__(self, generator: Generator): + self.G = generator + + def sample(self, n): + # Sample noise as generator input + z = torch.cuda.FloatTensor(np.random.uniform(-1, 1, + (n, self.G.latent_dim))) + # Generate a batch of mols + return self.G(z) diff --git a/moses/latentgan/trainer.py b/moses/latentgan/trainer.py new file mode 100644 index 0000000..69a57ff --- /dev/null +++ b/moses/latentgan/trainer.py @@ -0,0 +1,271 @@ +import torch +import torch.optim as optim +import numpy as np +import pandas as pd +from tqdm.auto import tqdm +from collections import Counter +import os +import sys +from moses.interfaces import MosesTrainer +from moses.utils import CharVocab, Logger +from .model import LatentMolsDataset +from .model import load_model +from .model import Generator +from .model import Discriminator +from .model import Sampler +from rdkit import Chem +from ddc_pub import ddc_v3 as ddc + + +class LatentGANTrainer(MosesTrainer): + + def __init__(self, config): + self.config = config + self.latent_size = self.config.latent_vector_dim + self.generator = Generator(data_shape=(1, self.latent_size)) + self.discriminator = Discriminator(data_shape=(1, self.latent_size)) + cuda = True if torch.cuda.is_available() else False + if cuda: + self.discriminator.cuda() + self.generator.cuda() + + def _train_epoch(self, model, tqdm_data, + optimizer_disc=None, optimizer_gen=None): + if optimizer_disc is None: + model.eval() + optimizer_gen = None + else: + model.train() + self.Sampler = Sampler(generator=self.generator) + + postfix = {'generator_loss': 0, + 'discriminator_loss': 0} + disc_loss_batch = [] + g_loss_batch = [] + + for i, real_mols in enumerate(tqdm_data): + + real_mols = real_mols.type(model.Tensor) + if optimizer_disc is not None: + optimizer_disc.zero_grad() + fake_mols = self.Sampler.sample(real_mols.shape[0]) + + real_validity = self.discriminator(real_mols) + fake_validity = self.discriminator(fake_mols) + # Gradient penalty + gradient_penalty = model.compute_gradient_penalty( + real_mols.data, fake_mols.data, self.discriminator) + + d_loss = -torch.mean(real_validity) \ + + torch.mean(fake_validity) \ + + self.config.gp * gradient_penalty + + disc_loss_batch.append(d_loss.item()) + + if optimizer_disc is not None: + + d_loss.backward() + optimizer_disc.step() + + # Train the generator every n_critic steps + if i % self.config.n_critic == 0: + # ----------------- + # Train Generator + # ----------------- + optimizer_gen.zero_grad() + # Generate a batch of mols + fake_mols = self.Sampler.sample(real_mols.shape[0]) + + # Loss measures generator's ability to + # fool the discriminator + # Train on fake images + fake_validity = self.discriminator(fake_mols) + g_loss = -torch.mean(fake_validity) + + g_loss.backward() + optimizer_gen.step() + + g_loss_batch.append(g_loss.item()) + postfix['generator_loss'] = np.mean(g_loss_batch) + + postfix['discriminator_loss'] = np.mean(disc_loss_batch) + tqdm_data.set_postfix(postfix) + postfix['mode'] = 'Eval' if optimizer_disc is None else 'Train' + return postfix + + def _train(self, model, train_loader, val_loader=None, logger=None): + + device = model.device + optimizer_disc = optim.Adam(self.discriminator.parameters(), + lr=self.config.lr, + betas=(self.config.b1, self.config.b2)) + optimizer_gen = optim.Adam(self.generator.parameters(), + lr=self.config.lr, + betas=(self.config.b1, self.config.b2)) + scheduler_disc = optim.lr_scheduler.StepLR(optimizer_disc, + self.config.step_size, + self.config.gamma) + scheduler_gen = optim.lr_scheduler.StepLR(optimizer_gen, + self.config.step_size, + self.config.gamma) + sys.stdout.flush() + + for epoch in range(self.config.train_epochs): + scheduler_disc.step() + scheduler_gen.step() + + tqdm_data = tqdm(train_loader, + desc='Training (epoch #{})'.format(epoch)) + + postfix = self._train_epoch(model, tqdm_data, + optimizer_disc, optimizer_gen) + if logger is not None: + logger.append(postfix) + logger.save(self.config.log_file) + + if val_loader is not None: + tqdm_data = tqdm(val_loader, + desc='Validation (epoch #{})'.format(epoch)) + postfix = self._train_epoch(model, tqdm_data) + if logger is not None: + logger.append(postfix) + logger.save(self.config.log_file) + + # Putting the generator and discriminator weights onto + # the model instance for checkpoint purposes + # The parameters of the model are + # the weights used during the sampling stage later. + + gen_params = self.generator.named_parameters() + disc_params = self.discriminator.named_parameters() + model_params = dict(model.named_parameters()) + for name, param in gen_params: + iter_name = 'generator.{}'.format(name) + if iter_name in model_params: + model_params[iter_name].data.copy_(param.data) + for name, param in disc_params: + iter_name = 'discriminator.{}'.format(name) + if iter_name in model_params: + model_params[iter_name].data.copy_(param.data) + model.load_state_dict(model_params, strict=False) + sys.stdout.flush() + if (self.config.model_save is not None) and \ + (epoch % self.config.save_frequency == 0): + model = model.to('cpu') + torch.save( + model.state_dict(), + self.config.model_save[:-3] + '_{0:03d}.pt'.format(epoch) + ) + model = model.to(device) + + def get_vocabulary(self, data): + return CharVocab.from_data(data) + + def get_collate_fn(self, model): + device = self.get_collate_device(model) + + def collate(data): + tensors = torch.tensor([t for t in data], + dtype=torch.float64, device=device) + return tensors + + return collate + + def _get_dataset_info(self, data, name=None): + df = pd.DataFrame(data) + maxlen = df.iloc[:, 0].map(len).max() + ctr = Counter(''.join(df.unstack().values)) + charset = '' + for c in list(ctr): + charset += c + return {"maxlen": maxlen, "charset": charset, "name": name} + + def fit(self, + model, + train_data, + val_data=None): + + logger = Logger() if self.config.log_file is not None else None + + if self.config.heteroencoder_version == 'new': + # Train the heteroencoder first + print("Training heteroencoder.") + currentDirectory = os.getcwd() + path = '{}/moses/latentgan/heteroencoder_models/new_model' \ + .format(currentDirectory) + encoder_checkpoint_path = \ + '{}/moses/latentgan/heteroencoder_models/checkpoints/' \ + .format(currentDirectory) + # Convert all SMILES to binary RDKit mols to be + # compatible with the heteroencoder + heteroencoder_mols = [Chem.rdchem.Mol + .ToBinary(Chem.MolFromSmiles(smiles)) + for smiles in train_data] + # Dataset information + dataset_info = self._get_dataset_info( + train_data, name="heteroencoder_train_data") + # Initialize heteroencoder with default parameters + heteroencoder_model = ddc.DDC(x=np.array(heteroencoder_mols), + y=np.array(heteroencoder_mols), + dataset_info=dataset_info, + scaling=False, + noise_std=self.config. + heteroencoder_noise_std, + lstm_dim=self.config. + heteroencoder_layer_dim, + dec_layers=self.config. + heteroencoder_dec_layers, + td_dense_dim=0, + batch_size=self.config. + heteroencoder_batch_size, + codelayer_dim=self.latent_size) + # Train heteroencoder + heteroencoder_model.fit(epochs=self.config.heteroencoder_epochs, + lr=self.config.heteroencoder_lr, + model_name="new_model", + mini_epochs=self.config. + heteroencoder_mini_epochs, + patience=self.config. + heteroencoder_patience, + save_period=self.config. + heteroencoder_save_period, + checkpoint_dir=encoder_checkpoint_path, + gpus=1, + use_multiprocessing=False, + workers=1, + lr_decay=self.config. + heteroencoder_lr_decay, + sch_epoch_to_start=self.config. + heteroencoder_lr_decay_start) + + heteroencoder_model.save(path) + + heteroencoder = load_model( + model_version=self.config.heteroencoder_version) + print("Training GAN.") + mols_in = [Chem.rdchem.Mol.ToBinary( + Chem.MolFromSmiles(smiles)) for smiles in train_data] + latent_train = heteroencoder.transform( + heteroencoder.vectorize(mols_in)) + # Now encode the GAN training set to latent vectors + + latent_train = latent_train.reshape(latent_train.shape[0], + self.latent_size) + + if val_data is not None: + mols_val = [Chem.rdchem.Mol.ToBinary(Chem.MolFromSmiles(smiles)) + for smiles in val_data] + latent_val = heteroencoder.transform( + heteroencoder.vectorize(mols_val)) + latent_val = latent_val.reshape(latent_val.shape[0], + self.latent_size) + + train_loader = self.get_dataloader(model, + LatentMolsDataset(latent_train), + shuffle=True) + val_loader = None if val_data is None else self.get_dataloader( + model, LatentMolsDataset(latent_val), shuffle=False + ) + + self._train(model, train_loader, val_loader, logger) + return model diff --git a/moses/models_storage.py b/moses/models_storage.py index 0a651fa..7ff0384 100644 --- a/moses/models_storage.py +++ b/moses/models_storage.py @@ -2,6 +2,7 @@ from moses.organ import ORGAN, ORGANTrainer, organ_parser from moses.aae import AAE, AAETrainer, aae_parser from moses.char_rnn import CharRNN, CharRNNTrainer, char_rnn_parser +from moses.latentgan import LatentGAN, LatentGANTrainer, latentGAN_parser class ModelsStorage(): @@ -12,6 +13,8 @@ def __init__(self): self.add_model('char_rnn', CharRNN, CharRNNTrainer, char_rnn_parser) self.add_model('vae', VAE, VAETrainer, vae_parser) self.add_model('organ', ORGAN, ORGANTrainer, organ_parser) + self.add_model('latentgan', LatentGAN, LatentGANTrainer, + latentGAN_parser) def add_model(self, name, class_, trainer_, parser_): self._models[name] = {'class': class_, diff --git a/scripts/table_config.csv b/scripts/table_config.csv index 02a9660..e11a32f 100644 --- a/scripts/table_config.csv +++ b/scripts/table_config.csv @@ -1,5 +1,6 @@ name,path Train,../data/samples/train_metrics.csv +LatentGAN,../data/samples/LatentGAN.csv CharRNN,../data/samples/char_rnn/metrics_char_rnn_1.csv CharRNN,../data/samples/char_rnn/metrics_char_rnn_2.csv CharRNN,../data/samples/char_rnn/metrics_char_rnn_3.csv @@ -9,4 +10,5 @@ AAE,../data/samples/aae/metrics_aae_3.csv VAE,../data/samples/vae/metrics_vae_1.csv VAE,../data/samples/vae/metrics_vae_2.csv VAE,../data/samples/vae/metrics_vae_3.csv -JTN-VAE,../data/samples/jtn_metrics.csv \ No newline at end of file +JTN-VAE,../data/samples/jtn_metrics.csv +