From a1d9c60e618a5bda6f8f033f6adb7c0e4416a3eb Mon Sep 17 00:00:00 2001 From: LYFxiaoan Date: Wed, 16 Jul 2025 16:24:08 +0800 Subject: [PATCH] =?UTF-8?q?-=E5=AE=8C=E5=96=84=E5=BA=97=E9=93=BA=E9=A2=84?= =?UTF-8?q?=E6=B5=8B=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../views/prediction/StorePredictionView.vue | 64 +++++++++++++----- lyf开发日志记录文档.md | 46 +++++++++++++ prediction_history.db | Bin 245760 -> 286720 bytes server/api.py | 49 ++++++++------ server/core/config.py | 16 +++-- server/core/predictor.py | 21 +++--- server/predictors/model_predictor.py | 30 ++++---- server/trainers/mlstm_trainer.py | 12 ++-- server/trainers/tcn_trainer.py | 11 ++- server/trainers/transformer_trainer.py | 13 ++-- 10 files changed, 171 insertions(+), 91 deletions(-) diff --git a/UI/src/views/prediction/StorePredictionView.vue b/UI/src/views/prediction/StorePredictionView.vue index 692a104..22c472e 100644 --- a/UI/src/views/prediction/StorePredictionView.vue +++ b/UI/src/views/prediction/StorePredictionView.vue @@ -208,11 +208,15 @@ const startPrediction = async () => { future_days: form.future_days, start_date: form.start_date, analyze_result: form.analyze_result, - store_id: form.store_id + store_id: form.store_id, + // 修正:对于店铺模型,product_id应传递店铺的标识符 + product_id: `store_${form.store_id}` } - const response = await axios.post('/api/predict', payload) + // 修正API端点 + const response = await axios.post('/api/prediction', payload) if (response.data.status === 'success') { - predictionResult.value = response.data.data + // 修正:数据现在直接在响应的顶层 + predictionResult.value = response.data ElMessage.success('预测完成!') await nextTick() renderChart() @@ -231,30 +235,58 @@ const renderChart = () => { if (chart) { chart.destroy() } - const predictions = predictionResult.value.predictions - const labels = predictions.map(p => p.date) - const data = predictions.map(p => p.sales) + + const historyData = predictionResult.value.history_data || [] + const predictionData = predictionResult.value.prediction_data || [] + + const labels = [ + ...historyData.map(p => p.date), + ...predictionData.map(p => p.date) + ] + + const historySales = historyData.map(p => p.sales) + // 预测数据需要填充与历史数据等长的null值,以保证图表正确对齐 + const predictionSales = [ + ...Array(historyData.length).fill(null), + ...predictionData.map(p => p.predicted_sales) + ] + chart = new Chart(chartCanvas.value, { type: 'line', data: { labels, - datasets: [{ - label: '预测销量', - data, - borderColor: '#409EFF', - backgroundColor: 'rgba(64, 158, 255, 0.1)', - tension: 0.4, - fill: true - }] + datasets: [ + { + label: '历史销量', + data: historySales, + borderColor: '#67C23A', + backgroundColor: 'rgba(103, 194, 58, 0.1)', + fill: false, + tension: 0.4 + }, + { + label: '预测销量', + data: predictionSales, + borderColor: '#409EFF', + backgroundColor: 'rgba(64, 158, 255, 0.1)', + borderDash: [5, 5], // 虚线 + fill: false, + tension: 0.4 + } + ] }, options: { responsive: true, plugins: { title: { display: true, - text: '销量预测趋势图' + text: '店铺销量历史与预测趋势图' } - } + }, + interaction: { + intersect: false, + mode: 'index', + }, } }) } diff --git a/lyf开发日志记录文档.md b/lyf开发日志记录文档.md index 541e7b3..636425a 100644 --- a/lyf开发日志记录文档.md +++ b/lyf开发日志记录文档.md @@ -121,3 +121,49 @@ ### 11:45 - 项目总结与文档归档 - **任务**: 根据用户要求,回顾整个调试过程,将所有问题、解决方案、优化思路和最终结论,按照日期和时间顺序,整理并更新到本开发日志中,形成一份高质量的技术档案。 - **结果**: 本文档已更新完成。 + + +### 13:15 - 最终修复:根治模型标识符不一致问题 +- **问题**: 经过再次测试和日志分析,发现即便是修正后,店铺模型的 `model_identifier` 在训练时依然被错误地构建为 `01010023_store_01010023`。 +- **根本原因**: `server/core/predictor.py` 的 `train_model` 方法中,在 `training_mode == 'store'` 的分支下,构建 `model_identifier` 的逻辑存在冗余和错误。 +- **最终解决方案**: 删除了错误的拼接逻辑 `model_identifier = f"{store_id}_{product_id}"`,直接使用在之前步骤中已经被正确赋值为 `f"store_{store_id}"` 的 `product_id` 变量作为 `model_identifier`。这确保了从训练、保存到最终API查询,店铺模型的唯一标识符始终保持一致。 + + +### 13:30 - 最终修复(第二轮):根治模型保存路径错误 +- **问题**: 即便修复了标识符,模型版本依然无法加载。 +- **根本原因**: 通过分析训练日志,发现所有训练器(`transformer_trainer.py`, `mlstm_trainer.py`, `tcn_trainer.py`)中的 `save_checkpoint` 函数,都会强制在 `saved_models` 目录下创建一个 `checkpoints` 子目录,并将所有模型文件保存在其中。而负责查找模型的 `get_model_versions` 函数只在根目录查找,导致模型永远无法被发现。 +- **最终解决方案**: 逐一修改了所有相关训练器文件中的 `save_checkpoint` 函数,移除了创建和使用 `checkpoints` 子目录的逻辑,确保所有模型都直接保存在 `saved_models` 根目录下。 +- **结论**: 至此,模型保存的路径与查找的路径完全统一,从根本上解决了模型版本无法加载的问题。 + + +### 13:40 - 最终修复(第三轮):统一所有训练器的模型保存逻辑 +- **问题**: 在修复了 `transformer_trainer.py` 后,发现 `mlstm_trainer.py` 和 `tcn_trainer.py` 存在完全相同的路径和命名错误,导致问题依旧。 +- **根本原因**: `save_checkpoint` 函数在所有训练器中都被错误地实现,它们都强制创建了 `checkpoints` 子目录,并使用了错误的逻辑来拼接文件名。 +- **最终解决方案**: + 1. **逐一修复**: 逐一修改了 `transformer_trainer.py`, `mlstm_trainer.py`, 和 `tcn_trainer.py` 中的 `save_checkpoint` 函数。 + 2. **路径修复**: 移除了创建和使用 `checkpoints` 子目录的逻辑,确保模型直接保存在 `model_dir` (即 `saved_models`) 的根目录下。 + 3. **文件名修复**: 简化并修正了文件名的生成逻辑,直接使用 `product_id` 参数作为唯一标识符(该参数已由上游逻辑正确赋值为 `药品ID` 或 `store_{店铺ID}`),不再进行任何额外的、错误的拼接。 +- **结论**: 至此,所有训练器的模型保存逻辑完全统一,模型保存的路径和文件名与API的查找逻辑完全匹配,从根本上解决了模型版本无法加载的问题。 + + +--- + +## 2025-07-16 (续):端到端修复“店铺预测”图表功能 +**开发者**: lyf + +### 15:30 - 最终修复(第四轮):打通店铺预测的数据流 +- **问题**: 在解决了模型加载问题后,“店铺预测”功能虽然可以成功执行,但前端图表依然空白,不显示历史数据和预测数据。 +- **根本原因**: 参数传递在调用链中出现断裂。 + 1. `server/api.py` 在调用 `run_prediction` 时,没有传递 `training_mode`。 + 2. `server/core/predictor.py` 在调用 `load_model_and_predict` 时,没有传递 `store_id` 和 `training_mode`。 + 3. `server/predictors/model_predictor.py` 内部的数据加载逻辑,在处理店铺预测时,错误地使用了模型标识符(`store_{id}`)作为产品ID来过滤数据,导致无法加载到任何历史数据。 +- **最终解决方案 (三步修复)**: + 1. **修复 `model_predictor.py`**: 修改 `load_model_and_predict` 函数,使其能够根据 `training_mode` 参数智能地加载数据。当模式为 `'store'` 时,它会正确地聚合该店铺的所有销售数据作为历史数据,这与训练时的数据准备方式完全一致。 + 2. **修复 `predictor.py`**: 修改 `predict` 方法,将 `store_id` 和 `training_mode` 参数正确地传递给底层的 `load_model_and_predict` 函数。 + 3. **修复 `api.py`**: 修改 `predict` 路由和 `run_prediction` 辅助函数,确保 `training_mode` 参数在整个调用链中被完整传递。 +- **结论**: 通过以上修复,我们确保了从API接口到最底层数据加载器的参数传递是完整和正确的。现在,无论是药品预测还是店铺预测,系统都能够加载正确的历史数据用于图表绘制,彻底解决了图表显示空白的问题。 + +### 16:16 - 项目状态更新 +- **状态**: **所有已知问题已修复**。 +- **确认**: 用户已确认“现在药品和店铺预测流程通了”。 +- **后续**: 将本次修复过程归档至本文档。 diff --git a/prediction_history.db b/prediction_history.db index 18afacf96f921618616c68066a4ef31587659a26..fc75b895053fcc7848a571ca1108257593a220d1 100644 GIT binary patch delta 12331 zcmeG?Yiu0HdAGZy$m3JAL{bu?h@yg6w7Bb?nVs30(}SWVQD&{dnkyE7h~cg&(i`QQvkLa$p`l`)1*D(;oQD z?S{|nPWbF8z-Kpy&rY_yWy__?Lf0+1{s(#%2S3sMuRZ5`|GY2gTi^A)o-18{Ht>Vt z{-HnUdTRK&fuCEmtN&XArv`G}^Zg$g+SB{}f$4##hyQwT%bJJ#c6Z&|bAM01`}6s) z_ikUauJ4ij!TcAxH+S9Ad+VAvheii~XYj3|zwiJ3;j=wB|8K+2@EvO|_V*3E+WpGV z$MWwT%>VJ=&HbeN_?q&CVDe9l%68+h#(iXCU$`d51e=Z!zR4+KCUPBuJmRsUZ!gG! z%oe09yud7>a?x~3zGD`t6PT{+B7~V!Mh?GO4#b||90lBDHnm|OK_+)7FT=|mc9}ENAwDsMQ!1H` z&v}V*$1A$-0XfhuA>k2TGFiZ2MeHC`ctEM}JVr_-mv}{gzZ^&y!4+EooKeRG+%7VU zMO-quU9??XCPmw~_sNlEA^*96kt1$8fk#YFNZh5eC~_=NnK(Defy_mmVF9}dC=gkK zjd&%Wo5V*Y&$bB?6wk?lB^rnlqh7@%$Z-I-2V3^au37SJjtF4$9kExAw9C}z0TyOp z16qgMKqM)_F_oFHX$klq@L7c$Fi_|<9Xo)-;SMu}fWx7f6`3PA^N_bk4kVa)q*x** z7q}9@LZzbV`GDK=?V>0b5y64KTaFA6p~3@J65#`bgde~dhD@*MIlkkSaLGrzvErj)|Ix$XKe(if(Oq3IC{!bJjsJu2it#%!m=2{BQCJvi^cfsOgTi9PNU7sR#5 zi(KKhSqa2Jm_A}4lmX+W7!8WBof0WAk8&0W*Oeo~Tfb(!@4}hue{F-d@WS4%Nn@e+ z+2x*NoAth@mfv^G>FWEXzB`t0JNBN*zWseW@^9zg$iLEerf(qsRQ^of&mYLYKi|{) z&%J-z`|L_}ZExkhUDMmNllL4yQ9f!T_-A8s&k1koWN`e%-Q^Pt?ZG#<2On$?o(FKH z4e|r+r|)kM-q#*{Qv}}lPCK7lzV+i9w7u#(E5DO_+Ap|U4c#Tb^Eln3m2&I;qbLD>gOsg zyq;-g|4*bVqR9Y)FYe{qGe&sJAL^%XkmLaUcYD8bG%gLY>lc?2T^GN5oAIjFB`1WF zpVD?kzm7G2J-MV`z5!dD&rI4Dab_loh2gh}EGm11{EaOP+=4$Y30J3ozTPemM zL}mKWTyA*F(A4xQ~>o6Ir4HN#x-nS^%$49g9N5 zDX=syGc(At1s4u>DPt6KB1Nh$PACQZ?4?mxY zRQQ{-dVVX&8W<)Z$+2v(Ptr=sKVX^Q5AjTqtx6q(yLSjC2>q+#Ld^Tv7uvM4nbTybjQmZ>SE>og}^c|qT} zBfW4Nq{}O#qv4hR*1N;M`kubOf=wIAGr%?mI-rFUEs|(Ci55+?SfZuQwXr5oz(lN7 zF_9f>nkYdGN)Uq*#GnK*C_xNL5EBVmofwoLrh?+sm^eWUP7s3=#NY%mI6(}K5kpKI z;{~*0=%IBmhMt%th8|im^l&!G7M{8zCLN1m{}1%d(lUT-WBDcOW6Q5en64w<(Mm#OjeJ!IB8nyab zLrKMpS*K|^0r-`Sf}*aNb%Igh<=c!xc*{Y3a2$9JC!`LjZj0nHO53lC4ky{aOZ%tD zJgH%|5DF$p^JK`iXBUmL;Wu9}#7|1?V!{Qg9Fr}a)H6`Kg#4zIsh#>H7VT7DCR-Fn zN^|KKyrqw+d}T5G!WEEbu-AfE1cb+dOkK*HDIfr4Gn1m;jKHwF2+Wjlo30Y{dN0`O z2wEFeA6)xeUyG|rq>j*~;|e3SRme-JCBSnqY_V|VQ^wTx4CNT*I995!YASynoXQQY z9h`#Tc5JXx{rK)EgFVp5VB@8rcd7&>Q`2PY9W&UYG00J12Bv1%wWyDbAqe$o0dou4 zE(9_4gcg0&*iyCY}Z8W=Oe0Y zVL1+RFkU=0<$L#(A&ad!DNxGrJB~J9^Op%DlXA5&SmdN}=snPw8S|zwTPl@z759$A zSI+4j;Uvp1e|h6nYEgv{LW;dI@6oi+{PyJFc zx%wqz@)?M~IV^4xi^DFVr8a9}NH!oQM-ZojVU-YwCRAijvdDZa{e3v6J-kxAA2jIr zY+8dhS0}Pm?i(89s#M$oxfkzhP}6Qiw&7h3ic;R6t_J=4^`I%lqM$>zG~Trk8mEEna5BoQ&g6;>h}LnwT#N4q|J=KK2TmPBhcBL6@7VvLTnaJq3#Z}i11yzC3V{T1WTx@f=kltzt}b(cfRymv=!Y2RxAkpVtxc}i^Mh=Yd@498G%Z2maFb$czx-`@L!f+oUGp( zPX6QPzWfhgU5GhQyQ_SnQF>+to6GPFa$*NGOV(E5dLPQZ5T3ahF8V3M39U*4 z!|D-zy(+K1AJt$K#8-xBU?7^>*noxqvt(>ip_x^ma4nc*W%Zgcsfucv>I_?wCDm&} zBx|X(z9rZ`Q%@3BV?q;npSrROU(&RxOphIube;X=fk=GhWn)czvOIuD%0pUwwp>Mi z8rP1)NC}Mtff3y65&EvsNF|B%;HNG$S7-usZF-B(#{J+7S>#-4b^E^&l)m3 zp;DwjncLe;x*i%Cd1CRmK3ScpKdfr!z*P}wtGsz|QbiyPX_&CEpwx92!#B?xJFHra znKc$f&l}dY&QZbc%m1m5UvuQCW?p&T*swK4CRBL#%QBDA9p(le%a3-4E7b?PD`!tF z#Qe&r3QMxQ{!5;4+OJ&2tu(`u{H$XzH(XL02R+T|Ju$3U_7G9B0Nf4&)$1;O!DCQ6a>Uso|CqhylaYRK*})O z6bX2Vs@^Y|5C#iOF@_hI1iA2jDIU#2(I^RQykD{)4F;IaT)4m&7$zrD%QP@5VY~t5 z%^RbM*M?C>!yRG#>N$P$4U$j>?@8gwt!vX(I-9U&rKLtDqL7MH;Fcp=850L8#zdj9 z+B-0lIuuiQxY3F*4vY&@7Fmb{>8^v#%<94WqRcA8bK~}c6lAtxGP6?2I~X#868-~% zW>K(RsF;96*)4_)nM%Ro?Xm!$Enm*L@VZ&paJSS*qFP{K^>6gXMSmS?+whQ>K)|xf z%sNa(s|NVE=%T+qH4{>Kq8hwbfP4O#`o#LWSH%ev|KhtHPW;YuMrRoRt53dYT3)c*L3;+vU4n7F=3@Hqx z57-SK3Oxse3Bm{{3Xcg#37H5{4iycW4eSo$3$71_3F{9u4_yli49yN_2_6l@4M+%btU^J6nU>vg%Fh>fL XAa59h(`tv)Y5|ASY67>@Y6C4*t=T~f diff --git a/server/api.py b/server/api.py index 88d118f..1c3895f 100644 --- a/server/api.py +++ b/server/api.py @@ -1510,7 +1510,8 @@ def predict(): data = request.json product_id = data.get('product_id') model_type = data.get('model_type') - store_id = data.get('store_id') # 新增店铺ID参数 + store_id = data.get('store_id') + training_mode = 'store' if store_id else 'product' version = data.get('version') # 新增版本参数 future_days = int(data.get('future_days', 7)) start_date = data.get('start_date', '') @@ -1527,28 +1528,31 @@ def predict(): if not product_name: product_name = product_id - # 根据版本获取模型ID - if version: - # 如果指定了版本,构造版本化的模型ID - model_id = f"{product_id}_{model_type}_{version}" - # 检查指定版本的模型是否存在 - model_file_path = get_model_file_path(product_id, model_type, version) - if not os.path.exists(model_file_path): - return jsonify({"status": "error", "error": f"未找到产品 {product_id} 的 {model_type} 类型模型版本 {version}"}), 404 + # 根据训练模式构建模型标识符 + if training_mode == 'store': + model_identifier = f"store_{store_id}" + # 对于店铺预测,product_id实际上是store_id,但我们需要一个药品ID来获取名称,这里暂时用一个占位符 + product_name = f"店铺 {store_id} 整体" else: - # 如果没有指定版本,使用最新版本 - latest_version = get_latest_model_version(product_id, model_type) - if latest_version: - model_id = f"{product_id}_{model_type}_{latest_version}" - version = latest_version - else: - # 兼容旧的无版本模型 - model_id = get_latest_model_id(model_type, product_id) - if not model_id: - return jsonify({"status": "error", "error": f"未找到产品 {product_id} 的 {model_type} 类型模型"}), 404 + model_identifier = product_id + product_name = get_product_name(product_id) or product_id + + # 获取模型版本 + if not version: + version = get_latest_model_version(model_identifier, model_type) + + if not version: + return jsonify({"status": "error", "error": f"未找到标识符为 {model_identifier} 的 {model_type} 类型模型"}), 404 + + # 检查模型文件是否存在 + model_file_path = get_model_file_path(model_identifier, model_type, version) + if not os.path.exists(model_file_path): + return jsonify({"status": "error", "error": f"未找到模型文件: {model_file_path}"}), 404 + + model_id = f"{model_identifier}_{model_type}_{version}" # 执行预测 - prediction_result = run_prediction(model_type, product_id, model_id, future_days, start_date, version, store_id) + prediction_result = run_prediction(model_type, product_id, model_id, future_days, start_date, version, store_id, training_mode) if prediction_result is None: return jsonify({"status": "error", "error": "预测失败,预测器返回None"}), 500 @@ -2708,7 +2712,7 @@ def get_product_name(product_id): return None # 执行预测的辅助函数 -def run_prediction(model_type, product_id, model_id, future_days, start_date, version=None, store_id=None): +def run_prediction(model_type, product_id, model_id, future_days, start_date, version=None, store_id=None, training_mode='product'): """执行模型预测""" try: scope_msg = f", store_id={store_id}" if store_id else ", 全局模型" @@ -2729,7 +2733,8 @@ def run_prediction(model_type, product_id, model_id, future_days, start_date, ve store_id=store_id, future_days=future_days, start_date=start_date, - version=version + version=version, + training_mode=training_mode ) if prediction_result is None: diff --git a/server/core/config.py b/server/core/config.py index dd67b52..fddd7b4 100644 --- a/server/core/config.py +++ b/server/core/config.py @@ -131,10 +131,11 @@ def get_model_file_path(product_id: str, model_type: str, version: str) -> str: filename = f"{model_type}_model_product_{product_id}.pth" return os.path.join(DEFAULT_MODEL_DIR, filename) - # 处理新的、基于epoch的检查点命名格式 - # 文件名示例: transformer_product_17002608_epoch_best.pth - filename = f"{model_type}_product_{product_id}_epoch_{version}.pth" - return os.path.join(DEFAULT_MODEL_DIR, 'checkpoints', filename) + # 修正:直接使用唯一的product_id(它可能包含store_前缀)来构建文件名 + # 文件名示例: transformer_17002608_epoch_best.pth 或 transformer_store_01010023_epoch_best.pth + filename = f"{model_type}_{product_id}_epoch_{version}.pth" + # 修正:直接在根模型目录查找,不再使用checkpoints子目录 + return os.path.join(DEFAULT_MODEL_DIR, filename) def get_model_versions(product_id: str, model_type: str) -> list: """ @@ -149,9 +150,10 @@ def get_model_versions(product_id: str, model_type: str) -> list: """ # 直接使用传入的product_id构建搜索模式 # 搜索模式,匹配 "transformer_product_17002608_epoch_50.pth" 或 "transformer_product_17002608_epoch_best.pth" - pattern = f"{model_type}_product_{product_id}_epoch_*.pth" - # 在 checkpoints 子目录中查找 - search_path = os.path.join(DEFAULT_MODEL_DIR, 'checkpoints', pattern) + # 修正:直接使用唯一的product_id(它可能包含store_前缀)来构建搜索模式 + pattern = f"{model_type}_{product_id}_epoch_*.pth" + # 修正:直接在根模型目录查找,不再使用checkpoints子目录 + search_path = os.path.join(DEFAULT_MODEL_DIR, pattern) existing_files = glob.glob(search_path) # 旧格式(兼容性支持) diff --git a/server/core/predictor.py b/server/core/predictor.py index 2a22e92..9345b98 100644 --- a/server/core/predictor.py +++ b/server/core/predictor.py @@ -132,8 +132,8 @@ class PharmacyPredictor: file_path=self.data_path ) log_message(f"按店铺聚合训练: 店铺 {store_id}, 聚合方法 {aggregation_method}, 数据量: {len(product_data)}") - # 将product_id设置为店铺ID,以便模型保存时使用有意义的标识 - product_id = store_id + # 将product_id设置为'store_{store_id}',与API查找逻辑保持一致 + product_id = f"store_{store_id}" except Exception as e: log_message(f"聚合店铺 {store_id} 数据失败: {e}", 'error') return None @@ -179,7 +179,7 @@ class PharmacyPredictor: # 根据训练模式构建模型标识符 if training_mode == 'store': - model_identifier = f"{store_id}_{product_id}" + model_identifier = product_id elif training_mode == 'global': model_identifier = f"global_{product_id}_{aggregation_method}" else: @@ -308,19 +308,22 @@ class PharmacyPredictor: """ # 根据训练模式构建模型标识符 if training_mode == 'store' and store_id: - model_identifier = f"{store_id}_{product_id}" + # 修正:店铺模型的标识符应该只基于店铺ID + model_identifier = f"store_{store_id}" elif training_mode == 'global': model_identifier = f"global_{product_id}_{aggregation_method}" else: model_identifier = product_id return load_model_and_predict( - model_identifier, - model_type, - future_days=future_days, - start_date=start_date, + model_identifier, + model_type, + store_id=store_id, + future_days=future_days, + start_date=start_date, analyze_result=analyze_result, - version=version + version=version, + training_mode=training_mode ) def train_optimized_kan_model(self, product_id, epochs=100, batch_size=32, diff --git a/server/predictors/model_predictor.py b/server/predictors/model_predictor.py index 0051424..560d2ee 100644 --- a/server/predictors/model_predictor.py +++ b/server/predictors/model_predictor.py @@ -23,7 +23,7 @@ from utils.visualization import plot_prediction_results from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data from core.config import DEVICE, get_model_file_path, DEFAULT_DATA_PATH -def load_model_and_predict(product_id, model_type, store_id=None, future_days=7, start_date=None, analyze_result=False, version=None): +def load_model_and_predict(product_id, model_type, store_id=None, future_days=7, start_date=None, analyze_result=False, version=None, training_mode='product'): """ 加载已训练的模型并进行预测 @@ -101,33 +101,37 @@ def load_model_and_predict(product_id, model_type, store_id=None, future_days=7, # 加载销售数据(支持多店铺) try: - from utils.multi_store_data_utils import load_multi_store_data - if store_id: - # 加载特定店铺的数据 - product_df = load_multi_store_data( - file_path=DEFAULT_DATA_PATH, + from utils.multi_store_data_utils import aggregate_multi_store_data + + # 根据训练模式加载相应的数据 + if training_mode == 'store' and store_id: + # 店铺模型:聚合该店铺的所有产品数据 + product_df = aggregate_multi_store_data( store_id=store_id, - product_id=product_id + aggregation_method='sum', + file_path=DEFAULT_DATA_PATH ) - store_name = product_df['store_name'].iloc[0] if 'store_name' in product_df.columns else f"店铺{store_id}" + store_name = product_df['store_name'].iloc[0] if 'store_name' in product_df.columns and not product_df.empty else f"店铺{store_id}" prediction_scope = f"店铺 '{store_name}' ({store_id})" + # 对于店铺模型,其“产品名称”就是店铺名称 + product_name = store_name else: - # 聚合所有店铺的数据进行预测 + # 产品模型(默认):聚合该产品在所有店铺的数据 + # 此时,传入的product_id是真正的产品ID product_df = aggregate_multi_store_data( product_id=product_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH ) prediction_scope = "全部店铺(聚合数据)" + product_name = product_df['product_name'].iloc[0] if not product_df.empty else product_id except Exception as e: print(f"加载数据失败: {e}") return None - + if product_df.empty: - print(f"产品 {product_id} 没有销售数据") + print(f"产品 {product_id} 或店铺 {store_id} 没有销售数据") return None - - product_name = product_df['product_name'].iloc[0] print(f"使用 {model_type} 模型预测产品 '{product_name}' (ID: {product_id}) 的未来 {future_days} 天销量") print(f"预测范围: {prediction_scope}") diff --git a/server/trainers/mlstm_trainer.py b/server/trainers/mlstm_trainer.py index 2f6eab5..c26f44e 100644 --- a/server/trainers/mlstm_trainer.py +++ b/server/trainers/mlstm_trainer.py @@ -42,16 +42,12 @@ def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str, aggregation_method: 聚合方法 """ # 创建检查点目录 - checkpoint_dir = os.path.join(model_dir, 'checkpoints') + # 直接在模型根目录保存,不再创建子目录 + checkpoint_dir = model_dir os.makedirs(checkpoint_dir, exist_ok=True) - # 生成检查点文件名 - if training_mode == 'store' and store_id: - filename = f"{model_type}_store_{store_id}_{product_id}_epoch_{epoch_or_label}.pth" - elif training_mode == 'global' and aggregation_method: - filename = f"{model_type}_global_{product_id}_{aggregation_method}_epoch_{epoch_or_label}.pth" - else: - filename = f"{model_type}_product_{product_id}_epoch_{epoch_or_label}.pth" + # 修正:直接使用product_id作为唯一标识符,因为它已经包含了store_前缀或药品ID + filename = f"{model_type}_{product_id}_epoch_{epoch_or_label}.pth" checkpoint_path = os.path.join(checkpoint_dir, filename) diff --git a/server/trainers/tcn_trainer.py b/server/trainers/tcn_trainer.py index 703ad68..acf5386 100644 --- a/server/trainers/tcn_trainer.py +++ b/server/trainers/tcn_trainer.py @@ -38,16 +38,13 @@ def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str, aggregation_method: 聚合方法 """ # 创建检查点目录 - checkpoint_dir = os.path.join(model_dir, 'checkpoints') + # 直接在模型根目录保存,不再创建子目录 + checkpoint_dir = model_dir os.makedirs(checkpoint_dir, exist_ok=True) # 生成检查点文件名 - if training_mode == 'store' and store_id: - filename = f"{model_type}_store_{store_id}_{product_id}_epoch_{epoch_or_label}.pth" - elif training_mode == 'global' and aggregation_method: - filename = f"{model_type}_global_{product_id}_{aggregation_method}_epoch_{epoch_or_label}.pth" - else: - filename = f"{model_type}_product_{product_id}_epoch_{epoch_or_label}.pth" + # 修正:直接使用product_id作为唯一标识符,因为它已经包含了store_前缀或药品ID + filename = f"{model_type}_{product_id}_epoch_{epoch_or_label}.pth" checkpoint_path = os.path.join(checkpoint_dir, filename) diff --git a/server/trainers/transformer_trainer.py b/server/trainers/transformer_trainer.py index 1b7e41d..fb8a55f 100644 --- a/server/trainers/transformer_trainer.py +++ b/server/trainers/transformer_trainer.py @@ -43,17 +43,12 @@ def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str, training_mode: 训练模式 aggregation_method: 聚合方法 """ - # 创建检查点目录 - checkpoint_dir = os.path.join(model_dir, 'checkpoints') + # 直接在模型根目录保存,不再创建子目录 + checkpoint_dir = model_dir os.makedirs(checkpoint_dir, exist_ok=True) - # 生成检查点文件名 - if training_mode == 'store' and store_id: - filename = f"{model_type}_store_{store_id}_{product_id}_epoch_{epoch_or_label}.pth" - elif training_mode == 'global' and aggregation_method: - filename = f"{model_type}_global_{product_id}_{aggregation_method}_epoch_{epoch_or_label}.pth" - else: - filename = f"{model_type}_product_{product_id}_epoch_{epoch_or_label}.pth" + # 修正:直接使用product_id作为唯一标识符,因为它已经包含了store_前缀或药品ID + filename = f"{model_type}_{product_id}_epoch_{epoch_or_label}.pth" checkpoint_path = os.path.join(checkpoint_dir, filename)