From 751de9b54825bdd3170383d083525ceac1f930c0 Mon Sep 17 00:00:00 2001 From: xz2000 Date: Tue, 22 Jul 2025 15:40:37 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8F=92=E4=BB=B6=E5=BC=8F=E6=B7=BB=E5=8A=A0?= =?UTF-8?q?=E6=A8=A1=E5=9E=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- prediction_history.db | Bin 114688 -> 167936 bytes requirements.txt | 3 +- server/api.py | 75 +++--- server/core/config.py | 4 +- server/core/predictor.py | 135 ++++------- server/models/cnn_bilstm_attention.py | 102 ++++++++ server/models/model_registry.py | 64 +++++ server/prediction_history.db | Bin 53248 -> 0 bytes server/predictors/model_predictor.py | 226 +++++++++--------- server/trainers/__init__.py | 52 +++- .../trainers/cnn_bilstm_attention_trainer.py | 118 +++++++++ server/trainers/kan_trainer.py | 7 +- server/trainers/mlstm_trainer.py | 6 +- server/trainers/tcn_trainer.py | 6 +- server/trainers/transformer_trainer.py | 6 +- server/trainers/xgboost_trainer.py | 142 +++++++++++ server/utils/model_manager.py | 63 +++-- xz修改记录日志和启动依赖.md | 2 +- xz新模型添加流程.md | 222 +++++++++++++++++ 19 files changed, 940 insertions(+), 293 deletions(-) create mode 100644 server/models/cnn_bilstm_attention.py create mode 100644 server/models/model_registry.py delete mode 100644 server/prediction_history.db create mode 100644 server/trainers/cnn_bilstm_attention_trainer.py create mode 100644 server/trainers/xgboost_trainer.py create mode 100644 xz新模型添加流程.md diff --git a/prediction_history.db b/prediction_history.db index e33ba52951f9ddf913e4b1cba64f078803fc5b94..93cb3a0a8a3e8996bcfc95e7d3daf246fa11858e 100644 GIT binary patch delta 17512 zcmeHPYj7Raeb?GDvL#(z3ELR!VN3D@3%r-}ekdfAAql1}Ely!lmr${L?=E#}WE?9F zk14Ve40NU>HW%1wCzC>$Hb}8D{)i0`lb2GW5xB` ziwPP#CaS&tI|qx87r$7{iVqbZEIv^DT=AaboyD2r;o>KY$^6`t>%R7{n?JMRDh%GE zaJhuxvOC0O*JZfuycCxm6S$1;!DZ|cTt;`{vVA8m+s1Jj8OCLJYxSw^zdozGSVvJ5 zUh4i}&xtNp{M){N>^ob0uGmxfm)^U(?p=S^hHdL#?%uTiKa2Nm_}#*h4c8R*b^T4x zC%SH0|3u;UdOLc$*YDZzr`=y)zrXtn#hZI4`fly{PWSh^)4r2Ey!WrVo-J(e`%vHJ z-Y0uM*85*wZ*RDL!$0?&#=95)&42#63m)3rIm^?8Xqr|+tEdh|QlVk2xC$$VbINUn zlZfj*Yi|rJNfOHv9de!UP;i=tmKr=)DI3j_veb!DZlFvRNT;EQD;guqaVRZgCS@`y6DizC6OouyI1PEqDxt6<31g0L zRy37yMM^Hq33mgNL}^oDJdAWjhC(VG+A@j4L?^oJ7GTPfxq*?6Q?!WLFtw~43KjD( zrc8yFSQ+y;PADl;Ke8fqN@EkoBDI(@rePFSNJv>4r^KeBq9fu4DoG7VWfW>+C;*b> z&~U0kA`(N5shCQvFLwj+`H7W^gt0BpIw*lvlqffO!reDVLR3nAN>!E#!MiJdN@GnJ zu_OvP(_Aa1(y*)u8jBet02Eas7DhJZ35irviTK{yM+&QqWg#UDkgHi? zv2I!kqcRtja#BgSl9#!Wh6$P`_ISAllno_7B%)Y1 zQvgV@C}cDiEG`Snl%DV(8AaSOEkk0^Ed+_fXtrY92gIbTB8A~SL_BQR8^p(*Pm z6crnX)_^()C4z#!@Pl6B-(9v*%*vRel48F#6Jf;5I+QFW3g1{!WxCr9RB^0HqCr8T zf*An_Hb+F^ERT|y8ZC63uw8B>l~xi-0AR|nB|=4EoW{W1CRD~zL{m+6x`AbieN7V+ za+6}pL`R`<6qI7wX)0ny;&g`_7^PCC1T&^-gc%7|44PA6Sra>JD#RKc$#Fk2PB3Re zLY)YOZNyebHsT>CB8`l+hUj$64UAL6602DVa>Tj?_9f<(G^8ny%QA|&P<+(SNHvdG zB@Qc82>=Avwct-on^fX5SF>q_$?a}rS>PocBu**@BN@k6R;;5D&P~Lr07GnZ10e{a zBqastSHZgveyb4UEf{(pZZyF$vb4Y84t~MMx8-WI0NWUp3rSXkHl550fUJKy`o-usK9xTAPa;V+A)`ee_` zU%P7b!k*c&&Jh_xB_a5z}sjy0CNWP;l#^ z!*O)zj$7lyhmX$4O`Rh{0Vo|T#dy6@v}v53;cZ%HSi45ru938Byu4jw$v3*Cx`oug zX*OJFbpX^`_qJ=itX<=!?HVWAHSSqxYhIxb%TF1n%=&=VOwA4){XT|L;uL~{$W4ckiQOoZg~34@O1D``}+z| zzH8Uq*wOV!=MUE0I6rsI8gpXv@Ic3nBiC)bd@Q)?zS{-{HqQ;^;hJpv5sX zX`-l5On~ZKSM}95w)EQ)Ws*qL)qwoz>Q^Gb7CZ6 z6rxFG`%eU$vQHcc2GJ#nvRqII2Pve@=#xQ+FOrZ4)>6xy{y_8LCxbmv9so+8JbHNg z#I7Na-i{AA^p1G+@<@{Gf4&l28?Y>VKG*`hIQ|=aiiDJ1#H-JgG!SG)SYrgq=B^1Y zs($H>!2;S12PQ%(OSAcX!Itc?UkwI(O$pN9q2(6x_#XxrXRo5UAI*{~2M7@o3cMOU zy(t?3C5JzadkA638%cnq-z^O3w8HU<}%nD@qhLQPu&A^Q_qecJem{ zFLdO#k=j zse%;%mF}siL0;&=#xwN7)(}$GA=v%`AV8NOT*BHAnk(!B9t!aiyulj0+@jBr>Wf>p zWF9hE0+|9}tuc)hq7E6FE#?tcYGb4pEQ>Y-!~K*<5W7VSH_%`@qjv`M29Eu<_zVV; zv_$V69lw3*Saj&q$@t_bDM8M|`T#XxUtob)Lls6GJA5dbIe6&c%$>DC!s`P;EH`cr zmJM_VgV~GM2HR#u3G*5%pvZxcHhB8Bq;fEt#*;0}b)5%mJxg(QFTE|g1JAUq4LwI^ zD*ihdC`W>XgtbB$K)835jvt}^O@Ig+0Sn%EhLw2i2>0_50rOb{Lq=hio9xuj2b*ST z2}&LKfvyP$x}0II64bedg=O8QQcOWgs`f(cJ?LF3yb+jzcnwIK!Q|9b8w@lh2mtK3 z7O>B-VQ!eg2<#ElKpwklZpJ4aoKAuKbh#Q1n=_lv1RI0up^=Yd|8_PQseXIJJi0A- zwiAf5Y`zqHXbrQ~A8os0Er_s6zVXZzIm`sl5l$vw&~c!naJ>7?_1CZa^4g%|z^S>9 zm=n9kTSUuhW?aLHbBuC*#-9%c3SyG7Nuf)iY%a3>p9%K9Z&<{O;XKQ&i{Y_Wta2E# zOT=&uEmG2yK@lyAmG_Srj-niEt^u}lv8fQ-#phXV4XVxC zJygvGQgjZN#(6Z}K$=`;AWg-|(HSUKIG{RYY~$kX4HR&aNybW2KrS_@wfB$Nln0hN zrrTOD&Ex#7Vsv`+D7aS$*hP!P_`X4JHV*(Ey?cLa1n6z08@aQg4tVCx;}O?Fm)i=p9{xz4_vkTfHS(fDHyEEjjcnEXTPQg00?e?{PAr zWVI!PS6}?Ii@n9(4VOxkU?EwO9eE-c&c1s`u%%CcwV~2Ag>_SBWyp+kM*}=?PSvit z;I59s$1k{Res1~#`{1wb^w#K*L-Is8F!0d+xuHDK7f*p#MUMrUNYAKrEGH!vhevaxG}joHf=2P0!J zSr)BDz=YpRZNo)N@E0);x)+SU`cH4})$hH5SG#YRbnX#v!PT0*$%b2Rr1hh7$^Yuu z@Q1tVkW`~TKGwL;9y=4%r+xOhBsh+pehHz=ZNQ8U5wp>05FH!fFC# z_l}xA^KL~=pXCXTEqWC-V@blb;xJWf#?t#hkMTUKTp<3S>!psam*(frcFmrg+~r9B z8c+K7T{|%F=-qQeD}^YVBb1G3(`#p*dgHSXzxmvsy!PxPZ=E>(#&6%h=dJrsz5ZPF zwJ)EUXyObP8JSw(3|iNCnRq_vFM>hjBo!qD<&A9?+63-BEC&X^0nuaxL7dmW#fJ(S z?9jq57lM#9yjf?|Tf`oyBw6%mP;%|gWktx}5Lo!I8y0*v^f95WwOVS)3`f+}87tWj z9}NawpNrr!L>^$Vz+2PKa~m0emuGN30e`{@SS4s~u;B3Z9L^-wlqv(O$cQBziABK? zr0`WFPA9lvu6rqiIuWY4b6o_)A%-}P-I{Hk2)57SaX3^Qv2*9G!?0rB4{Nw@t*dn2 zW6gm>pS!Pf^zs!u?`wg>EkRTqQWnPs1YRE;tvbhtUhHC55Ij~2CKdcQ*Xi6ge42q+ zhtr0TJeLSqIFWGCFA@J@y!+be^x!lc)V-rn69~^^S2ZvK8WdlNJCFeO873NDt_U2r z5Dq?0Z?IrQkzmzvn15r#r4dI?<8CpBE6Jg$nLBTHp7Hn!VrIP44L|eh-r(CoK(Zg5 z1r2@cY|xwi_32<3VunfQveg6+{r#t89{OKkrZpX}gPA&BKQ;IFP$ZYMDv~@$2)Psq zWRp~rRF>fWTSPVB$ThfNb(W})`WaOrRcz^Oo=xyudUA4it3XV19PX?t5Xa9g5gX1F zM4<};B0j2`t^f0)@VqE|??oZcL-`A$aEV>9r?8A(Yd*3xm}?q1%i48 z(LCAY44!AE^UU<#GE+tm2T!vc^xvMDCR&*(PqM?i%uITcfpZXpWBNQZxwC@vM*2_8 zNauOzzdbWu+R99MQjFexX2Q8EGTWNZE(gwQrt_NVz1B>#&`gMvoGZnyJu_X_%1rL8 zw>h(~U1Hz6&P?!`aQurilk=MCyk`0dhmUg5{{l0u>3F-N>$@FqpPKu3gpWeUP1h|s zhfp~(Z}WwZ`rI7D~ydz183l!H9@LHQ)h`k@i5knFe@`l`5 zpHpmjH@6IFJ*KL_`@_q zuuzqdX(j;%AVG?StUd964YOwmvwy#Fq$frFNDtW`wGg{Yhl49TnWIc#K}N1yGu}~{ z?4Id5-TCI4@%g!a&Qo%cUHJ#W2S)4X1V?8Me=3=3Bse#n zPSs+uU?x5(kO8Sc9dwbIa3X}GYvi<0oIRs@0@-Jc)Rq;xYRrTb8N}@sj?Hl@q(EJs zEgUng^qAwOmHHezt<>kvtv0rPWj*I54^AHY6Ias~kQ?PHt5vJKW01LU0dEL|pB|9& z;8X{hw1S+Y(#L~*;9Qj4$Sk6=t8vTN4b+o1j^%I=2{~}n-Q$&L#-E@y2GMK`c@`_s zSEt6YT=nrMFIuHvX9LJ;=x;~z?s0RD$L;FW^|@Ti=+=TQ(+7Wz{>R ze&&vQT(Ylficqfrtjn=kq3i#bh{=Of9pq9A$jw!0D;1@^V~{B=ag+))XLPQB-U+1^`^##7olu$rY&9V1yVg0Rs0g~6XuFDiwIIq zgfoudk@zwcc6+HAj!~qe8tYSZszZy@MwjHeg7&( z049IsJ1f6A8@n&qdSQ#nTQ@=TpjzJ&-m)cmYOwv5d^E3h#8~T=5SO8Zf{C9Vt&HN? za~H2KD2$-%o{p}2=I3URBsi;mP{XSVSKgLQpgJ%|x&Ep_jmBYqJCWFU&})*7%d1(N zBnu~##3I}w1!>wS)VH8`95sfh3PK3L=*4IUx&*UKpga&|Lr4bn(!>Wl@DK`)H0r5b z?SN!-LC~NEjHAZRWeplsydfLW)gZW5yoMse0VCR#bop#vUjzo?(F#?GC}XVEIQlpT zUhc#j5)hHhh+JvqG6@Gag^+!yHE~sb_(KMgAmdS2U-|JSm(DwUBkG!v|AsmTyb0OP zE){P{sDq-)4Sy>DBGSU8*Dhf<5$kqL>vG+<<-7EIv7P(Ok_L6#yiJP^rnRXaXX#8Cn!k>H595E`*Z1b@E) zO&|vMPB&0sL*a&7UkfGP6xiMt{KZg+iGN2%X%w!$`m5q)2SU10$x`T7_ zGWj?&Ngf6UeiPoq{5SX>@$KYa&(F>KhNqgpjkl23gLf7W1Ai=!CSN+=1fH|J@_crD zqC87@f_T1d7Bo1= 2: + metrics = result[1] # 通常第二个返回值是metrics else: - log_message(f"不支持的模型类型: {model_type}", 'error') - return None + log_message(f"⚠️ 训练器返回格式未知,无法直接提取metrics: {type(result)}", 'warning') + metrics = None + # 检查和打印返回的metrics log_message(f"📊 训练完成,检查返回的metrics: {metrics}") diff --git a/server/models/cnn_bilstm_attention.py b/server/models/cnn_bilstm_attention.py new file mode 100644 index 0000000..65d5ce9 --- /dev/null +++ b/server/models/cnn_bilstm_attention.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- +""" +CNN-BiLSTM-Attention 模型定义,适配药店销售预测系统。 +原始代码来源: python机器学习回归全家桶 +""" + +import torch +import torch.nn as nn + +# 注意:由于原始代码使用了 TensorFlow/Keras 的层,我们将在这里创建一个 PyTorch 的等效实现。 +# 这是一个更健壮、更符合现有系统架构的做法。 + +class Attention(nn.Module): + """ + PyTorch 实现的注意力机制。 + """ + def __init__(self, feature_dim, step_dim, bias=True, **kwargs): + super(Attention, self).__init__(**kwargs) + + self.supports_masking = True + self.bias = bias + self.feature_dim = feature_dim + self.step_dim = step_dim + self.features_dim = 0 + + weight = torch.zeros(feature_dim, 1) + nn.init.xavier_uniform_(weight) + self.weight = nn.Parameter(weight) + + if bias: + self.b = nn.Parameter(torch.zeros(step_dim)) + + def forward(self, x, mask=None): + feature_dim = self.feature_dim + step_dim = self.step_dim + + eij = torch.mm( + x.contiguous().view(-1, feature_dim), + self.weight + ).view(-1, step_dim) + + if self.bias: + eij = eij + self.b + + eij = torch.tanh(eij) + a = torch.exp(eij) + + if mask is not None: + a = a * mask + + a = a / (torch.sum(a, 1, keepdim=True) + 1e-10) + + weighted_input = x * torch.unsqueeze(a, -1) + return torch.sum(weighted_input, 1) + + +class CnnBiLstmAttention(nn.Module): + """ + CNN-BiLSTM-Attention 模型的 PyTorch 实现。 + """ + def __init__(self, input_dim, output_dim, sequence_length, cnn_filters=64, cnn_kernel_size=1, lstm_units=128): + super(CnnBiLstmAttention, self).__init__() + self.sequence_length = sequence_length + self.cnn_filters = cnn_filters + self.lstm_units = lstm_units + + # CNN 层 + self.conv1d = nn.Conv1d(in_channels=input_dim, out_channels=cnn_filters, kernel_size=cnn_kernel_size) + self.relu = nn.ReLU() + self.maxpool = nn.MaxPool1d(kernel_size=1) + + # BiLSTM 层 + self.bilstm = nn.LSTM(input_size=cnn_filters, hidden_size=lstm_units, num_layers=1, batch_first=True, bidirectional=True) + + # Attention 层 + self.attention = Attention(feature_dim=lstm_units * 2, step_dim=sequence_length) + + # 全连接输出层 + self.dense = nn.Linear(lstm_units * 2, output_dim) + + def forward(self, x): + # 输入 x 的形状: (batch_size, sequence_length, input_dim) + + # CNN 处理 + x = x.permute(0, 2, 1) # 转换为 (batch_size, input_dim, sequence_length) 以适应 Conv1d + x = self.conv1d(x) + x = self.relu(x) + x = x.permute(0, 2, 1) # 转换回 (batch_size, sequence_length, cnn_filters) + + # BiLSTM 处理 + lstm_out, _ = self.bilstm(x) # lstm_out 形状: (batch_size, sequence_length, lstm_units * 2) + + # Attention 处理 + # 注意:这里的 Attention 实现可能需要根据具体任务微调 + # 一个简化的方法是直接使用 LSTM 的最终隐藏状态或输出 + # 这里我们先用一个简化的逻辑:直接展平 LSTM 输出 + attention_out = self.attention(lstm_out) + + # 全连接层输出 + output = self.dense(attention_out) + + return output diff --git a/server/models/model_registry.py b/server/models/model_registry.py new file mode 100644 index 0000000..661405a --- /dev/null +++ b/server/models/model_registry.py @@ -0,0 +1,64 @@ +""" +模型注册表 +用于解耦模型的调用和实现,支持插件式扩展新模型。 +""" + +# 训练器注册表 +TRAINER_REGISTRY = {} + +def register_trainer(name, func): + """ + 注册一个模型训练器。 + + 参数: + name (str): 模型类型名称 (e.g., 'xgboost') + func (function): 对应的训练函数 + """ + if name in TRAINER_REGISTRY: + print(f"警告: 模型训练器 '{name}' 已被覆盖注册。") + TRAINER_REGISTRY[name] = func + print(f"✅ 已注册训练器: {name}") + +def get_trainer(name): + """ + 根据模型类型名称获取一个已注册的训练器。 + """ + if name not in TRAINER_REGISTRY: + # 在打印可用训练器之前,确保它们已经被加载 + from trainers import discover_trainers + discover_trainers() + if name not in TRAINER_REGISTRY: + raise ValueError(f"未注册的模型训练器: '{name}'. 可用: {list(TRAINER_REGISTRY.keys())}") + return TRAINER_REGISTRY[name] + +# --- 预测器注册表 --- + +# 预测器函数需要一个统一的接口,例如: +# def predictor_function(model, checkpoint, **kwargs): -> predictions + +PREDICTOR_REGISTRY = {} + +def register_predictor(name, func): + """ + 注册一个模型预测器。 + """ + if name in PREDICTOR_REGISTRY: + print(f"警告: 模型预测器 '{name}' 已被覆盖注册。") + PREDICTOR_REGISTRY[name] = func + +def get_predictor(name): + """ + 根据模型类型名称获取一个已注册的预测器。 + 如果找不到特定预测器,可以返回一个默认的。 + """ + return PREDICTOR_REGISTRY.get(name, PREDICTOR_REGISTRY.get('default')) + +# 默认的PyTorch预测逻辑可以被注册为 'default' +def register_default_predictors(): + from predictors.model_predictor import default_pytorch_predictor + register_predictor('default', default_pytorch_predictor) + # 如果其他PyTorch模型有特殊预测逻辑,也可以在这里注册 + # register_predictor('kan', kan_predictor_func) + +# 注意:这个函数的调用时机很重要,需要在应用启动时执行一次。 +# 我们可以暂时在 model_predictor.py 导入注册表后调用它。 \ No newline at end of file diff --git a/server/prediction_history.db b/server/prediction_history.db deleted file mode 100644 index 0a671d01ba7d8597dbde93982e9dc84efb597bbf..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 53248 zcmeI5Z)_9i9mjqCn>e-~AQVRHHp*M zsQ~Q_q$DeAE0vC=ZPvopZB@6T(v5X#tAaLZQrBs(_OH@3X)lN!v`Ld%ZR?A@*z-Jh zk9~*jm;hQC`W?dg{vQACd7j_rdG37A^Z0!G9va9NGU8ZndLmU270Wt{AXs`u(PFW9 z_5b7#+uY>Uu50>7Fz9;HYuP>*^7=v+k5jTZ-#|b0w0ZyIUTpb$Q^x%$dK4kwi{uU1 zfB+Bx0zlxlCveQ;3dBOf@j!MmojH)DYC4ym87<_?!j$Um8;-{kadG!x zfBX@#{$AqHplHsoL(e#wO^0T?1XloKVJ=Wejf`ierZeg6Xd#=MOdihCy^dC=SoMMv zvHJ(&V)grChg;OQY+6jjA4!ONhj;IZ4IdC6j33xmezQ^w&1rBbAr9^z7^wD`Oig6! z`Ap=}nek-d=u`s_UX|)n@`coNA(>8*k8D=iYLBs*!VGDTbn0lH;7-IJh!2a^UZc|) zl1Dm8KJ)A4lN~3mnkpPNJ1`>K^hP4}d^k4T_odiyhb)D7DkH$tu_T2;vro8Oflx@8 z-^v=!+`?w7$<$!x_Y9528%=Y#rXJ;1<6ef=Ilo4E7)mqdq9^B# z$Eb;E=RT(bvlu0Sv-%(qsK+dONfyxl}!ORK@2_gG_4 ztCQFESPS>@93f9W1q@}>0-w(wtIM68lKH8jX0zp{GLzZKL#v$^$ho;NlQ+~NR6fwR zYA1}X)h)iZ_N?B2W{RGvv=cc$iF8cb!t{S`s*s%^ z6C&y4W2s4+kYsbyvLr{sQdbzOu8K^^%1#CEjNt8EJyBVfDx=B#L4D%=peetDEW=6$ z$!Jd4S9+pSk9_k1$WfA+7J%g!l>pM|NV+E#84E|cMx?OXlaj-!jHHA| zy0Maqrn}Q!nUNa}07rMKcsq`EN3mf%f1@pc32PPrUDT)Je>6w_eRSLPaHtRv00KY& z2mk>f00e*l5C8%|00;nqJBUDsZNG3m(>HY4{-NucCb1I7j_a9@pzHq*?}rxjO%(L~ z$oEAeVFLm{00;m9AOHk_01yBIKmZ5;f!mtEl*8>0V_{2sP*L_tQl#|#OUqBcyZr7# zc=_oCs@SWRdS&VS^5xmDjOHdL=q)El$12(k5)I zG#;N~i0>H|pZ5CJ$>N3QNa1^p!I9#+NPKGvZpun=T6u7CCAcd~-y!*HuPYbNm!5iE z!)f`pIY1?hBWo!!FSz|mr9@m@SXlYS1=fH|C$u7#%Ys2A)v$apHl8|^Kb)P?sM2O! zOQCfA-+_K_L7$?3pd$JUk+1;)AOHk_01yBIKmZ5;0U!VbfB+Bx0{;gDnrwc7&Hy=d zq1tt!*mNOVT{eF+@xg5U-*?`E-bcrf>idL9*nj{K00KY&2mk>f00e*l5C8%|00?}x z1o|B=f9GAUR)yR|A1S_kdFA<+i|5|C`n`qHPfiz?PL^iRue|nZ@y*k{Ul?OsY4LQ8 zuPxkXcloz&(Q;8s?_Df?>ul-bxxUg<^JSyAU13}4`L}Ci&~J13v7Ui)^@YnT&;FQv z)uH&x!piL1%ZsOqrxvsE?4Q*V%3kQhiLSI^9s78ZL$ z8u8kZ#00`auLv$ad3Jg3bm{odbRNuHdFIXD9WT=DrP^7*`u}I!rJ-~n00e*l5C8%| z00;m9AOHk_01yBIH%fpm^3xd-w*F6qw*F6qw*F6qw*F59TmKKXELza-&=1j3G=R4F zKKA{}_k!;+-@U%|-Yedpdrx`4>Xp1U&mTRnd%o@&@_er4pDiDdU0?$OKmZ5;0U!Vb zfB+Bx0zlx_5$Fs${r>i5b~87*RF70{@U5{X*}2)|tgh}HFgYu$J8v>MlOKSmcW=}C zjB8cHZ*rzCt2l2gJFATYY&AJI2ylbRxj}&IP0kGhTxW7_5Fjc$D~$v6nVcI0=ruVv z2+(74ZV+IL$+y)_?E9_nvhOXj3v56D2mk>f00e*l5C8%|00;m9AOHkzB>}J9uiej12jD@Y zrEWG_O2BBzn~WCrbBoP1Z8TbHtI<+67%h3d(ZcJvrO>({qow+cmf|&9vd3uQ7H+XQ z6SvV)n~j#zWVB?L(ZWt{v55(X(NgV3OR*U(+3F!Y*}eZH2CRMmpZxy+ztG>%U(p}X zhqtonATJ;Q1b_e#00KY&2mk>f00e*l5C8%|;MNjw+k%4j%m0pwv{$69BCXA~V6*0{ zjsGuM(8uTs`Uw37UAnckg4}=r5C8%|00;m9AOHk_01yBIKmZ5;ftw}Zv9$^Op@Vj# zVQsOsHS;F}dTm>T^3xCkz%S7|=mL5Ty@Jl5r^%fF$I#>GFiN6_ z(H_)~?nW{SkzHT|0zd!=00AHX1b_e#00KY&2mk>faK{kPAD9;javdiu<<_? bR=r}F%}0GZp^@K~%y! diff --git a/server/predictors/model_predictor.py b/server/predictors/model_predictor.py index 1f57caf..2463543 100644 --- a/server/predictors/model_predictor.py +++ b/server/predictors/model_predictor.py @@ -18,39 +18,90 @@ from models.mlstm_model import MLSTMTransformer as MatrixLSTM from models.kan_model import KANForecaster from models.tcn_model import TCNForecaster from models.optimized_kan_forecaster import OptimizedKANForecaster +from models.cnn_bilstm_attention import CnnBiLstmAttention +import xgboost as xgb from analysis.trend_analysis import analyze_prediction_result from utils.visualization import plot_prediction_results from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data from core.config import DEVICE, get_model_file_path, DEFAULT_DATA_PATH +from models.model_registry import get_predictor, register_predictor + +def default_pytorch_predictor(model, checkpoint, product_df, future_days, start_date, history_lookback_days): + """ + 默认的PyTorch模型预测逻辑,支持自动回归。 + """ + config = checkpoint['config'] + scaler_X = checkpoint['scaler_X'] + scaler_y = checkpoint['scaler_y'] + features = config.get('features', ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']) + sequence_length = config['sequence_length'] + + if start_date: + start_date_dt = pd.to_datetime(start_date) + prediction_input_df = product_df[product_df['date'] < start_date_dt].tail(sequence_length) + else: + prediction_input_df = product_df.tail(sequence_length) + start_date_dt = product_df['date'].iloc[-1] + timedelta(days=1) + + if len(prediction_input_df) < sequence_length: + raise ValueError(f"预测所需的历史数据不足。需要 {sequence_length} 天, 但只有 {len(prediction_input_df)} 天。") + + history_for_chart_df = product_df[product_df['date'] < start_date_dt].tail(history_lookback_days) + + all_predictions = [] + current_sequence_df = prediction_input_df.copy() + + for _ in range(future_days): + X_current_scaled = scaler_X.transform(current_sequence_df[features].values) + # **核心改进**: 智能判断模型类型并调用相应的预测方法 + if isinstance(model, xgb.Booster): + # XGBoost 模型预测路径 + X_input_reshaped = X_current_scaled.reshape(1, -1) + d_input = xgb.DMatrix(X_input_reshaped) + # **关键修复**: 使用 best_iteration 进行预测,以匹配早停策略 + y_pred_scaled = model.predict(d_input, iteration_range=(0, model.best_iteration)) + next_step_pred_scaled = y_pred_scaled.reshape(1, -1) + else: + # 默认 PyTorch 模型预测路径 + X_input = torch.tensor(X_current_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE) + with torch.no_grad(): + y_pred_scaled = model(X_input).cpu().numpy() + next_step_pred_scaled = y_pred_scaled[0, 0].reshape(1, -1) + next_step_pred_unscaled = float(max(0, scaler_y.inverse_transform(next_step_pred_scaled)[0][0])) + + next_date = current_sequence_df['date'].iloc[-1] + timedelta(days=1) + all_predictions.append({'date': next_date, 'predicted_sales': next_step_pred_unscaled}) + + new_row = {'date': next_date, 'sales': next_step_pred_unscaled, 'weekday': next_date.weekday(), 'month': next_date.month, 'is_holiday': 0, 'is_weekend': 1 if next_date.weekday() >= 5 else 0, 'is_promotion': 0, 'temperature': current_sequence_df['temperature'].iloc[-1]} + new_row_df = pd.DataFrame([new_row]) + current_sequence_df = pd.concat([current_sequence_df.iloc[1:], new_row_df], ignore_index=True) + + predictions_df = pd.DataFrame(all_predictions) + return predictions_df, history_for_chart_df, prediction_input_df + +# 注册默认的PyTorch预测器 +register_predictor('default', default_pytorch_predictor) +# 将增强后的默认预测器也注册给xgboost +register_predictor('xgboost', default_pytorch_predictor) +# 将新模型也注册给默认预测器 +register_predictor('cnn_bilstm_attention', default_pytorch_predictor) + def load_model_and_predict(model_path: str, product_id: str, model_type: str, store_id: Optional[str] = None, future_days: int = 7, start_date: Optional[str] = None, analyze_result: bool = False, version: Optional[str] = None, training_mode: str = 'product', history_lookback_days: int = 30): """ - 加载已训练的模型并进行预测 (v3版 - 支持自动回归) - - 参数: - ... (同上, 新增 history_lookback_days) - history_lookback_days: 用于图表展示的历史数据天数 - - 返回: - 预测结果和分析 + 加载已训练的模型并进行预测 (v4版 - 插件式架构) """ try: - print(f"v3版预测函数启动,模型路径: {model_path}, 预测天数: {future_days}, 历史回看: {history_lookback_days}") - if not os.path.exists(model_path): - print(f"模型文件 {model_path} 不存在") - return None - - # 加载销售数据 + raise FileNotFoundError(f"模型文件 {model_path} 不存在") + + # --- 数据加载部分保持不变 --- from utils.multi_store_data_utils import aggregate_multi_store_data if training_mode == 'store' and store_id: - # 先从原始数据加载一次以获取店铺名称,聚合会丢失此信息 from utils.multi_store_data_utils import load_multi_store_data store_df_for_name = load_multi_store_data(store_id=store_id) product_name = store_df_for_name['store_name'].iloc[0] if not store_df_for_name.empty else f"店铺 {store_id}" - - # 然后再进行聚合获取用于预测的数据 product_df = aggregate_multi_store_data(store_id=store_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH) elif training_mode == 'global': product_df = aggregate_multi_store_data(aggregation_method='sum', file_path=DEFAULT_DATA_PATH) @@ -60,124 +111,75 @@ def load_model_and_predict(model_path: str, product_id: str, model_type: str, st product_name = product_df['product_name'].iloc[0] if not product_df.empty else product_id if product_df.empty: - print(f"产品 {product_id} 或店铺 {store_id} 没有销售数据") - return None + raise ValueError(f"产品 {product_id} 或店铺 {store_id} 没有销售数据") - # 加载模型和配置 + # --- 模型加载与实例化 (重构) --- try: torch.serialization.add_safe_globals([sklearn.preprocessing._data.MinMaxScaler]) except Exception: pass checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False) - if 'config' not in checkpoint or 'scaler_X' not in checkpoint or 'scaler_y' not in checkpoint: - print("模型文件不完整,缺少config或scaler") - return None - - config = checkpoint['config'] - scaler_X = checkpoint['scaler_X'] - scaler_y = checkpoint['scaler_y'] - - # 创建模型实例 - # (此处省略了与原版本相同的模型创建代码,以保持简洁) - if model_type == 'transformer': - model = TimeSeriesTransformer(num_features=config['input_dim'], d_model=config['hidden_size'], nhead=config['num_heads'], num_encoder_layers=config['num_layers'], dim_feedforward=config['hidden_size'] * 2, dropout=config['dropout'], output_sequence_length=config['output_dim'], seq_length=config['sequence_length'], batch_size=32).to(DEVICE) - elif model_type == 'mlstm': - model = MatrixLSTM(num_features=config['input_dim'], hidden_size=config['hidden_size'], mlstm_layers=config['mlstm_layers'], embed_dim=config.get('embed_dim', 32), dense_dim=config.get('dense_dim', 32), num_heads=config.get('num_heads', 4), dropout_rate=config['dropout_rate'], num_blocks=config.get('num_blocks', 3), output_sequence_length=config['output_dim']).to(DEVICE) - elif model_type == 'kan': - model = KANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE) - elif model_type == 'optimized_kan': - model = OptimizedKANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE) - elif model_type == 'tcn': - model = TCNForecaster(num_features=config['input_dim'], output_sequence_length=config['output_dim'], num_channels=[config['hidden_size']] * config['num_layers'], kernel_size=config['kernel_size'], dropout=config['dropout']).to(DEVICE) + config = checkpoint.get('config', {}) + loaded_model_type = config.get('model_type', model_type) # 优先使用模型内保存的类型 + + # 根据模型类型决定如何获取模型实例 + if loaded_model_type == 'xgboost': + # 对于XGBoost, 模型对象直接保存在'model_state_dict'键中 + model = checkpoint['model_state_dict'] else: - print(f"不支持的模型类型: {model_type}"); return None - - model.load_state_dict(checkpoint['model_state_dict']) - model.eval() - - # --- 核心逻辑修改:自动回归预测 --- - features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'] - sequence_length = config['sequence_length'] - - # 确定预测的起始点 - if start_date: - start_date_dt = pd.to_datetime(start_date) - # 获取预测开始日期前的 `sequence_length` 天数据作为初始输入 - prediction_input_df = product_df[product_df['date'] < start_date_dt].tail(sequence_length) - else: - # 如果未指定开始日期,则从数据的最后一天开始预测 - prediction_input_df = product_df.tail(sequence_length) - start_date_dt = product_df['date'].iloc[-1] + timedelta(days=1) - - if len(prediction_input_df) < sequence_length: - print(f"错误: 预测所需的历史数据不足。需要 {sequence_length} 天, 但只有 {len(prediction_input_df)} 天。") - return None - - # 准备用于图表展示的历史数据 - history_for_chart_df = product_df[product_df['date'] < start_date_dt].tail(history_lookback_days) - - # 自动回归预测循环 - all_predictions = [] - current_sequence_df = prediction_input_df.copy() - - print(f"开始自动回归预测,共 {future_days} 天...") - for i in range(future_days): - # 准备当前序列的输入张量 - X_current_scaled = scaler_X.transform(current_sequence_df[features].values) - X_input = torch.tensor(X_current_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE) - - # 模型进行一次预测(可能预测出多个点,但我们只用第一个) - with torch.no_grad(): - y_pred_scaled = model(X_input).cpu().numpy() + # 对于PyTorch模型, 需要重新构建实例并加载state_dict + if loaded_model_type == 'transformer': + model = TimeSeriesTransformer(num_features=config['input_dim'], d_model=config['hidden_size'], nhead=config['num_heads'], num_encoder_layers=config['num_layers'], dim_feedforward=config['hidden_size'] * 2, dropout=config['dropout'], output_sequence_length=config['output_dim'], seq_length=config['sequence_length'], batch_size=32).to(DEVICE) + elif loaded_model_type == 'mlstm': + model = MatrixLSTM(num_features=config['input_dim'], hidden_size=config['hidden_size'], mlstm_layers=config['mlstm_layers'], embed_dim=config.get('embed_dim', 32), dense_dim=config.get('dense_dim', 32), num_heads=config.get('num_heads', 4), dropout_rate=config['dropout_rate'], num_blocks=config.get('num_blocks', 3), output_sequence_length=config['output_dim']).to(DEVICE) + elif loaded_model_type == 'kan': + model = KANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE) + elif loaded_model_type == 'optimized_kan': + model = OptimizedKANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE) + elif loaded_model_type == 'tcn': + model = TCNForecaster(num_features=config['input_dim'], output_sequence_length=config['output_dim'], num_channels=[config['hidden_size']] * config['num_layers'], kernel_size=config['kernel_size'], dropout=config['dropout']).to(DEVICE) + elif loaded_model_type == 'cnn_bilstm_attention': + model = CnnBiLstmAttention( + input_dim=config['input_dim'], + output_dim=config['output_dim'], + sequence_length=config['sequence_length'] + ).to(DEVICE) + else: + raise ValueError(f"不支持的模型类型: {loaded_model_type}") - # 提取下一个时间点的预测值 - next_step_pred_scaled = y_pred_scaled[0, 0].reshape(1, -1) - next_step_pred_unscaled = scaler_y.inverse_transform(next_step_pred_scaled)[0][0] - next_step_pred_unscaled = float(max(0, next_step_pred_unscaled)) # 确保销量不为负,并转换为标准float + model.load_state_dict(checkpoint['model_state_dict']) + model.eval() - # 获取新预测的日期 - next_date = current_sequence_df['date'].iloc[-1] + timedelta(days=1) - all_predictions.append({'date': next_date, 'predicted_sales': next_step_pred_unscaled}) + # --- 动态调用预测器 --- + predictor_function = get_predictor(loaded_model_type) + if not predictor_function: + raise ValueError(f"找不到模型类型 '{loaded_model_type}' 的预测器实现") - # 构建新的一行数据,用于更新输入序列 - new_row = { - 'date': next_date, - 'sales': next_step_pred_unscaled, - 'weekday': next_date.weekday(), - 'month': next_date.month, - 'is_holiday': 0, - 'is_weekend': 1 if next_date.weekday() >= 5 else 0, - 'is_promotion': 0, - 'temperature': current_sequence_df['temperature'].iloc[-1] # 沿用最后一天的温度 - } - - # 更新序列:移除最旧的一行,添加最新预测的一行 - new_row_df = pd.DataFrame([new_row]) - current_sequence_df = pd.concat([current_sequence_df.iloc[1:], new_row_df], ignore_index=True) + predictions_df, history_for_chart_df, prediction_input_df = predictor_function( + model=model, + checkpoint=checkpoint, + product_df=product_df, + future_days=future_days, + start_date=start_date, + history_lookback_days=history_lookback_days + ) - predictions_df = pd.DataFrame(all_predictions) - print(f"自动回归预测完成,生成 {len(predictions_df)} 条预测数据。") - - # 分析与可视化 + # --- 分析与返回部分保持不变 --- analysis = None if analyze_result: try: - y_pred_for_analysis = predictions_df['predicted_sales'].values - # 使用初始输入序列的特征进行分析 - initial_features_for_analysis = prediction_input_df[features].values - analysis = analyze_prediction_result(product_id, model_type, y_pred_for_analysis, initial_features_for_analysis) + analysis = analyze_prediction_result(product_id, loaded_model_type, predictions_df['predicted_sales'].values, prediction_input_df[config.get('features')].values) except Exception as e: print(f"分析预测结果失败: {str(e)}") - # 在返回前,将DataFrame转换为前端期望的JSON数组格式 history_data_json = history_for_chart_df.to_dict('records') if not history_for_chart_df.empty else [] prediction_data_json = predictions_df.to_dict('records') if not predictions_df.empty else [] return { 'product_id': product_id, 'product_name': product_name, - 'model_type': model_type, - 'predictions': prediction_data_json, # 兼容旧字段,使用已转换的json + 'model_type': loaded_model_type, + 'predictions': prediction_data_json, 'prediction_data': prediction_data_json, 'history_data': history_data_json, 'analysis': analysis diff --git a/server/trainers/__init__.py b/server/trainers/__init__.py index 5c68ae3..da31755 100644 --- a/server/trainers/__init__.py +++ b/server/trainers/__init__.py @@ -2,18 +2,44 @@ 药店销售预测系统 - 模型训练模块 """ -from .mlstm_trainer import train_product_model_with_mlstm -from .kan_trainer import train_product_model_with_kan -from .tcn_trainer import train_product_model_with_tcn -from .transformer_trainer import train_product_model_with_transformer +import os +import glob +import importlib -# 默认训练函数 -from .mlstm_trainer import train_product_model_with_mlstm as train_product_model +_TRAINERS_LOADED = False + +def discover_trainers(): + """ + 自动发现并加载所有训练器插件。 + 使用一个标志位确保这个过程只执行一次。 + """ + global _TRAINERS_LOADED + if _TRAINERS_LOADED: + return + + print("🚀 开始发现并加载训练器插件...") + + package_dir = os.path.dirname(__file__) + module_name = __name__ + + trainer_files = glob.glob(os.path.join(package_dir, "*_trainer.py")) + + for f in trainer_files: + base_name = os.path.basename(f) + if base_name.startswith('__'): + continue + + module_stem = base_name.replace('.py', '') + + try: + # 动态导入模块以触发自注册 + importlib.import_module(f".{module_stem}", package=module_name) + except ImportError as e: + print(f"⚠️ 加载训练器 {module_stem} 失败: {e}") + + _TRAINERS_LOADED = True + print("✅ 所有训练器插件加载完成。") + +# 在包被首次导入时,自动执行发现过程 +discover_trainers() -__all__ = [ - 'train_product_model', - 'train_product_model_with_mlstm', - 'train_product_model_with_kan', - 'train_product_model_with_tcn', - 'train_product_model_with_transformer' -] diff --git a/server/trainers/cnn_bilstm_attention_trainer.py b/server/trainers/cnn_bilstm_attention_trainer.py new file mode 100644 index 0000000..35a9149 --- /dev/null +++ b/server/trainers/cnn_bilstm_attention_trainer.py @@ -0,0 +1,118 @@ +# -*- coding: utf-8 -*- +""" +CNN-BiLSTM-Attention 模型训练器 +""" + +import torch +import torch.optim as optim +import numpy as np + +from models.model_registry import register_trainer +from utils.model_manager import model_manager +from analysis.metrics import evaluate_model +from utils.data_utils import create_dataset +from sklearn.preprocessing import MinMaxScaler + +# 导入新创建的模型 +from models.cnn_bilstm_attention import CnnBiLstmAttention + +def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs): + """ + 使用 CNN-BiLSTM-Attention 模型进行训练。 + 函数签名遵循系统标准。 + """ + print(f"🚀 CNN-BiLSTM-Attention 训练器启动: model_identifier='{model_identifier}'") + + # --- 1. 数据准备 --- + features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'] + + X = product_df[features].values + y = product_df[['sales']].values + + scaler_X = MinMaxScaler(feature_range=(0, 1)) + scaler_y = MinMaxScaler(feature_range=(0, 1)) + + X_scaled = scaler_X.fit_transform(X) + y_scaled = scaler_y.fit_transform(y) + + train_size = int(len(X_scaled) * 0.8) + X_train_raw, X_test_raw = X_scaled[:train_size], X_scaled[train_size:] + y_train_raw, y_test_raw = y_scaled[:train_size], y_scaled[train_size:] + + trainX, trainY = create_dataset(X_train_raw, y_train_raw, sequence_length, forecast_horizon) + testX, testY = create_dataset(X_test_raw, y_test_raw, sequence_length, forecast_horizon) + + # 转换为 PyTorch Tensors + trainX = torch.from_numpy(trainX).float() + trainY = torch.from_numpy(trainY).float() + testX = torch.from_numpy(testX).float() + testY = torch.from_numpy(testY).float() + + # --- 2. 实例化模型和优化器 --- + input_dim = trainX.shape[2] + + model = CnnBiLstmAttention( + input_dim=input_dim, + output_dim=forecast_horizon, + sequence_length=sequence_length + ) + + optimizer = optim.Adam(model.parameters(), lr=kwargs.get('learning_rate', 0.001)) + criterion = torch.nn.MSELoss() + + # --- 3. 训练循环 --- + print("开始训练 CNN-BiLSTM-Attention 模型...") + for epoch in range(epochs): + model.train() + optimizer.zero_grad() + + outputs = model(trainX) + loss = criterion(outputs, trainY.squeeze(-1)) # 确保目标维度匹配 + + loss.backward() + optimizer.step() + + if (epoch + 1) % 10 == 0: + print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}') + + # --- 4. 模型评估 --- + model.eval() + with torch.no_grad(): + test_pred_scaled = model(testX) + + test_pred_unscaled = scaler_y.inverse_transform(test_pred_scaled.numpy()) + test_true_unscaled = scaler_y.inverse_transform(testY.squeeze(-1).numpy()) + + metrics = evaluate_model(test_true_unscaled.flatten(), test_pred_unscaled.flatten()) + print(f"模型评估完成: RMSE={metrics['rmse']:.4f}") + + # --- 5. 模型保存 --- + model_data = { + 'model_state_dict': model.state_dict(), + 'scaler_X': scaler_X, + 'scaler_y': scaler_y, + 'config': { + 'model_type': 'cnn_bilstm_attention', + 'input_dim': input_dim, + 'output_dim': forecast_horizon, + 'sequence_length': sequence_length, + 'features': features + }, + 'metrics': metrics + } + + final_model_path, final_version = model_manager.save_model( + model_data=model_data, + product_id=product_id, + model_type='cnn_bilstm_attention', + store_id=store_id, + training_mode=training_mode, + aggregation_method=aggregation_method, + product_name=product_df['product_name'].iloc[0] + ) + + print(f"✅ CNN-BiLSTM-Attention 模型已保存,版本: {final_version}") + return model, metrics, final_version, final_model_path + +# --- 关键步骤: 将训练器注册到系统中 --- +register_trainer('cnn_bilstm_attention', train_with_cnn_bilstm_attention) \ No newline at end of file diff --git a/server/trainers/kan_trainer.py b/server/trainers/kan_trainer.py index 7aa1d95..0d71e7d 100644 --- a/server/trainers/kan_trainer.py +++ b/server/trainers/kan_trainer.py @@ -349,4 +349,9 @@ def train_product_model_with_kan(product_id, model_identifier, product_df=None, print(f"最终模型已保存,版本: {final_version}, 路径: {final_model_path}") - return model, metrics \ No newline at end of file + return model, metrics + +# --- 将此训练器注册到系统中 --- +from models.model_registry import register_trainer +register_trainer('kan', train_product_model_with_kan) +register_trainer('optimized_kan', train_product_model_with_kan) \ No newline at end of file diff --git a/server/trainers/mlstm_trainer.py b/server/trainers/mlstm_trainer.py index 6098c11..d3d0bf1 100644 --- a/server/trainers/mlstm_trainer.py +++ b/server/trainers/mlstm_trainer.py @@ -514,4 +514,8 @@ def train_product_model_with_mlstm( emit_progress(f"✅ mLSTM模型训练完成!版本 {final_version} 已保存", progress=100, metrics=final_metrics) - return model, metrics, epochs, final_model_path \ No newline at end of file + return model, metrics, epochs, final_model_path + +# --- 将此训练器注册到系统中 --- +from models.model_registry import register_trainer +register_trainer('mlstm', train_product_model_with_mlstm) \ No newline at end of file diff --git a/server/trainers/tcn_trainer.py b/server/trainers/tcn_trainer.py index f4a5638..34f10ba 100644 --- a/server/trainers/tcn_trainer.py +++ b/server/trainers/tcn_trainer.py @@ -379,4 +379,8 @@ def train_product_model_with_tcn( emit_progress(f"模型训练完成!版本 {final_version} 已保存", progress=100, metrics=final_metrics) - return model, metrics, epochs, final_model_path \ No newline at end of file + return model, metrics, epochs, final_model_path + +# --- 将此训练器注册到系统中 --- +from models.model_registry import register_trainer +register_trainer('tcn', train_product_model_with_tcn) \ No newline at end of file diff --git a/server/trainers/transformer_trainer.py b/server/trainers/transformer_trainer.py index 4a70d9a..06d6af3 100644 --- a/server/trainers/transformer_trainer.py +++ b/server/trainers/transformer_trainer.py @@ -406,4 +406,8 @@ def train_product_model_with_transformer( 'version': final_version } - return model, final_metrics, epochs \ No newline at end of file + return model, final_metrics, epochs + +# --- 将此训练器注册到系统中 --- +from models.model_registry import register_trainer +register_trainer('transformer', train_product_model_with_transformer) \ No newline at end of file diff --git a/server/trainers/xgboost_trainer.py b/server/trainers/xgboost_trainer.py new file mode 100644 index 0000000..9dff330 --- /dev/null +++ b/server/trainers/xgboost_trainer.py @@ -0,0 +1,142 @@ +""" +药店销售预测系统 - XGBoost 模型训练器 (插件式) +""" + +import time +import pandas as pd +import numpy as np +import xgboost as xgb +from sklearn.preprocessing import MinMaxScaler +from xgboost.callback import EarlyStopping + +# 导入核心工具 +from utils.data_utils import create_dataset +from analysis.metrics import evaluate_model +from utils.model_manager import model_manager +from models.model_registry import register_trainer + +def train_product_model_with_xgboost(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs): + """ + 使用 XGBoost 模型训练产品销售预测模型。 + 此函数签名与其他训练器保持一致,以兼容注册表调用。 + """ + print(f"🚀 XGBoost训练器启动: model_identifier='{model_identifier}'") + + # --- 1. 数据准备和验证 --- + if product_df.empty: + raise ValueError(f"产品 {product_id} 没有可用的销售数据") + + min_required_samples = sequence_length + forecast_horizon + if len(product_df) < min_required_samples: + error_msg = (f"数据不足: 需要 {min_required_samples} 条, 实际 {len(product_df)} 条。") + raise ValueError(error_msg) + + product_df = product_df.sort_values('date') + product_name = product_df['product_name'].iloc[0] if 'product_name' in product_df.columns else model_identifier + + # --- 2. 数据预处理和适配 --- + features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'] + + X = product_df[features].values + y = product_df[['sales']].values + + scaler_X = MinMaxScaler(feature_range=(0, 1)) + scaler_y = MinMaxScaler(feature_range=(0, 1)) + + X_scaled = scaler_X.fit_transform(X) + y_scaled = scaler_y.fit_transform(y) + + train_size = int(len(X_scaled) * 0.8) + X_train_raw, X_test_raw = X_scaled[:train_size], X_scaled[train_size:] + y_train_raw, y_test_raw = y_scaled[:train_size], y_scaled[train_size:] + + trainX, trainY = create_dataset(X_train_raw, y_train_raw, sequence_length, forecast_horizon) + testX, testY = create_dataset(X_test_raw, y_test_raw, sequence_length, forecast_horizon) + + # **关键适配步骤**: XGBoost 需要二维输入 + trainX = trainX.reshape(trainX.shape[0], -1) + testX = testX.reshape(testX.shape[0], -1) + + # **关键适配**: 转换为 XGBoost 核心 DMatrix 格式,以使用稳定的 xgb.train API + dtrain = xgb.DMatrix(trainX, label=trainY) + dtest = xgb.DMatrix(testX, label=testY) + + # --- 3. 模型训练 (使用核心 xgb.train API) --- + xgb_params = { + 'learning_rate': kwargs.get('learning_rate', 0.08), + 'subsample': kwargs.get('subsample', 0.75), + 'colsample_bytree': kwargs.get('colsample_bytree', 1), + 'max_depth': kwargs.get('max_depth', 7), + 'gamma': kwargs.get('gamma', 0), + 'objective': 'reg:squarederror', + 'eval_metric': 'rmse', # eval_metric 在这里是原生支持的 + 'n_jobs': -1 + } + n_estimators = kwargs.get('n_estimators', 500) + + print("开始训练XGBoost模型 (使用核心xgb.train API)...") + start_time = time.time() + + evals_result = {} + model = xgb.train( + params=xgb_params, + dtrain=dtrain, + num_boost_round=n_estimators, + evals=[(dtrain, 'train'), (dtest, 'test')], + early_stopping_rounds=50, # early_stopping_rounds 在这里是原生支持的 + evals_result=evals_result, + verbose_eval=False + ) + + training_time = time.time() - start_time + print(f"XGBoost模型训练完成,耗时: {training_time:.2f}秒") + + # --- 4. 模型评估 --- + # 使用 model.best_iteration 获取最佳轮次的预测结果 + test_pred = model.predict(dtest, iteration_range=(0, model.best_iteration)) + + test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, forecast_horizon)) + test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, forecast_horizon)) + + metrics = evaluate_model(test_true_inv.flatten(), test_pred_inv.flatten()) + metrics['training_time'] = training_time + + print("\n模型评估指标:") + print(f"MSE: {metrics['mse']:.4f}, RMSE: {metrics['rmse']:.4f}, MAE: {metrics['mae']:.4f}, R²: {metrics['r2']:.4f}, MAPE: {metrics['mape']:.2f}%") + + # --- 5. 模型保存 (借道 utils.model_manager) --- + # **关键适配点**: 我们将完整的XGBoost模型对象存入字典 + # torch.save 可以序列化多种Python对象,包括sklearn模型 + model_data = { + 'model_state_dict': model, # 直接保存模型对象 + 'scaler_X': scaler_X, + 'scaler_y': scaler_y, + 'config': { + 'sequence_length': sequence_length, + 'forecast_horizon': forecast_horizon, + 'model_type': 'xgboost', + 'features': features, + 'xgb_params': xgb_params + }, + 'metrics': metrics, + 'loss_history': evals_result + } + + # 调用全局管理器进行保存,复用其命名和版本逻辑 + final_model_path, final_version = model_manager.save_model( + model_data=model_data, + product_id=product_id, + model_type='xgboost', + store_id=store_id, + training_mode=training_mode, + aggregation_method=aggregation_method, + product_name=product_name + ) + + print(f"XGBoost模型已通过统一管理器保存,版本: {final_version}, 路径: {final_model_path}") + + # 返回值遵循统一格式 + return model, metrics, final_version, final_model_path + +# --- 将此训练器注册到系统中 --- +register_trainer('xgboost', train_product_model_with_xgboost) \ No newline at end of file diff --git a/server/utils/model_manager.py b/server/utils/model_manager.py index 6ab4e46..60164d0 100644 --- a/server/utils/model_manager.py +++ b/server/utils/model_manager.py @@ -280,48 +280,41 @@ class ModelManager: if len(parts) < 3: return None # 格式不符合基本要求 - model_type = parts[0] - mode = parts[1] - + # **核心修复**: 采用更健壮的、从后往前的解析逻辑,以支持带下划线的模型名称 try: - if mode == 'store' and len(parts) >= 3: - # {model_type}_store_{store_id}_{version} - version = parts[-1] - store_id = '_'.join(parts[2:-1]) - return { - 'model_type': model_type, - 'training_mode': 'store', - 'store_id': store_id, - 'version': version, - 'product_id': None, - 'aggregation_method': None - } - elif mode == 'global' and len(parts) >= 3: - # {model_type}_global_{aggregation_method}_{version} - version = parts[-1] - aggregation_method = '_'.join(parts[2:-1]) - return { - 'model_type': model_type, - 'training_mode': 'global', - 'aggregation_method': aggregation_method, - 'version': version, - 'product_id': None, - 'store_id': None - } - elif mode == 'product' and len(parts) >= 3: - # {model_type}_product_{product_id}_{version} - version = parts[-1] - product_id = '_'.join(parts[2:-1]) + version = parts[-1] + identifier = parts[-2] + mode_candidate = parts[-3] + + if mode_candidate == 'product': + model_type = '_'.join(parts[:-3]) return { 'model_type': model_type, 'training_mode': 'product', - 'product_id': product_id, + 'product_id': identifier, 'version': version, - 'store_id': None, - 'aggregation_method': None } + elif mode_candidate == 'store': + model_type = '_'.join(parts[:-3]) + return { + 'model_type': model_type, + 'training_mode': 'store', + 'store_id': identifier, + 'version': version, + } + elif mode_candidate == 'global': + model_type = '_'.join(parts[:-3]) + return { + 'model_type': model_type, + 'training_mode': 'global', + 'aggregation_method': identifier, + 'version': version, + } + except IndexError: + # 如果文件名部分不够,则解析失败 + pass except Exception as e: - print(f"解析新版v2文件名失败 {filename}: {e}") + print(f"解析文件名失败 {filename}: {e}") return None diff --git a/xz修改记录日志和启动依赖.md b/xz修改记录日志和启动依赖.md index 00b767a..0e9e683 100644 --- a/xz修改记录日志和启动依赖.md +++ b/xz修改记录日志和启动依赖.md @@ -1,5 +1,5 @@ ### 根目录启动 -`uv pip install loguru numpy pandas torch matplotlib flask flask_cors flask_socketio flasgger scikit-learn tqdm pytorch_tcn` +`uv pip install loguru numpy pandas torch matplotlib flask flask_cors flask_socketio flasgger scikit-learn tqdm pytorch_tcn pyarrow` ### UI `npm install` `npm run dev` diff --git a/xz新模型添加流程.md b/xz新模型添加流程.md new file mode 100644 index 0000000..7e88326 --- /dev/null +++ b/xz新模型添加流程.md @@ -0,0 +1,222 @@ +# 如何向系统添加新模型 + +本指南详细说明了如何向本预测系统添加一个全新的预测模型。系统采用灵活的插件式架构,集成新模型的过程非常模块化,主要围绕 **模型(Model)**、**训练器(Trainer)** 和 **预测器(Predictor)** 这三个核心组件进行。 + +## 核心理念 + +系统的核心是 `models/model_registry.py`,它维护了两个独立的注册表:一个用于训练函数,另一个用于预测函数。添加新模型的本质就是: + +1. **定义模型**:创建模型的架构。 +2. **创建训练器**:编写一个函数来训练这个模型,并将其注册到训练器注册表。 +3. **集成预测器**:确保系统知道如何加载模型并用它来预测,然后将预测逻辑注册到预测器注册表。 + +--- + +## 第 1 步:定义模型架构 + +首先,您需要在 `ShopTRAINING/server/models/` 目录下创建一个新的 Python 文件来定义您的模型。 + +**示例:创建 `ShopTRAINING/server/models/my_new_model.py`** + +如果您的新模型是基于 PyTorch 的,它应该是一个继承自 `torch.nn.Module` 的类。 + +```python +# file: ShopTRAINING/server/models/my_new_model.py + +import torch +import torch.nn as nn + +class MyNewModel(nn.Module): + def __init__(self, input_features, hidden_size, output_sequence_length): + """ + 定义模型的层和结构。 + """ + super(MyNewModel, self).__init__() + self.layer1 = nn.Linear(input_features, hidden_size) + self.relu = nn.ReLU() + self.layer2 = nn.Linear(hidden_size, output_sequence_length) + # ... 可添加更复杂的结构 + + def forward(self, x): + """ + 定义数据通过模型的前向传播路径。 + x 的形状通常是 (batch_size, sequence_length, num_features) + """ + # 确保输入是正确的形状 + # 例如,对于简单的线性层,可能需要展平 + batch_size, seq_len, features = x.shape + x = x.view(batch_size * seq_len, features) # 展平 + + out = self.layer1(x) + out = self.relu(out) + out = self.layer2(out) + + # 恢复形状以匹配输出 + out = out.view(batch_size, seq_len, -1) + # 通常我们只关心序列的最后一个预测 + return out[:, -1, :] +``` + +--- + +## 第 2 步:创建模型训练器 + +接下来,在 `ShopTRAINING/server/trainers/` 目录下创建一个新的训练器文件。这个文件负责模型的整个训练、评估和保存流程。 + +**示例:创建 `ShopTRAINING/server/trainers/my_new_model_trainer.py`** + +这个训练函数需要遵循系统中其他训练器(如 `xgboost_trainer.py`)的统一函数签名,并使用 `@register_trainer` 装饰器或在文件末尾调用 `register_trainer` 函数。 + +```python +# file: ShopTRAINING/server/trainers/my_new_model_trainer.py + +import torch +import torch.optim as optim +from models.model_registry import register_trainer +from utils.model_manager import model_manager +from analysis.metrics import evaluate_model +from models.my_new_model import MyNewModel # 导入您的新模型 + +# 遵循系统的标准函数签名 +def train_with_mynewmodel(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs): + print(f"🚀 MyNewModel 训练器启动: model_identifier='{model_identifier}'") + + # --- 1. 数据准备 --- + # (此处省略了数据加载、标准化和创建数据集的详细代码, + # 您可以参考 xgboost_trainer.py 或其他训练器中的实现) + # ... + # 假设您已准备好 trainX, trainY, testX, testY, scaler_y 等变量 + # trainX = ... + # trainY = ... + # testX = ... + # testY = ... + # scaler_y = ... + # features = [...] + + # --- 2. 实例化模型和优化器 --- + input_dim = trainX.shape[2] # 获取特征数量 + hidden_size = 64 # 示例超参数 + + model = MyNewModel( + input_features=input_dim, + hidden_size=hidden_size, + output_sequence_length=forecast_horizon + ) + optimizer = optim.Adam(model.parameters(), lr=0.001) + criterion = torch.nn.MSELoss() + + # --- 3. 训练循环 --- + print("开始训练 MyNewModel...") + for epoch in range(epochs): + model.train() + optimizer.zero_grad() + outputs = model(trainX) + loss = criterion(outputs, trainY) + loss.backward() + optimizer.step() + if (epoch + 1) % 10 == 0: + print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}') + + # --- 4. 模型评估 --- + model.eval() + with torch.no_grad(): + test_pred_scaled = model(testX) + + # 反标准化并计算指标 + # ... (参考其他训练器) + metrics = {'rmse': 0.0, 'mae': 0.0, 'r2': 0.0, 'mape': 0.0} # 示例 + + # --- 5. 模型保存 --- + model_data = { + 'model_state_dict': model.state_dict(), + 'scaler_X': None, # 替换为您的 scaler_X + 'scaler_y': scaler_y, + 'config': { + 'model_type': 'mynewmodel', # **关键**: 使用唯一的模型名称 + 'input_dim': input_dim, + 'hidden_size': hidden_size, + 'sequence_length': sequence_length, + 'forecast_horizon': forecast_horizon, + 'features': features + }, + 'metrics': metrics + } + + model_manager.save_model( + model_data=model_data, + product_id=product_id, + model_type='mynewmodel', # **关键**: 再次确认模型名称 + # ... 其他参数 + ) + + print("✅ MyNewModel 模型训练并保存完成!") + return model, metrics, "v1", "path/to/model" # 返回值遵循统一格式 + +# --- 关键步骤: 将训练器注册到系统中 --- +register_trainer('mynewmodel', train_with_mynewmodel) +``` + +--- + +## 第 3 步:集成模型预测器 + +最后,您需要让系统知道如何加载和使用您的新模型进行预测。这需要在 `ShopTRAINING/server/predictors/model_predictor.py` 中进行两处修改。 + +**文件: `ShopTRAINING/server/predictors/model_predictor.py`** + +1. **让系统知道如何构建您的模型实例** + + 在 `load_model_and_predict` 函数中,有一个 `if/elif` 结构用于根据模型类型实例化不同的模型。您需要为 `MyNewModel` 添加一个新的分支。 + + ```python + # 在 model_predictor.py 中 + + # 首先,导入您的新模型类 + from models.my_new_model import MyNewModel + + # ... 在 load_model_and_predict 函数内部 ... + + # ... 其他模型的 elif 分支 ... + elif loaded_model_type == 'tcn': + model = TCNForecaster(...) + + # vvv 添加这个新的分支 vvv + elif loaded_model_type == 'mynewmodel': + model = MyNewModel( + input_features=config['input_dim'], + hidden_size=config['hidden_size'], + output_sequence_length=config['forecast_horizon'] + ).to(DEVICE) + # ^^^ 添加结束 ^^^ + + else: + raise ValueError(f"不支持的模型类型: {loaded_model_type}") + + model.load_state_dict(checkpoint['model_state_dict']) + model.eval() + ``` + +2. **注册预测逻辑** + + 如果您的模型是一个标准的 PyTorch 模型,并且其预测逻辑与现有的模型(如 Transformer, KAN)相同,您可以直接复用 `default_pytorch_predictor`。只需在文件末尾添加一行注册代码即可。 + + ```python + # 在 model_predictor.py 文件末尾 + + # ... + # 将增强后的默认预测器也注册给xgboost + register_predictor('xgboost', default_pytorch_predictor) + + # vvv 添加这行代码 vvv + # 让 'mynewmodel' 也使用通用的 PyTorch 预测器 + register_predictor('mynewmodel', default_pytorch_predictor) + # ^^^ 添加结束 ^^^ + ``` + + 如果您的模型需要特殊的预测逻辑(例如,像 XGBoost 那样有不同的输入格式或调用方式),您可以复制 `default_pytorch_predictor` 创建一个新函数,修改其内部逻辑,然后将新函数注册给 `'mynewmodel'`。 + +--- + +## 总结 + +完成以上三个步骤后,您的新模型 `MyNewModel` 就已完全集成到系统中了。系统会自动在 `trainers` 目录中发现您的新训练器。当您通过 API 或界面选择 `mynewmodel` 进行训练和预测时,系统将自动调用您刚刚编写和注册的所有相应逻辑。 \ No newline at end of file