-完善店铺预测模块

This commit is contained in:
LYFxiaoan 2025-07-16 16:24:08 +08:00
parent a18c8dddf9
commit a1d9c60e61
10 changed files with 171 additions and 91 deletions

View File

@ -208,11 +208,15 @@ const startPrediction = async () => {
future_days: form.future_days, future_days: form.future_days,
start_date: form.start_date, start_date: form.start_date,
analyze_result: form.analyze_result, analyze_result: form.analyze_result,
store_id: form.store_id store_id: form.store_id,
// product_id
product_id: `store_${form.store_id}`
} }
const response = await axios.post('/api/predict', payload) // API
const response = await axios.post('/api/prediction', payload)
if (response.data.status === 'success') { if (response.data.status === 'success') {
predictionResult.value = response.data.data //
predictionResult.value = response.data
ElMessage.success('预测完成!') ElMessage.success('预测完成!')
await nextTick() await nextTick()
renderChart() renderChart()
@ -231,30 +235,58 @@ const renderChart = () => {
if (chart) { if (chart) {
chart.destroy() chart.destroy()
} }
const predictions = predictionResult.value.predictions
const labels = predictions.map(p => p.date) const historyData = predictionResult.value.history_data || []
const data = predictions.map(p => p.sales) const predictionData = predictionResult.value.prediction_data || []
const labels = [
...historyData.map(p => p.date),
...predictionData.map(p => p.date)
]
const historySales = historyData.map(p => p.sales)
// null
const predictionSales = [
...Array(historyData.length).fill(null),
...predictionData.map(p => p.predicted_sales)
]
chart = new Chart(chartCanvas.value, { chart = new Chart(chartCanvas.value, {
type: 'line', type: 'line',
data: { data: {
labels, labels,
datasets: [{ datasets: [
label: '预测销量', {
data, label: '历史销量',
borderColor: '#409EFF', data: historySales,
backgroundColor: 'rgba(64, 158, 255, 0.1)', borderColor: '#67C23A',
tension: 0.4, backgroundColor: 'rgba(103, 194, 58, 0.1)',
fill: true fill: false,
}] tension: 0.4
},
{
label: '预测销量',
data: predictionSales,
borderColor: '#409EFF',
backgroundColor: 'rgba(64, 158, 255, 0.1)',
borderDash: [5, 5], // 线
fill: false,
tension: 0.4
}
]
}, },
options: { options: {
responsive: true, responsive: true,
plugins: { plugins: {
title: { title: {
display: true, display: true,
text: '销量预测趋势图' text: '店铺销量历史与预测趋势图'
} }
} },
interaction: {
intersect: false,
mode: 'index',
},
} }
}) })
} }

View File

@ -121,3 +121,49 @@
### 11:45 - 项目总结与文档归档 ### 11:45 - 项目总结与文档归档
- **任务**: 根据用户要求,回顾整个调试过程,将所有问题、解决方案、优化思路和最终结论,按照日期和时间顺序,整理并更新到本开发日志中,形成一份高质量的技术档案。 - **任务**: 根据用户要求,回顾整个调试过程,将所有问题、解决方案、优化思路和最终结论,按照日期和时间顺序,整理并更新到本开发日志中,形成一份高质量的技术档案。
- **结果**: 本文档已更新完成。 - **结果**: 本文档已更新完成。
### 13:15 - 最终修复:根治模型标识符不一致问题
- **问题**: 经过再次测试和日志分析,发现即便是修正后,店铺模型的 `model_identifier` 在训练时依然被错误地构建为 `01010023_store_01010023`
- **根本原因**: `server/core/predictor.py``train_model` 方法中,在 `training_mode == 'store'` 的分支下,构建 `model_identifier` 的逻辑存在冗余和错误。
- **最终解决方案**: 删除了错误的拼接逻辑 `model_identifier = f"{store_id}_{product_id}"`,直接使用在之前步骤中已经被正确赋值为 `f"store_{store_id}"``product_id` 变量作为 `model_identifier`。这确保了从训练、保存到最终API查询店铺模型的唯一标识符始终保持一致。
### 13:30 - 最终修复(第二轮):根治模型保存路径错误
- **问题**: 即便修复了标识符,模型版本依然无法加载。
- **根本原因**: 通过分析训练日志,发现所有训练器(`transformer_trainer.py`, `mlstm_trainer.py`, `tcn_trainer.py`)中的 `save_checkpoint` 函数,都会强制在 `saved_models` 目录下创建一个 `checkpoints` 子目录,并将所有模型文件保存在其中。而负责查找模型的 `get_model_versions` 函数只在根目录查找,导致模型永远无法被发现。
- **最终解决方案**: 逐一修改了所有相关训练器文件中的 `save_checkpoint` 函数,移除了创建和使用 `checkpoints` 子目录的逻辑,确保所有模型都直接保存在 `saved_models` 根目录下。
- **结论**: 至此,模型保存的路径与查找的路径完全统一,从根本上解决了模型版本无法加载的问题。
### 13:40 - 最终修复(第三轮):统一所有训练器的模型保存逻辑
- **问题**: 在修复了 `transformer_trainer.py` 后,发现 `mlstm_trainer.py``tcn_trainer.py` 存在完全相同的路径和命名错误,导致问题依旧。
- **根本原因**: `save_checkpoint` 函数在所有训练器中都被错误地实现,它们都强制创建了 `checkpoints` 子目录,并使用了错误的逻辑来拼接文件名。
- **最终解决方案**:
1. **逐一修复**: 逐一修改了 `transformer_trainer.py`, `mlstm_trainer.py`, 和 `tcn_trainer.py` 中的 `save_checkpoint` 函数。
2. **路径修复**: 移除了创建和使用 `checkpoints` 子目录的逻辑,确保模型直接保存在 `model_dir` (即 `saved_models`) 的根目录下。
3. **文件名修复**: 简化并修正了文件名的生成逻辑,直接使用 `product_id` 参数作为唯一标识符(该参数已由上游逻辑正确赋值为 `药品ID``store_{店铺ID}`),不再进行任何额外的、错误的拼接。
- **结论**: 至此所有训练器的模型保存逻辑完全统一模型保存的路径和文件名与API的查找逻辑完全匹配从根本上解决了模型版本无法加载的问题。
---
## 2025-07-16 (续):端到端修复“店铺预测”图表功能
**开发者**: lyf
### 15:30 - 最终修复(第四轮):打通店铺预测的数据流
- **问题**: 在解决了模型加载问题后,“店铺预测”功能虽然可以成功执行,但前端图表依然空白,不显示历史数据和预测数据。
- **根本原因**: 参数传递在调用链中出现断裂。
1. `server/api.py` 在调用 `run_prediction` 时,没有传递 `training_mode`
2. `server/core/predictor.py` 在调用 `load_model_and_predict` 时,没有传递 `store_id``training_mode`
3. `server/predictors/model_predictor.py` 内部的数据加载逻辑,在处理店铺预测时,错误地使用了模型标识符(`store_{id}`作为产品ID来过滤数据导致无法加载到任何历史数据。
- **最终解决方案 (三步修复)**:
1. **修复 `model_predictor.py`**: 修改 `load_model_and_predict` 函数,使其能够根据 `training_mode` 参数智能地加载数据。当模式为 `'store'` 时,它会正确地聚合该店铺的所有销售数据作为历史数据,这与训练时的数据准备方式完全一致。
2. **修复 `predictor.py`**: 修改 `predict` 方法,将 `store_id``training_mode` 参数正确地传递给底层的 `load_model_and_predict` 函数。
3. **修复 `api.py`**: 修改 `predict` 路由和 `run_prediction` 辅助函数,确保 `training_mode` 参数在整个调用链中被完整传递。
- **结论**: 通过以上修复我们确保了从API接口到最底层数据加载器的参数传递是完整和正确的。现在无论是药品预测还是店铺预测系统都能够加载正确的历史数据用于图表绘制彻底解决了图表显示空白的问题。
### 16:16 - 项目状态更新
- **状态**: **所有已知问题已修复**
- **确认**: 用户已确认“现在药品和店铺预测流程通了”。
- **后续**: 将本次修复过程归档至本文档。

Binary file not shown.

View File

@ -1510,7 +1510,8 @@ def predict():
data = request.json data = request.json
product_id = data.get('product_id') product_id = data.get('product_id')
model_type = data.get('model_type') model_type = data.get('model_type')
store_id = data.get('store_id') # 新增店铺ID参数 store_id = data.get('store_id')
training_mode = 'store' if store_id else 'product'
version = data.get('version') # 新增版本参数 version = data.get('version') # 新增版本参数
future_days = int(data.get('future_days', 7)) future_days = int(data.get('future_days', 7))
start_date = data.get('start_date', '') start_date = data.get('start_date', '')
@ -1527,28 +1528,31 @@ def predict():
if not product_name: if not product_name:
product_name = product_id product_name = product_id
# 根据版本获取模型ID # 根据训练模式构建模型标识符
if version: if training_mode == 'store':
# 如果指定了版本构造版本化的模型ID model_identifier = f"store_{store_id}"
model_id = f"{product_id}_{model_type}_{version}" # 对于店铺预测product_id实际上是store_id但我们需要一个药品ID来获取名称这里暂时用一个占位符
# 检查指定版本的模型是否存在 product_name = f"店铺 {store_id} 整体"
model_file_path = get_model_file_path(product_id, model_type, version)
if not os.path.exists(model_file_path):
return jsonify({"status": "error", "error": f"未找到产品 {product_id}{model_type} 类型模型版本 {version}"}), 404
else: else:
# 如果没有指定版本,使用最新版本 model_identifier = product_id
latest_version = get_latest_model_version(product_id, model_type) product_name = get_product_name(product_id) or product_id
if latest_version:
model_id = f"{product_id}_{model_type}_{latest_version}" # 获取模型版本
version = latest_version if not version:
else: version = get_latest_model_version(model_identifier, model_type)
# 兼容旧的无版本模型
model_id = get_latest_model_id(model_type, product_id) if not version:
if not model_id: return jsonify({"status": "error", "error": f"未找到标识符为 {model_identifier}{model_type} 类型模型"}), 404
return jsonify({"status": "error", "error": f"未找到产品 {product_id}{model_type} 类型模型"}), 404
# 检查模型文件是否存在
model_file_path = get_model_file_path(model_identifier, model_type, version)
if not os.path.exists(model_file_path):
return jsonify({"status": "error", "error": f"未找到模型文件: {model_file_path}"}), 404
model_id = f"{model_identifier}_{model_type}_{version}"
# 执行预测 # 执行预测
prediction_result = run_prediction(model_type, product_id, model_id, future_days, start_date, version, store_id) prediction_result = run_prediction(model_type, product_id, model_id, future_days, start_date, version, store_id, training_mode)
if prediction_result is None: if prediction_result is None:
return jsonify({"status": "error", "error": "预测失败预测器返回None"}), 500 return jsonify({"status": "error", "error": "预测失败预测器返回None"}), 500
@ -2708,7 +2712,7 @@ def get_product_name(product_id):
return None return None
# 执行预测的辅助函数 # 执行预测的辅助函数
def run_prediction(model_type, product_id, model_id, future_days, start_date, version=None, store_id=None): def run_prediction(model_type, product_id, model_id, future_days, start_date, version=None, store_id=None, training_mode='product'):
"""执行模型预测""" """执行模型预测"""
try: try:
scope_msg = f", store_id={store_id}" if store_id else ", 全局模型" scope_msg = f", store_id={store_id}" if store_id else ", 全局模型"
@ -2729,7 +2733,8 @@ def run_prediction(model_type, product_id, model_id, future_days, start_date, ve
store_id=store_id, store_id=store_id,
future_days=future_days, future_days=future_days,
start_date=start_date, start_date=start_date,
version=version version=version,
training_mode=training_mode
) )
if prediction_result is None: if prediction_result is None:

View File

@ -131,10 +131,11 @@ def get_model_file_path(product_id: str, model_type: str, version: str) -> str:
filename = f"{model_type}_model_product_{product_id}.pth" filename = f"{model_type}_model_product_{product_id}.pth"
return os.path.join(DEFAULT_MODEL_DIR, filename) return os.path.join(DEFAULT_MODEL_DIR, filename)
# 处理新的、基于epoch的检查点命名格式 # 修正直接使用唯一的product_id它可能包含store_前缀来构建文件名
# 文件名示例: transformer_product_17002608_epoch_best.pth # 文件名示例: transformer_17002608_epoch_best.pth 或 transformer_store_01010023_epoch_best.pth
filename = f"{model_type}_product_{product_id}_epoch_{version}.pth" filename = f"{model_type}_{product_id}_epoch_{version}.pth"
return os.path.join(DEFAULT_MODEL_DIR, 'checkpoints', filename) # 修正直接在根模型目录查找不再使用checkpoints子目录
return os.path.join(DEFAULT_MODEL_DIR, filename)
def get_model_versions(product_id: str, model_type: str) -> list: def get_model_versions(product_id: str, model_type: str) -> list:
""" """
@ -149,9 +150,10 @@ def get_model_versions(product_id: str, model_type: str) -> list:
""" """
# 直接使用传入的product_id构建搜索模式 # 直接使用传入的product_id构建搜索模式
# 搜索模式,匹配 "transformer_product_17002608_epoch_50.pth" 或 "transformer_product_17002608_epoch_best.pth" # 搜索模式,匹配 "transformer_product_17002608_epoch_50.pth" 或 "transformer_product_17002608_epoch_best.pth"
pattern = f"{model_type}_product_{product_id}_epoch_*.pth" # 修正直接使用唯一的product_id它可能包含store_前缀来构建搜索模式
# 在 checkpoints 子目录中查找 pattern = f"{model_type}_{product_id}_epoch_*.pth"
search_path = os.path.join(DEFAULT_MODEL_DIR, 'checkpoints', pattern) # 修正直接在根模型目录查找不再使用checkpoints子目录
search_path = os.path.join(DEFAULT_MODEL_DIR, pattern)
existing_files = glob.glob(search_path) existing_files = glob.glob(search_path)
# 旧格式(兼容性支持) # 旧格式(兼容性支持)

View File

@ -132,8 +132,8 @@ class PharmacyPredictor:
file_path=self.data_path file_path=self.data_path
) )
log_message(f"按店铺聚合训练: 店铺 {store_id}, 聚合方法 {aggregation_method}, 数据量: {len(product_data)}") log_message(f"按店铺聚合训练: 店铺 {store_id}, 聚合方法 {aggregation_method}, 数据量: {len(product_data)}")
# 将product_id设置为店铺ID以便模型保存时使用有意义的标识 # 将product_id设置为'store_{store_id}'与API查找逻辑保持一致
product_id = store_id product_id = f"store_{store_id}"
except Exception as e: except Exception as e:
log_message(f"聚合店铺 {store_id} 数据失败: {e}", 'error') log_message(f"聚合店铺 {store_id} 数据失败: {e}", 'error')
return None return None
@ -179,7 +179,7 @@ class PharmacyPredictor:
# 根据训练模式构建模型标识符 # 根据训练模式构建模型标识符
if training_mode == 'store': if training_mode == 'store':
model_identifier = f"{store_id}_{product_id}" model_identifier = product_id
elif training_mode == 'global': elif training_mode == 'global':
model_identifier = f"global_{product_id}_{aggregation_method}" model_identifier = f"global_{product_id}_{aggregation_method}"
else: else:
@ -308,7 +308,8 @@ class PharmacyPredictor:
""" """
# 根据训练模式构建模型标识符 # 根据训练模式构建模型标识符
if training_mode == 'store' and store_id: if training_mode == 'store' and store_id:
model_identifier = f"{store_id}_{product_id}" # 修正店铺模型的标识符应该只基于店铺ID
model_identifier = f"store_{store_id}"
elif training_mode == 'global': elif training_mode == 'global':
model_identifier = f"global_{product_id}_{aggregation_method}" model_identifier = f"global_{product_id}_{aggregation_method}"
else: else:
@ -317,10 +318,12 @@ class PharmacyPredictor:
return load_model_and_predict( return load_model_and_predict(
model_identifier, model_identifier,
model_type, model_type,
store_id=store_id,
future_days=future_days, future_days=future_days,
start_date=start_date, start_date=start_date,
analyze_result=analyze_result, analyze_result=analyze_result,
version=version version=version,
training_mode=training_mode
) )
def train_optimized_kan_model(self, product_id, epochs=100, batch_size=32, def train_optimized_kan_model(self, product_id, epochs=100, batch_size=32,

View File

@ -23,7 +23,7 @@ from utils.visualization import plot_prediction_results
from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data
from core.config import DEVICE, get_model_file_path, DEFAULT_DATA_PATH from core.config import DEVICE, get_model_file_path, DEFAULT_DATA_PATH
def load_model_and_predict(product_id, model_type, store_id=None, future_days=7, start_date=None, analyze_result=False, version=None): def load_model_and_predict(product_id, model_type, store_id=None, future_days=7, start_date=None, analyze_result=False, version=None, training_mode='product'):
""" """
加载已训练的模型并进行预测 加载已训练的模型并进行预测
@ -101,33 +101,37 @@ def load_model_and_predict(product_id, model_type, store_id=None, future_days=7,
# 加载销售数据(支持多店铺) # 加载销售数据(支持多店铺)
try: try:
from utils.multi_store_data_utils import load_multi_store_data from utils.multi_store_data_utils import aggregate_multi_store_data
if store_id:
# 加载特定店铺的数据 # 根据训练模式加载相应的数据
product_df = load_multi_store_data( if training_mode == 'store' and store_id:
file_path=DEFAULT_DATA_PATH, # 店铺模型:聚合该店铺的所有产品数据
product_df = aggregate_multi_store_data(
store_id=store_id, store_id=store_id,
product_id=product_id aggregation_method='sum',
file_path=DEFAULT_DATA_PATH
) )
store_name = product_df['store_name'].iloc[0] if 'store_name' in product_df.columns else f"店铺{store_id}" store_name = product_df['store_name'].iloc[0] if 'store_name' in product_df.columns and not product_df.empty else f"店铺{store_id}"
prediction_scope = f"店铺 '{store_name}' ({store_id})" prediction_scope = f"店铺 '{store_name}' ({store_id})"
# 对于店铺模型,其“产品名称”就是店铺名称
product_name = store_name
else: else:
# 聚合所有店铺的数据进行预测 # 产品模型(默认):聚合该产品在所有店铺的数据
# 此时传入的product_id是真正的产品ID
product_df = aggregate_multi_store_data( product_df = aggregate_multi_store_data(
product_id=product_id, product_id=product_id,
aggregation_method='sum', aggregation_method='sum',
file_path=DEFAULT_DATA_PATH file_path=DEFAULT_DATA_PATH
) )
prediction_scope = "全部店铺(聚合数据)" prediction_scope = "全部店铺(聚合数据)"
product_name = product_df['product_name'].iloc[0] if not product_df.empty else product_id
except Exception as e: except Exception as e:
print(f"加载数据失败: {e}") print(f"加载数据失败: {e}")
return None return None
if product_df.empty: if product_df.empty:
print(f"产品 {product_id} 没有销售数据") print(f"产品 {product_id} 或店铺 {store_id} 没有销售数据")
return None return None
product_name = product_df['product_name'].iloc[0]
print(f"使用 {model_type} 模型预测产品 '{product_name}' (ID: {product_id}) 的未来 {future_days} 天销量") print(f"使用 {model_type} 模型预测产品 '{product_name}' (ID: {product_id}) 的未来 {future_days} 天销量")
print(f"预测范围: {prediction_scope}") print(f"预测范围: {prediction_scope}")

View File

@ -42,16 +42,12 @@ def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
aggregation_method: 聚合方法 aggregation_method: 聚合方法
""" """
# 创建检查点目录 # 创建检查点目录
checkpoint_dir = os.path.join(model_dir, 'checkpoints') # 直接在模型根目录保存,不再创建子目录
checkpoint_dir = model_dir
os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True)
# 生成检查点文件名 # 修正直接使用product_id作为唯一标识符因为它已经包含了store_前缀或药品ID
if training_mode == 'store' and store_id: filename = f"{model_type}_{product_id}_epoch_{epoch_or_label}.pth"
filename = f"{model_type}_store_{store_id}_{product_id}_epoch_{epoch_or_label}.pth"
elif training_mode == 'global' and aggregation_method:
filename = f"{model_type}_global_{product_id}_{aggregation_method}_epoch_{epoch_or_label}.pth"
else:
filename = f"{model_type}_product_{product_id}_epoch_{epoch_or_label}.pth"
checkpoint_path = os.path.join(checkpoint_dir, filename) checkpoint_path = os.path.join(checkpoint_dir, filename)

View File

@ -38,16 +38,13 @@ def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
aggregation_method: 聚合方法 aggregation_method: 聚合方法
""" """
# 创建检查点目录 # 创建检查点目录
checkpoint_dir = os.path.join(model_dir, 'checkpoints') # 直接在模型根目录保存,不再创建子目录
checkpoint_dir = model_dir
os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True)
# 生成检查点文件名 # 生成检查点文件名
if training_mode == 'store' and store_id: # 修正直接使用product_id作为唯一标识符因为它已经包含了store_前缀或药品ID
filename = f"{model_type}_store_{store_id}_{product_id}_epoch_{epoch_or_label}.pth" filename = f"{model_type}_{product_id}_epoch_{epoch_or_label}.pth"
elif training_mode == 'global' and aggregation_method:
filename = f"{model_type}_global_{product_id}_{aggregation_method}_epoch_{epoch_or_label}.pth"
else:
filename = f"{model_type}_product_{product_id}_epoch_{epoch_or_label}.pth"
checkpoint_path = os.path.join(checkpoint_dir, filename) checkpoint_path = os.path.join(checkpoint_dir, filename)

View File

@ -43,17 +43,12 @@ def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
training_mode: 训练模式 training_mode: 训练模式
aggregation_method: 聚合方法 aggregation_method: 聚合方法
""" """
# 创建检查点目录 # 直接在模型根目录保存,不再创建子目录
checkpoint_dir = os.path.join(model_dir, 'checkpoints') checkpoint_dir = model_dir
os.makedirs(checkpoint_dir, exist_ok=True) os.makedirs(checkpoint_dir, exist_ok=True)
# 生成检查点文件名 # 修正直接使用product_id作为唯一标识符因为它已经包含了store_前缀或药品ID
if training_mode == 'store' and store_id: filename = f"{model_type}_{product_id}_epoch_{epoch_or_label}.pth"
filename = f"{model_type}_store_{store_id}_{product_id}_epoch_{epoch_or_label}.pth"
elif training_mode == 'global' and aggregation_method:
filename = f"{model_type}_global_{product_id}_{aggregation_method}_epoch_{epoch_or_label}.pth"
else:
filename = f"{model_type}_product_{product_id}_epoch_{epoch_or_label}.pth"
checkpoint_path = os.path.join(checkpoint_dir, filename) checkpoint_path = os.path.join(checkpoint_dir, filename)