From 6f3240c7239f316cdd664c4fa0290188d11ee30c Mon Sep 17 00:00:00 2001 From: LYFxiaoan Date: Thu, 17 Jul 2025 17:54:53 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8C=89=E8=8D=AF=E5=93=81=E8=AE=AD=E7=BB=83-?= =?UTF-8?q?=E9=A2=84=E6=B5=8B=E8=B7=91=E9=80=9A?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- prediction_history.db | Bin 307200 -> 450560 bytes server/core/config.py | 57 ++++++++++++------ .../__pycache__/mlstm_model.cpython-313.pyc | Bin 13511 -> 13511 bytes server/predictors/model_predictor.py | 6 +- server/trainers/kan_trainer.py | 40 +++++++++++- 项目快速上手指南.md | 11 ++++ 6 files changed, 92 insertions(+), 22 deletions(-) diff --git a/prediction_history.db b/prediction_history.db index 16cbe924d9fa39ac21b463885e535af512a709a3..2295ab09db1e02fac8465e8550f806efa50654dd 100644 GIT binary patch delta 34294 zcmeHwd5|1enP+8XR(Gq_DyglbCAB45a_i_LOL-qHpEKBo#mJI;U@Y5Gow975mV5|< zN(~qW2gWu9P8$#i3z2nRG-uK>jzC8K-Stp+wzHc~pB9qA+yno~Z>BQrEKEJqc zNB(AHWDYh4mK-bO{~`aM@;}IbH~*ddQ~7V?zm|VI|KM;SlcOZd5L4}RWu z1AhL^JMr_@cVse8{A%qADW643TKr@DX0|Q0slgk&&!)apUfh z&GpRh`hVL0Wd6PdhoyiDZ4}o^nAJJWKUo2H@!Q1FUb9u z+~4Ld&z+V3`+@fiRB|6yUQu4me>3-F{^9)H`6KyjlpiZ!&)uo~h4SDNf3jz#vVMuO zW~rpRnxW~wcC38xcsk>b#*u?7rQ7MZe%o2fnjtE!xvt?H*njlcfm_?Nj;&VKY~oUu z>FI|KAKQ0e-@B@n-8UBx9t~U93`g^jaZ5Vm$g5XA93dOlJ`~mtu|L>&&;FHfI}qVS zy)c^l!;+fonFikT=BNxI$z^V8&)U}*IkZx`k$(5=jlfZ>wf3xPdsd}AtK6PdqE40- z+W~tB@b=y9fE(Jgu5Zt}u03nl`v&)ZaB#1LzbyXbnZdov@4S4j@7c^iu<%R$-(2wh zzGt5p)V|l}PVc{AN#?3GZyz$(Nf!+cE?+)9+zm+u7iFc*!9BMr>&|w24b8As%~f^x zO5L8a-6>rwXolsRTJVT0P3-Q1L<=_^J#=t-|Mh8HUX^NNM0+M&tSQ?nfU-@^BV5M! zbkm?fX5V!J88sT1ZZ?JpnPnUG*0`jH2xw^qPxo!z7HDT zvVy5OzU447W8E~!$jk&^R`$sQi?F+8GsiR1jhW+(=_Lo8>HT}!kSzM1Rn-jDEo)`$ z)26SMT(hF;rLt+)Y}cvS!G&3AtkarurZkP_(bO#8F&M|PH+y?!#_`nh-AieKpK}4` z@$Th=r`&tbnI7)ue4#PBs?+iYDw^T@njaJ=m1Ti;tuj>53bt(%MD4v*5d@z|6aZ-_{_y}KDg_P(!=#m zBh}1HbTG$qo$7YV>u+VwTnH#S61v zll$_2t$a?pc=2NcJA2NR&R+P;z?}n`+$9UI%RZMsAN=6&8t0F!T`L{zS=xW^f=i{R z@;_Vn=;9AZFJzzVH5UAOv9zeJ9RHuY@wk+RO zJP29DftRCEgzjEd!hx8o+QiVmy{RWG|W z)hd-Sl)C1r)l$hcymA>lxwD@NT82||P19VX>cqcWjs+%PFV|F~XqPa~MYB}aZf^9Q zbEhJ|oY@(Cf3Gr{+lI%n##heSrYPGpW5EaaE2E8#qvJ|$4N@;&gVY_PUr=&~P-xwT zZA$KGWM6a{vV#rdN>3H3joXlVVZ*zXo*AV2UzSH3_iP+jdX3D^#y2)@pog*Gndgw* zw+Y!lLF(&IBX#YjcPo86(C$BOx?NeYHgjk2%mv6kzWEEvg2UAAKW9f9myC@o3w{k` z_dK2*&5hx4tZ`~=o6`Rz9)2_Sx5~mtGoNpKb9@^;>}=#GPAdc7r~Y1%MjJnw_!ni- zTTt*vlea62K7s7#uR!);bzI5sMYT_>+vs6u<0bWWCI2(@vL*}oy@e~3#TKeP^U7$W zf6F+DFLyS)EgR@zXQQ~~GxdHvo!7H9v-p|CkM?a^eAnXZdcWM;x7b|VpZ{skbNMIp z_x5b<{rh|^e_npsqF*oicZ(icba>IFi`Mq)1OH>-djk&-ylY^`z~+UoEd0sBM|xjg z`2K~v78(l|^grJ})3?3b9>*{zjVQ`dhS{9bkEm%uU+uqf8UJO)A#$n zzw47f*LNHY>|f{VH1B@V(o24=rg|<`I##g;>LD7gR!fej)gZI^w^KpTOhv0&n(Ej# z=)x<53-N6HR2`@67fnNVjN7=N32_OhxLPtyYzR8G2CQ>S)r%#^(rv?Vs`jl^u#U#< znpdi*H4q(2Vm*V=0V?I1Q?8h1(JJZ37)jHp6st8`bu2JHmW#fjK{eY0lD@9`uHhqDr6vJ1-GfxHTG7jpwt<>kEUA_W0k~L2 z2Ywmj;Z;DWCF3T6q^)};qgF>lMU0SEDyv1tSFk z*r@%A`zIJ7O|RfhYDH{fzFsk^X3;kEGL@_vCEL|??5!otGImEywV`7u&60*m)eY>L z#YV7BZ5$gLYkaIQ2144op&J3{hQ?1esl0(m+ZqpVRvRZaj|U51md}ySZM=Mr+Bmvl ztg&MKSg>dmRJ)dN3Ep%EcAG2su}OuwV8N&~D&-q%$5iYUJY3T-UgMiwb%jZ@e_NnfDsWZ9KPLZT!=^vBuq7#u{Hmv!HFMxACG{ zXnaQ9{PA(%G0=E^quRJ&{aE9w&0~!#N5{5x!)5Owthx~yb~V;)P#gEJ-{4{I$TV4F z-746+<(ZC2TQ<}?;6oz5ISu^iFP2MI@YBcS)sxL?kTM<6ljaZ$6|zWSwGbGrv*cx2 z%?f_|NUP=?I{fT?^6G-#H^j6AHwPeaj zz|2I-N~G*W%1NZ$SgNjSUM#6qbHEj6UgWZWb_0wbv=%aswa@q z6UZ0|WQ+tdMgkclfsBzr#z-Jz1pjrXJXVhnx|JAoq+(4Y6>A!)Skp+wnnuc5e|>EJ zcbQH>1Iw(_ukF#VYtygq7wY(jl}#z|EcDNVb_BottZQ7(bGfw1r3~>;6n43oZ!km` z|Fn2(F1O}4TVo6ypIhT+zs^m&{44tOxVp13{XM1XYMF?pV$a*J?g zj{IKUtf*WT100u=qP&=(#CbEnxkmoJl;h`M@$?Dl14NC9b@iIlUy}Sl zjfo}y+S6Z>{3@>p{RRK9Q=X_3ZIo{;9^A{?k6pT1@Svvvd05z{4ZmPQ@-%rZn|8sp zN%Hgy7D=Oa!8eIec)6n>OQmi>2hb^?4b(Ewh6+j`kl_|wY{a@-@X?B)7hG!PYEc2b zQq$;ylra{!3DTnKrp_`h{*l1z6l@ew4XxmyQLH0_lAsBy?iXxiqG^rA;n8kbikZU8 z@~zUxl%Ae-nf#JP>%gflUj$WaDXkUjM(%u7)txhCLXT({EDeH`FVte)G|0mG>066J zEw)3b#Xg;TK%O4%=H&>Vw*>F}ytHVeIM+;UstE-r^mWFRc7w-LsZ>d)Hm5;%eK}NF6NwFzGt0~+rI1%noHh?Gl1^XTlU%G^6+t>QfT+TEB(U%@RoY2}=T zBY1p6cGb)a-BV+QRot(ShTUvVoRlSz{D;cJpOEU-DTZS4)0bZFldgEpaB?A zTpubnSa&>MbAlbamF1E{(!nZ-gh^mmH6b0$r)Kb8OIp2UPCXj7kPoo!#5fY9=Fg5| zA&A`=t;D!)q+(n*QZW|9s;hAx1j>yBGI1FIumo=$mjM7v@Wyc&0I&oyaTdf<>v0(X zn#Gt5fK&pRxC}r#$d(|3RDw~D%K(7I1QMhY$QWx9Lu15;27+7gTSr=1cM?;CRBVcn ziXnnj3=yPaqk>dQV$cOIZ!;GT4Fxegf-%=MX52l7)IEj>Cf{KeM#vXUDT7&goB4gw z6eJF6iP?QeNCTn;F6oG87BC#{peY6jD3>Se7Ni*QE2IaTc{NcXT zJK#Dq*bYA~<@i}h6ylL33h~Gi#p_P7NrDCaCWz#c~x^vTr5A_@U@NJvNvGUrxQ zj}LhRm>Ya%j-Js9)HDK-{85%ZdY{rS_iKHq; z;$VlN28Xjp8_#c9Q`or-%hr5z*PZJt#bZ^NzUfq=^1;Y8$|@DRiVGu>uW6p{!s>;W zj=0k{STOyFyi|tISMYcxJ2a2Fk{3Mw)$DwFMa>`mCcA298&1F%tjJZL9jdm5@$TH< zsRO~P>EUh^*wHYq4LqdDer%+&18N zrqfDk4OS8qT(k_12%Di8-O{XNK%g*T({oKM^>nouFgWs07F zN^E%Gb8-BE>!*jiS-Gv@fzrNxi5XL%Bir$9%}y>c6t=|@qwSga!uYupy?+`Q++Sv< z@s;Z#Va$hTCA0Xg3=O?d|W7Rz=;VcU0arv44>qG}+w0 zJnYJFFfvF8YauFXMzG-IBl0p-WHG4I#y4x1>;7uao9&+8Wmm*`Y#bHf$U+`lq(UBB zq(UBBq(UBBq(Ww2q(Ww2q(as~q(aVtQZX@zfe}+BQVC=tF-X7?$V6h0fF+QL#2^7n zAkz|q7#J~CA{BCYB9%ZU5`zRRflMR@30MM|NDLCNdIFnB3}Rqpze37w>CchExg0|c zsTgYHr`?p$M9COxfW;;aDW{&ar3j0MBP0@*W1va|8z&?cheZ*OB^93;C6A>UpL1YC z#;JGIu8U5M2rMErt!x!Hw z<+yTl<6^kjxcKqHUGiP^*PmQlG7#1l_Kw$^Tq79>-DB@~-N`jUt^F->4UW7iI|P#} zn?m5dN0N!*vE&0~0m(A3z!oeBDkcmakUp@si2+xI4aF}wU|S5Upks}xpaLC|4g?fp zW+3no_i6eCm*gl&NW^nswWEmT6bx!xC?_s+&>CvMvf^@&kb;E7#C9Dpxzr=9DclYW zJR}Th1s!Ec_@R_lu#kdTLl**1H^K=WiVsQ;oRoW%hcnW{(AGaBO?~Rxz08HO=PzsP zKk|<0;cnbfhzm`7{<>3|IR%?6@mrqJQq-#s)SmFX^ECMXI`BS=GkC)vWnzvM zRycJCADY_dB0@q@@q*SuPx>;X_0+cEqMMhp+{M>J)1FS9knv z8*?)7wT<{|kR?!*fq&fN7ah^-#CgpyE4KMk7J2+m&0z0^{ejaXq zyA~b*#XG!Y_`*N$%JA=1G^-q0-5vf~p3wTWi?CVZ&` zjO+``f;YowX0i$Ld|ihtna|3&kc&*iga~^&RxN6A_TK)0YS_mZ5$GQrUPjnOqE*xe zV+DLuJZS^c-j;kRpfIOW@mHR$Ed@*7jGfrysKAGF%t<`ViY*z+qT~9K4&n*JAi51l zZsVK-wdQ4wAHMiWVd>#4USs^kYmA>xwvUM8aBgSWxdLA2SRFt?ceQBelqiW-T3p z{5ymJc_sNmG>z_T3EB&7lgSRs*Ed@RTM=-@2F{(0bgEA zBc*HGJGP*-H4OzvbJ9@sk~#uGAygw;Aim3l7{#DxKCS+|-abWPBZ&bCGZXI`W_Dk) z^-jd(VHg(EIKZ+$3x25juuv!OKEL&73i; z)O-M5&<*uP2Ov8%k%i9kbPr}b!wj~+Lm6s%rEZP`aN?oUXjyfvL=LHQu0-yrlil`2 zSO{bD8_uXLqcQwlLoo{aJETH=5T_!MPNYILhe(BL4v`Ag93mB}IYcT{bBNUI=Zp%O z9m`*+QFI}e85(OrjJi;t7(!=a2&D;Pd8mRR{8C{u;FK^P8L#Q^Y6bEj{nJD%kVhs z*MUAeI;&cclOQ3H(U`Onu)LDkLJJs-5EdmlgCe3F$0vMN$&QSFWXq*P2T}r5CFGF` zk(N#w4O$EWnNS-*f5@(+C{OYetk7hXHm+us40*DW+5^5cY&M2v89ol zHJg?L>$2vPBM+$xB6cMWBOZ7KGy0pcLVRptbU54UHi82okR>op!OqOpKU_}_IwmANplClnSj~ZK8xvMuCEu4T zw!LN?e}!IvLMbMcjU>LsIe5Z@vSNo6@E{k_2?e@iBhW9Chl9iWlm!dJi9!{YrGNNy zxi3;E&@_;l^P`FmH~nDbJo$ggz(KI?)Jeb!1+?u3Nu||!9XNKgV9{;{6FgQ3tP7;0 z(hI)e)M@9%?uvw25E~oTzliib6L~SyE~xcHMmhm%JqD$4eLrg`gapplK`e%bqKgRQ z$TcB-g@!w2;0g`$!-v~}7LnzuNuS8dKoqy9YamUUiEv2i-4ejy$Q4R~U)e`&8AN#^ zXfNTn%5`Eu_IM3`uuoZCXlriPzSskJXg(cc6%4Cs^k?mlCGVB`dLz z0&2AA0wTne;SQ^a8ixu7Oj$N=b^ryLwx=WJIYy6dK;G~WQ87Fzs+qoPyU?W(n-x)U zvsMd6N0G4B%%?kwseopl(QLIjTYq$7uG%PSxteR^Rv{?kF>`4&0Sx};PQ{)_7fgfA zdf1fUbPqFFB2-tP9fQqAfo&s>>FDt8b1Rm3|5t5|u^a7 zi5bSZ=)lZs;I^gMB5wo6f}#LJW4L31+$q6Q#wkAyW5TCHbrL7KjU(8{T~8Xga|#2D zNGc6$NWWMXF#Up)0r!P#PoO|Yf5gv3wc212$6KJlN4%ZqgSZg71+{FfUK>!n8^=Y# zw;60QSbtn|M;8w!u}s1U%aGs2F78{t2b&0@XMFJHvHsd%vk`OK-~;>gz+%wTqlP}) zhJ;GB4@HN!j3L=0-k?RxC&$ngy-}R=1#e9`F~%32J7jg8$WdIW+hTqDLq5$Zki0~J zYdsO?3;e9e!S7EMpyo0Aizf?=Ff=&`LwfmI`GRc0B6c{Xu<|u>uRvWL4V-lJ{OwCaDvh>nyvx_Vq6sG!NVLekS+ot@FEJWPSny~yyZQ{i5~(q#`A82a#vnanTq-#lD3yLyluEsZ zXg-FaLg>e+cASLD6x9Mn+OFedOFa?Us<_aMrPb3*l?5EkC}ReLljlm?$@lgWu)q*Y zcLpO<%A=4?JwA3JTt8mfaDf8ar)5CPl&uq8+u=?$7iK|$NRO#x| z-HF^;h?YFFz~*A739mL>NzpFESNVt~pqqjc>ze}C62R7}y9F9Q7iW?I&RGXh`4q$- z(QP|0;+;@95+S=Pu~f*mD68=lLZ(b z1S!JR2MjB@FA^gXa&~DL%Z>l>a%J@d>F~*E3^$rN5MdLlm=ugvxH1f8!P5^T#F4zX zToQh8a~}@dge=zv!|8wktGjl{H%RB=a{4gZ1>Oco#jK}D#k?EZy+DH9_seCzvL}4C zT@D-lg(H9iT|TTAtcXkmnsbiJC52ak%YFwUBtk^5i&~mHaHaOSS*_^>9Cx)RQL zo&hVQZ^Af)8OJ8ilYdD#`%T?hdUi4Msi~uBZJU$WK?pC-w(B_vpVTo(t+EITGm7W) zDYz6@yo##?!U^#T!EcXcmoG&=2l9mZKt1AiFU|@p%pbD2&Vvp@g?bZ481CdWd382+ z$lE127tDXTq=T$+deiAivC9;40JAoYGesV-VH&4Av{H!X5nfjuc~D1b;v1P;Vb3f& zs6-p&K92)eL}F{8KhYKqr7a?=VN54Mf{tiXdV%aAcZ)rMwP6k)-4v8f0Sz!oam%pC zH7n_pYha}(77lX{BRpEBfCOStC}!K|BYw1-`)wGyyEhmU4Z{_ok?<`z;F_-9;JcB5+AAEb8YTO^ZQw1a5S3 zf9fAa48j1@9c8c|z`O4aF=&c7{zx!*LkwyPeE*M&K^L~OS=dlX^ufMK>=iD@fa;q5 ziB!`*Pnt-mD0Q(-n%xj(OcL6Zkv4(G*MP>aKE@lSu)|+GczcL?h(FNegOQ=qSiPu5Eb6 z2b%^jZ`~CTeBvY0sz7^G`j?h*z5}tsJSNDKxh?Q5M3-Q|iGrceAes&DwE?Hac7r!} zL0pkc6tbAOmlN7;h(Y9CMydq9X@sJ?;Cjd)*M%26duhd1MH`H+u2{iWP)w$q1m@9d zO2b5V9Nv`%ovE59gd^NO(VwjoD+%+OD+a(~DiI>p@*(3&t(d zxc^ak72QAtuTq^ZZ$dy5>`|?7yZMm##AEW%JV=}Z?&j0o&(Ym_S|ssJiqQcqcPkX`xex0`rg^>AB7-S-OJiMjWBA&8{;swTRg?83Dbr7-^P!N z?>d2_i0yB%7$tzgj;bWAoY<}zq+=4m*SLt;Yp)jpgW~72BJM7%H<$j{QDM&giLkq{ zKj>KS9SInC`GZE`anWov0#XbZ2Yjsz>qW0H>-(+bkUMz(@wk}ypb~tW;AGihJY&FF zjV7rXpy}x}6wSQEu%4k@^u~PmTZgG^)?FAwHO4oL6F9ZYhFP&)ma<4&6#Fn##9{yz z+N4{)3S0p@Y>dGmf8C~rh@W_y<)@?zQm4Q=Y`$2Pu-9N;gBk`y2NeyrJpAGr8yt2) zY$LqCy{!wfir5V}t;JV6>CW)UK92%XVO1AiRm~tfy03zO*r{9(mPHDYi1-IspOQd! zc@bZJ2rEXdcGV?C+ZEESp0P`VnU6O<{t(kh|Nk9t$DCy%>Q02)c~+f{e4+o=H9r4^ zVq9hW6uSLkQ4Uz1~ZTj;l zFTN$BoVehc?sHv+OL;h6Gdhq6OgB;xByoisHO7$TfF!PP<8H{nEIc-IB2A2G^O3g9 z<|DDDZv^Yb%uTOVFm1XFBxCM?t8vS{Y|iKj=S;70k&4=`O2Z%-6!T-%fyaOX4!=s`|kof(*VI{EOA`W18x z&XANW&FW~IK3qW+aDq0A&_9yj5lARpheo^kk{JmFE1N>W&f8-Q?V`a?=YoCEly~tt zG-7Dq+9^~IJS_xQusyoj{B(mLMVz=X!f!ipLl~Pu>C+7brkeO6j=^ZF7VQT@KSmLI zbMCGgv$rPIAq0S%gaGi#jPywe0QXDvPnGVV@e;1hB35q=LJhk32qIYWKZJUbTLm1Bd|2Bd9kD4-ZVsC*F50DoZ<0uVT#gleu07P(mA zHeL9;Qcs?V&^*W}aiwP)j6Wj_D-%rNh~)zhj!((KP8~xr4dZ1%sz-=|INtDmIp2Uo zm+2FMM5`bTYx!otsS@}7Q}7i0IPo#{#CW7(%=U`!JwOo#Z7yBR;kYhLzHu2U4Lh@< z@nKvG*ghw1LrT`iwvkMa21Y9Vsu+g!YhcRbDCQXd*pUdIW*WStB8FlD5`I;{niVnQ zIfSIaBrN=@KCBGNK!L2q4d_&GQB33+nY9bTF%s~mVSm`=nsB#o8l>Lr8&e|BtibX&Bo% z`8p|he!@`}N=nz($nTi610j2mh9&fk_-rX0^1)RxLB^LmVwsb$xV}-aR!zko>kpfH zLS_KPZ3#k)=yHzmj1ZSF2nAtN5y;TyN5#0ErtXndhC+ua^qHcTaEWk)zERwnDKw3u z1963eXihvguXqsJNHY7tdji)?!`B)t8D1BM69^swQ~2&SG{8*pK3pt4g2RM>-HN4M z@SFH>UpQk^knl+();3k?v1Y|&VjO3Rf0q3VaY^qcvRkKCJDK#;%RLym3<{JbXUnU} zj}V$}UhjP0@&1S8Jidy|F>x&a{;6aJgIlq4YVVO=gnI-zOfV-$?UJX&`B~cth@+$X zcuz}p`dSG-rbeq7)59M$0UIDxec0A;xWoGhTMYbVNSi@H{2cs`xgk~~(m}FT5XTb< zxzI*aO)LrFC(-n663SxoGR7N{c^}m;6Z{1EN}vMTBt8UcNa2kr+#taJb30&KSm`fr zd;4&DF&002vq3=nj;5}*Wp1GD^- z;tI3g9wGq@3IG5AS^<{;3xx+;lMWyzlQtkYlW-s)lae4P3JT#0xDJpBZ~>Q-u@Kvr zdDsLkv&10}4wuu<1TC^K{|^CLvOokN3R(vd00*Q8kq3U0DqtWK2T}(=4%G~O3}gpw*rF^XNM7I0f!N10*4W21Gf=p1oJos0SuP_441PJK#dHy Im#hS71AdM_M*si- diff --git a/server/core/config.py b/server/core/config.py index fddd7b4..76a314a 100644 --- a/server/core/config.py +++ b/server/core/config.py @@ -133,7 +133,14 @@ def get_model_file_path(product_id: str, model_type: str, version: str) -> str: # 修正:直接使用唯一的product_id(它可能包含store_前缀)来构建文件名 # 文件名示例: transformer_17002608_epoch_best.pth 或 transformer_store_01010023_epoch_best.pth - filename = f"{model_type}_{product_id}_epoch_{version}.pth" + # 针对 KAN 和 optimized_kan,使用 model_manager 的命名约定 + if model_type in ['kan', 'optimized_kan']: + # 格式: {model_type}_product_{product_id}_{version}.pth + # 注意:KAN trainer 保存时,product_id 就是 model_identifier + filename = f"{model_type}_product_{product_id}_{version}.pth" + else: + # 其他模型使用 _epoch_ 约定 + filename = f"{model_type}_{product_id}_epoch_{version}.pth" # 修正:直接在根模型目录查找,不再使用checkpoints子目录 return os.path.join(DEFAULT_MODEL_DIR, filename) @@ -151,32 +158,46 @@ def get_model_versions(product_id: str, model_type: str) -> list: # 直接使用传入的product_id构建搜索模式 # 搜索模式,匹配 "transformer_product_17002608_epoch_50.pth" 或 "transformer_product_17002608_epoch_best.pth" # 修正:直接使用唯一的product_id(它可能包含store_前缀)来构建搜索模式 - pattern = f"{model_type}_{product_id}_epoch_*.pth" - # 修正:直接在根模型目录查找,不再使用checkpoints子目录 - search_path = os.path.join(DEFAULT_MODEL_DIR, pattern) - existing_files = glob.glob(search_path) + # 扩展搜索模式以兼容多种命名约定 + patterns = [ + f"{model_type}_{product_id}_epoch_*.pth", # 原始格式 (e.g., transformer_123_epoch_best.pth) + f"{model_type}_product_{product_id}_*.pth" # KAN/ModelManager格式 (e.g., kan_product_123_v1.pth) + ] + existing_files = [] + for pattern in patterns: + search_path = os.path.join(DEFAULT_MODEL_DIR, pattern) + existing_files.extend(glob.glob(search_path)) + # 旧格式(兼容性支持) pattern_old = f"{model_type}_model_product_{product_id}.pth" old_file_path = os.path.join(DEFAULT_MODEL_DIR, pattern_old) - has_old_format = os.path.exists(old_file_path) - old_file_path = os.path.join(DEFAULT_MODEL_DIR, pattern_old) - has_old_format = os.path.exists(old_file_path) - + if os.path.exists(old_file_path): + existing_files.append(old_file_path) + versions = set() # 使用集合避免重复 # 从找到的文件中提取版本信息 for file_path in existing_files: filename = os.path.basename(file_path) - # 匹配 _epoch_ 后面的内容作为版本 - version_match = re.search(r"_epoch_(.+)\.pth$", filename) - if version_match: - versions.add(version_match.group(1)) - - # 如果存在旧格式文件,将其视为v1 - if has_old_format: - versions.add("v1_legacy") # 添加一个特殊标识 - print(f"检测到旧格式模型文件: {old_file_path},将其视为版本 v1_legacy") + + # 尝试匹配 _epoch_ 格式 + version_match_epoch = re.search(r"_epoch_(.+)\.pth$", filename) + if version_match_epoch: + versions.add(version_match_epoch.group(1)) + continue + + # 尝试匹配 _product_..._v 格式 (KAN) + version_match_kan = re.search(r"_product_.+_v(\d+)\.pth$", filename) + if version_match_kan: + versions.add(f"v{version_match_kan.group(1)}") + continue + + # 尝试匹配旧的 _model_product_ 格式 + if pattern_old in filename: + versions.add("v1_legacy") # 添加一个特殊标识 + print(f"检测到旧格式模型文件: {old_file_path},将其视为版本 v1_legacy") + continue # 转换为列表并排序 sorted_versions = sorted(list(versions)) diff --git a/server/models/__pycache__/mlstm_model.cpython-313.pyc b/server/models/__pycache__/mlstm_model.cpython-313.pyc index 59692f5ef902abe922acd12217833fa133491442..61289434cfa7cc3153d1f32559bf8fc8fbaae27e 100644 GIT binary patch delta 20 acmX?}c|4Q*GcPX}0}#0Ftk}rC#{>XMPX?v{ delta 20 acmX?}c|4Q*GcPX}0}u#pF5Sqz#{>XKum))W diff --git a/server/predictors/model_predictor.py b/server/predictors/model_predictor.py index afbaa93..e56db5a 100644 --- a/server/predictors/model_predictor.py +++ b/server/predictors/model_predictor.py @@ -216,11 +216,11 @@ def load_model_and_predict(product_id, model_type, store_id=None, future_days=7, model = MatrixLSTM( num_features=config['input_dim'], hidden_size=config['hidden_size'], - mlstm_layers=config['num_layers'], + mlstm_layers=config['mlstm_layers'], embed_dim=embed_dim, dense_dim=dense_dim, num_heads=num_heads, - dropout_rate=config['dropout'], + dropout_rate=config['dropout_rate'], num_blocks=num_blocks, output_sequence_length=config['output_dim'] ).to(DEVICE) @@ -241,7 +241,7 @@ def load_model_and_predict(product_id, model_type, store_id=None, future_days=7, num_features=config['input_dim'], output_sequence_length=config['output_dim'], num_channels=[config['hidden_size']] * config['num_layers'], - kernel_size=3, + kernel_size=config['kernel_size'], dropout=config['dropout'] ).to(DEVICE) else: diff --git a/server/trainers/kan_trainer.py b/server/trainers/kan_trainer.py index d074880..63f44b7 100644 --- a/server/trainers/kan_trainer.py +++ b/server/trainers/kan_trainer.py @@ -168,6 +168,7 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None, train_losses = [] test_losses = [] start_time = time.time() + best_loss = float('inf') for epoch in range(epochs): model.train() @@ -225,6 +226,43 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None, test_loss = test_loss / len(test_loader) test_losses.append(test_loss) + + # 检查是否为最佳模型 + model_type_name = 'optimized_kan' if use_optimized else 'kan' + if test_loss < best_loss: + best_loss = test_loss + print(f"🎉 新的最佳模型发现在 epoch {epoch+1},测试损失: {test_loss:.4f}") + + # 为保存最佳模型准备数据 + best_model_data = { + 'model_state_dict': model.state_dict(), + 'scaler_X': scaler_X, + 'scaler_y': scaler_y, + 'config': { + 'input_dim': input_dim, + 'output_dim': output_dim, + 'hidden_size': hidden_size, + 'hidden_sizes': [hidden_size, hidden_size * 2, hidden_size], + 'sequence_length': sequence_length, + 'forecast_horizon': forecast_horizon, + 'model_type': model_type_name, + 'use_optimized': use_optimized + }, + 'epoch': epoch + 1 + } + + # 使用模型管理器保存 'best' 版本 + from utils.model_manager import model_manager + model_manager.save_model( + model_data=best_model_data, + product_id=model_identifier, + model_type=model_type_name, + version='best', + store_id=store_id, + training_mode=training_mode, + aggregation_method=aggregation_method, + product_name=product_name + ) if (epoch + 1) % 10 == 0: print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}") @@ -301,7 +339,7 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None, model_data=model_data, product_id=model_identifier, model_type=model_type_name, - version='v1', # KAN训练器默认使用v1 + version=f'final_epoch_{epochs}', store_id=store_id, training_mode=training_mode, aggregation_method=aggregation_method, diff --git a/项目快速上手指南.md b/项目快速上手指南.md index 1912719..8000320 100644 --- a/项目快速上手指南.md +++ b/项目快速上手指南.md @@ -93,6 +93,17 @@ * 在这个新函数里,确保实例化的是你的 `NewNet` 模型。 * **最关键的一步**: 在保存checkpoint时,确保 `config` 字典里包含了重建 `NewNet` 所需的所有超参数(比如层数、节点数等)。 + * **重要开发规范:参数命名规则** + 为了防止在模型加载时出现参数不匹配的错误(例如 `KeyError: 'num_layers'`),我们制定了以下命名规范: + > **规则:** 对于特定于某个算法的超参数,其在 `config` 字典中的键名(key)必须以该算法的名称作为前缀或唯一标识。 + + **示例:** + * 对于 `mLSTM` 模型的层数,键名应为 `mlstm_layers`。 + * 对于 `TCN` 模型的通道数,键名可以是 `tcn_channels`。 + * 对于 `Transformer` 模型的编码器层数,键名可以是 `num_encoder_layers` (因为这在Transformer语境下是明确的)。 + + 在 **加载模型时** ([`server/predictors/model_predictor.py`](server/predictors/model_predictor.py:1)),必须使用与保存时完全一致的键名来读取这些参数。遵循此规则可以从根本上杜绝因参数名不一致导致的模型加载失败问题。 + 2. **注册新模型**: * 打开 `server/core/config.py` 文件。 * 找到 `SUPPORTED_MODELS` 列表。