diff --git a/UI/src/views/prediction/StorePredictionView.vue b/UI/src/views/prediction/StorePredictionView.vue index 692a104..22c472e 100644 --- a/UI/src/views/prediction/StorePredictionView.vue +++ b/UI/src/views/prediction/StorePredictionView.vue @@ -208,11 +208,15 @@ const startPrediction = async () => { future_days: form.future_days, start_date: form.start_date, analyze_result: form.analyze_result, - store_id: form.store_id + store_id: form.store_id, + // 修正:对于店铺模型,product_id应传递店铺的标识符 + product_id: `store_${form.store_id}` } - const response = await axios.post('/api/predict', payload) + // 修正API端点 + const response = await axios.post('/api/prediction', payload) if (response.data.status === 'success') { - predictionResult.value = response.data.data + // 修正:数据现在直接在响应的顶层 + predictionResult.value = response.data ElMessage.success('预测完成!') await nextTick() renderChart() @@ -231,30 +235,58 @@ const renderChart = () => { if (chart) { chart.destroy() } - const predictions = predictionResult.value.predictions - const labels = predictions.map(p => p.date) - const data = predictions.map(p => p.sales) + + const historyData = predictionResult.value.history_data || [] + const predictionData = predictionResult.value.prediction_data || [] + + const labels = [ + ...historyData.map(p => p.date), + ...predictionData.map(p => p.date) + ] + + const historySales = historyData.map(p => p.sales) + // 预测数据需要填充与历史数据等长的null值,以保证图表正确对齐 + const predictionSales = [ + ...Array(historyData.length).fill(null), + ...predictionData.map(p => p.predicted_sales) + ] + chart = new Chart(chartCanvas.value, { type: 'line', data: { labels, - datasets: [{ - label: '预测销量', - data, - borderColor: '#409EFF', - backgroundColor: 'rgba(64, 158, 255, 0.1)', - tension: 0.4, - fill: true - }] + datasets: [ + { + label: '历史销量', + data: historySales, + borderColor: '#67C23A', + backgroundColor: 'rgba(103, 194, 58, 0.1)', + fill: false, + tension: 0.4 + }, + { + label: '预测销量', + data: predictionSales, + borderColor: '#409EFF', + backgroundColor: 'rgba(64, 158, 255, 0.1)', + borderDash: [5, 5], // 虚线 + fill: false, + tension: 0.4 + } + ] }, options: { responsive: true, plugins: { title: { display: true, - text: '销量预测趋势图' + text: '店铺销量历史与预测趋势图' } - } + }, + interaction: { + intersect: false, + mode: 'index', + }, } }) } diff --git a/lyf开发日志记录文档.md b/lyf开发日志记录文档.md index 541e7b3..636425a 100644 --- a/lyf开发日志记录文档.md +++ b/lyf开发日志记录文档.md @@ -121,3 +121,49 @@ ### 11:45 - 项目总结与文档归档 - **任务**: 根据用户要求,回顾整个调试过程,将所有问题、解决方案、优化思路和最终结论,按照日期和时间顺序,整理并更新到本开发日志中,形成一份高质量的技术档案。 - **结果**: 本文档已更新完成。 + + +### 13:15 - 最终修复:根治模型标识符不一致问题 +- **问题**: 经过再次测试和日志分析,发现即便是修正后,店铺模型的 `model_identifier` 在训练时依然被错误地构建为 `01010023_store_01010023`。 +- **根本原因**: `server/core/predictor.py` 的 `train_model` 方法中,在 `training_mode == 'store'` 的分支下,构建 `model_identifier` 的逻辑存在冗余和错误。 +- **最终解决方案**: 删除了错误的拼接逻辑 `model_identifier = f"{store_id}_{product_id}"`,直接使用在之前步骤中已经被正确赋值为 `f"store_{store_id}"` 的 `product_id` 变量作为 `model_identifier`。这确保了从训练、保存到最终API查询,店铺模型的唯一标识符始终保持一致。 + + +### 13:30 - 最终修复(第二轮):根治模型保存路径错误 +- **问题**: 即便修复了标识符,模型版本依然无法加载。 +- **根本原因**: 通过分析训练日志,发现所有训练器(`transformer_trainer.py`, `mlstm_trainer.py`, `tcn_trainer.py`)中的 `save_checkpoint` 函数,都会强制在 `saved_models` 目录下创建一个 `checkpoints` 子目录,并将所有模型文件保存在其中。而负责查找模型的 `get_model_versions` 函数只在根目录查找,导致模型永远无法被发现。 +- **最终解决方案**: 逐一修改了所有相关训练器文件中的 `save_checkpoint` 函数,移除了创建和使用 `checkpoints` 子目录的逻辑,确保所有模型都直接保存在 `saved_models` 根目录下。 +- **结论**: 至此,模型保存的路径与查找的路径完全统一,从根本上解决了模型版本无法加载的问题。 + + +### 13:40 - 最终修复(第三轮):统一所有训练器的模型保存逻辑 +- **问题**: 在修复了 `transformer_trainer.py` 后,发现 `mlstm_trainer.py` 和 `tcn_trainer.py` 存在完全相同的路径和命名错误,导致问题依旧。 +- **根本原因**: `save_checkpoint` 函数在所有训练器中都被错误地实现,它们都强制创建了 `checkpoints` 子目录,并使用了错误的逻辑来拼接文件名。 +- **最终解决方案**: + 1. **逐一修复**: 逐一修改了 `transformer_trainer.py`, `mlstm_trainer.py`, 和 `tcn_trainer.py` 中的 `save_checkpoint` 函数。 + 2. **路径修复**: 移除了创建和使用 `checkpoints` 子目录的逻辑,确保模型直接保存在 `model_dir` (即 `saved_models`) 的根目录下。 + 3. **文件名修复**: 简化并修正了文件名的生成逻辑,直接使用 `product_id` 参数作为唯一标识符(该参数已由上游逻辑正确赋值为 `药品ID` 或 `store_{店铺ID}`),不再进行任何额外的、错误的拼接。 +- **结论**: 至此,所有训练器的模型保存逻辑完全统一,模型保存的路径和文件名与API的查找逻辑完全匹配,从根本上解决了模型版本无法加载的问题。 + + +--- + +## 2025-07-16 (续):端到端修复“店铺预测”图表功能 +**开发者**: lyf + +### 15:30 - 最终修复(第四轮):打通店铺预测的数据流 +- **问题**: 在解决了模型加载问题后,“店铺预测”功能虽然可以成功执行,但前端图表依然空白,不显示历史数据和预测数据。 +- **根本原因**: 参数传递在调用链中出现断裂。 + 1. `server/api.py` 在调用 `run_prediction` 时,没有传递 `training_mode`。 + 2. `server/core/predictor.py` 在调用 `load_model_and_predict` 时,没有传递 `store_id` 和 `training_mode`。 + 3. `server/predictors/model_predictor.py` 内部的数据加载逻辑,在处理店铺预测时,错误地使用了模型标识符(`store_{id}`)作为产品ID来过滤数据,导致无法加载到任何历史数据。 +- **最终解决方案 (三步修复)**: + 1. **修复 `model_predictor.py`**: 修改 `load_model_and_predict` 函数,使其能够根据 `training_mode` 参数智能地加载数据。当模式为 `'store'` 时,它会正确地聚合该店铺的所有销售数据作为历史数据,这与训练时的数据准备方式完全一致。 + 2. **修复 `predictor.py`**: 修改 `predict` 方法,将 `store_id` 和 `training_mode` 参数正确地传递给底层的 `load_model_and_predict` 函数。 + 3. **修复 `api.py`**: 修改 `predict` 路由和 `run_prediction` 辅助函数,确保 `training_mode` 参数在整个调用链中被完整传递。 +- **结论**: 通过以上修复,我们确保了从API接口到最底层数据加载器的参数传递是完整和正确的。现在,无论是药品预测还是店铺预测,系统都能够加载正确的历史数据用于图表绘制,彻底解决了图表显示空白的问题。 + +### 16:16 - 项目状态更新 +- **状态**: **所有已知问题已修复**。 +- **确认**: 用户已确认“现在药品和店铺预测流程通了”。 +- **后续**: 将本次修复过程归档至本文档。 diff --git a/prediction_history.db b/prediction_history.db index 18afacf..fc75b89 100644 Binary files a/prediction_history.db and b/prediction_history.db differ diff --git a/server/api.py b/server/api.py index 88d118f..1c3895f 100644 --- a/server/api.py +++ b/server/api.py @@ -1510,7 +1510,8 @@ def predict(): data = request.json product_id = data.get('product_id') model_type = data.get('model_type') - store_id = data.get('store_id') # 新增店铺ID参数 + store_id = data.get('store_id') + training_mode = 'store' if store_id else 'product' version = data.get('version') # 新增版本参数 future_days = int(data.get('future_days', 7)) start_date = data.get('start_date', '') @@ -1527,28 +1528,31 @@ def predict(): if not product_name: product_name = product_id - # 根据版本获取模型ID - if version: - # 如果指定了版本,构造版本化的模型ID - model_id = f"{product_id}_{model_type}_{version}" - # 检查指定版本的模型是否存在 - model_file_path = get_model_file_path(product_id, model_type, version) - if not os.path.exists(model_file_path): - return jsonify({"status": "error", "error": f"未找到产品 {product_id} 的 {model_type} 类型模型版本 {version}"}), 404 + # 根据训练模式构建模型标识符 + if training_mode == 'store': + model_identifier = f"store_{store_id}" + # 对于店铺预测,product_id实际上是store_id,但我们需要一个药品ID来获取名称,这里暂时用一个占位符 + product_name = f"店铺 {store_id} 整体" else: - # 如果没有指定版本,使用最新版本 - latest_version = get_latest_model_version(product_id, model_type) - if latest_version: - model_id = f"{product_id}_{model_type}_{latest_version}" - version = latest_version - else: - # 兼容旧的无版本模型 - model_id = get_latest_model_id(model_type, product_id) - if not model_id: - return jsonify({"status": "error", "error": f"未找到产品 {product_id} 的 {model_type} 类型模型"}), 404 + model_identifier = product_id + product_name = get_product_name(product_id) or product_id + + # 获取模型版本 + if not version: + version = get_latest_model_version(model_identifier, model_type) + + if not version: + return jsonify({"status": "error", "error": f"未找到标识符为 {model_identifier} 的 {model_type} 类型模型"}), 404 + + # 检查模型文件是否存在 + model_file_path = get_model_file_path(model_identifier, model_type, version) + if not os.path.exists(model_file_path): + return jsonify({"status": "error", "error": f"未找到模型文件: {model_file_path}"}), 404 + + model_id = f"{model_identifier}_{model_type}_{version}" # 执行预测 - prediction_result = run_prediction(model_type, product_id, model_id, future_days, start_date, version, store_id) + prediction_result = run_prediction(model_type, product_id, model_id, future_days, start_date, version, store_id, training_mode) if prediction_result is None: return jsonify({"status": "error", "error": "预测失败,预测器返回None"}), 500 @@ -2708,7 +2712,7 @@ def get_product_name(product_id): return None # 执行预测的辅助函数 -def run_prediction(model_type, product_id, model_id, future_days, start_date, version=None, store_id=None): +def run_prediction(model_type, product_id, model_id, future_days, start_date, version=None, store_id=None, training_mode='product'): """执行模型预测""" try: scope_msg = f", store_id={store_id}" if store_id else ", 全局模型" @@ -2729,7 +2733,8 @@ def run_prediction(model_type, product_id, model_id, future_days, start_date, ve store_id=store_id, future_days=future_days, start_date=start_date, - version=version + version=version, + training_mode=training_mode ) if prediction_result is None: diff --git a/server/core/config.py b/server/core/config.py index dd67b52..fddd7b4 100644 --- a/server/core/config.py +++ b/server/core/config.py @@ -131,10 +131,11 @@ def get_model_file_path(product_id: str, model_type: str, version: str) -> str: filename = f"{model_type}_model_product_{product_id}.pth" return os.path.join(DEFAULT_MODEL_DIR, filename) - # 处理新的、基于epoch的检查点命名格式 - # 文件名示例: transformer_product_17002608_epoch_best.pth - filename = f"{model_type}_product_{product_id}_epoch_{version}.pth" - return os.path.join(DEFAULT_MODEL_DIR, 'checkpoints', filename) + # 修正:直接使用唯一的product_id(它可能包含store_前缀)来构建文件名 + # 文件名示例: transformer_17002608_epoch_best.pth 或 transformer_store_01010023_epoch_best.pth + filename = f"{model_type}_{product_id}_epoch_{version}.pth" + # 修正:直接在根模型目录查找,不再使用checkpoints子目录 + return os.path.join(DEFAULT_MODEL_DIR, filename) def get_model_versions(product_id: str, model_type: str) -> list: """ @@ -149,9 +150,10 @@ def get_model_versions(product_id: str, model_type: str) -> list: """ # 直接使用传入的product_id构建搜索模式 # 搜索模式,匹配 "transformer_product_17002608_epoch_50.pth" 或 "transformer_product_17002608_epoch_best.pth" - pattern = f"{model_type}_product_{product_id}_epoch_*.pth" - # 在 checkpoints 子目录中查找 - search_path = os.path.join(DEFAULT_MODEL_DIR, 'checkpoints', pattern) + # 修正:直接使用唯一的product_id(它可能包含store_前缀)来构建搜索模式 + pattern = f"{model_type}_{product_id}_epoch_*.pth" + # 修正:直接在根模型目录查找,不再使用checkpoints子目录 + search_path = os.path.join(DEFAULT_MODEL_DIR, pattern) existing_files = glob.glob(search_path) # 旧格式(兼容性支持) diff --git a/server/core/predictor.py b/server/core/predictor.py index 2a22e92..9345b98 100644 --- a/server/core/predictor.py +++ b/server/core/predictor.py @@ -132,8 +132,8 @@ class PharmacyPredictor: file_path=self.data_path ) log_message(f"按店铺聚合训练: 店铺 {store_id}, 聚合方法 {aggregation_method}, 数据量: {len(product_data)}") - # 将product_id设置为店铺ID,以便模型保存时使用有意义的标识 - product_id = store_id + # 将product_id设置为'store_{store_id}',与API查找逻辑保持一致 + product_id = f"store_{store_id}" except Exception as e: log_message(f"聚合店铺 {store_id} 数据失败: {e}", 'error') return None @@ -179,7 +179,7 @@ class PharmacyPredictor: # 根据训练模式构建模型标识符 if training_mode == 'store': - model_identifier = f"{store_id}_{product_id}" + model_identifier = product_id elif training_mode == 'global': model_identifier = f"global_{product_id}_{aggregation_method}" else: @@ -308,19 +308,22 @@ class PharmacyPredictor: """ # 根据训练模式构建模型标识符 if training_mode == 'store' and store_id: - model_identifier = f"{store_id}_{product_id}" + # 修正:店铺模型的标识符应该只基于店铺ID + model_identifier = f"store_{store_id}" elif training_mode == 'global': model_identifier = f"global_{product_id}_{aggregation_method}" else: model_identifier = product_id return load_model_and_predict( - model_identifier, - model_type, - future_days=future_days, - start_date=start_date, + model_identifier, + model_type, + store_id=store_id, + future_days=future_days, + start_date=start_date, analyze_result=analyze_result, - version=version + version=version, + training_mode=training_mode ) def train_optimized_kan_model(self, product_id, epochs=100, batch_size=32, diff --git a/server/predictors/model_predictor.py b/server/predictors/model_predictor.py index 0051424..560d2ee 100644 --- a/server/predictors/model_predictor.py +++ b/server/predictors/model_predictor.py @@ -23,7 +23,7 @@ from utils.visualization import plot_prediction_results from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data from core.config import DEVICE, get_model_file_path, DEFAULT_DATA_PATH -def load_model_and_predict(product_id, model_type, store_id=None, future_days=7, start_date=None, analyze_result=False, version=None): +def load_model_and_predict(product_id, model_type, store_id=None, future_days=7, start_date=None, analyze_result=False, version=None, training_mode='product'): """ 加载已训练的模型并进行预测 @@ -101,33 +101,37 @@ def load_model_and_predict(product_id, model_type, store_id=None, future_days=7, # 加载销售数据(支持多店铺) try: - from utils.multi_store_data_utils import load_multi_store_data - if store_id: - # 加载特定店铺的数据 - product_df = load_multi_store_data( - file_path=DEFAULT_DATA_PATH, + from utils.multi_store_data_utils import aggregate_multi_store_data + + # 根据训练模式加载相应的数据 + if training_mode == 'store' and store_id: + # 店铺模型:聚合该店铺的所有产品数据 + product_df = aggregate_multi_store_data( store_id=store_id, - product_id=product_id + aggregation_method='sum', + file_path=DEFAULT_DATA_PATH ) - store_name = product_df['store_name'].iloc[0] if 'store_name' in product_df.columns else f"店铺{store_id}" + store_name = product_df['store_name'].iloc[0] if 'store_name' in product_df.columns and not product_df.empty else f"店铺{store_id}" prediction_scope = f"店铺 '{store_name}' ({store_id})" + # 对于店铺模型,其“产品名称”就是店铺名称 + product_name = store_name else: - # 聚合所有店铺的数据进行预测 + # 产品模型(默认):聚合该产品在所有店铺的数据 + # 此时,传入的product_id是真正的产品ID product_df = aggregate_multi_store_data( product_id=product_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH ) prediction_scope = "全部店铺(聚合数据)" + product_name = product_df['product_name'].iloc[0] if not product_df.empty else product_id except Exception as e: print(f"加载数据失败: {e}") return None - + if product_df.empty: - print(f"产品 {product_id} 没有销售数据") + print(f"产品 {product_id} 或店铺 {store_id} 没有销售数据") return None - - product_name = product_df['product_name'].iloc[0] print(f"使用 {model_type} 模型预测产品 '{product_name}' (ID: {product_id}) 的未来 {future_days} 天销量") print(f"预测范围: {prediction_scope}") diff --git a/server/trainers/mlstm_trainer.py b/server/trainers/mlstm_trainer.py index 2f6eab5..c26f44e 100644 --- a/server/trainers/mlstm_trainer.py +++ b/server/trainers/mlstm_trainer.py @@ -42,16 +42,12 @@ def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str, aggregation_method: 聚合方法 """ # 创建检查点目录 - checkpoint_dir = os.path.join(model_dir, 'checkpoints') + # 直接在模型根目录保存,不再创建子目录 + checkpoint_dir = model_dir os.makedirs(checkpoint_dir, exist_ok=True) - # 生成检查点文件名 - if training_mode == 'store' and store_id: - filename = f"{model_type}_store_{store_id}_{product_id}_epoch_{epoch_or_label}.pth" - elif training_mode == 'global' and aggregation_method: - filename = f"{model_type}_global_{product_id}_{aggregation_method}_epoch_{epoch_or_label}.pth" - else: - filename = f"{model_type}_product_{product_id}_epoch_{epoch_or_label}.pth" + # 修正:直接使用product_id作为唯一标识符,因为它已经包含了store_前缀或药品ID + filename = f"{model_type}_{product_id}_epoch_{epoch_or_label}.pth" checkpoint_path = os.path.join(checkpoint_dir, filename) diff --git a/server/trainers/tcn_trainer.py b/server/trainers/tcn_trainer.py index 703ad68..acf5386 100644 --- a/server/trainers/tcn_trainer.py +++ b/server/trainers/tcn_trainer.py @@ -38,16 +38,13 @@ def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str, aggregation_method: 聚合方法 """ # 创建检查点目录 - checkpoint_dir = os.path.join(model_dir, 'checkpoints') + # 直接在模型根目录保存,不再创建子目录 + checkpoint_dir = model_dir os.makedirs(checkpoint_dir, exist_ok=True) # 生成检查点文件名 - if training_mode == 'store' and store_id: - filename = f"{model_type}_store_{store_id}_{product_id}_epoch_{epoch_or_label}.pth" - elif training_mode == 'global' and aggregation_method: - filename = f"{model_type}_global_{product_id}_{aggregation_method}_epoch_{epoch_or_label}.pth" - else: - filename = f"{model_type}_product_{product_id}_epoch_{epoch_or_label}.pth" + # 修正:直接使用product_id作为唯一标识符,因为它已经包含了store_前缀或药品ID + filename = f"{model_type}_{product_id}_epoch_{epoch_or_label}.pth" checkpoint_path = os.path.join(checkpoint_dir, filename) diff --git a/server/trainers/transformer_trainer.py b/server/trainers/transformer_trainer.py index 1b7e41d..fb8a55f 100644 --- a/server/trainers/transformer_trainer.py +++ b/server/trainers/transformer_trainer.py @@ -43,17 +43,12 @@ def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str, training_mode: 训练模式 aggregation_method: 聚合方法 """ - # 创建检查点目录 - checkpoint_dir = os.path.join(model_dir, 'checkpoints') + # 直接在模型根目录保存,不再创建子目录 + checkpoint_dir = model_dir os.makedirs(checkpoint_dir, exist_ok=True) - # 生成检查点文件名 - if training_mode == 'store' and store_id: - filename = f"{model_type}_store_{store_id}_{product_id}_epoch_{epoch_or_label}.pth" - elif training_mode == 'global' and aggregation_method: - filename = f"{model_type}_global_{product_id}_{aggregation_method}_epoch_{epoch_or_label}.pth" - else: - filename = f"{model_type}_product_{product_id}_epoch_{epoch_or_label}.pth" + # 修正:直接使用product_id作为唯一标识符,因为它已经包含了store_前缀或药品ID + filename = f"{model_type}_{product_id}_epoch_{epoch_or_label}.pth" checkpoint_path = os.path.join(checkpoint_dir, filename)