Compare commits
4 Commits
cc30295f1d
...
ee9ba299fa
Author | SHA1 | Date | |
---|---|---|---|
ee9ba299fa | |||
a1d9c60e61 | |||
a18c8dddf9 | |||
398e949935 |
@ -185,13 +185,14 @@ const startPrediction = async () => {
|
||||
try {
|
||||
predicting.value = true
|
||||
const payload = {
|
||||
training_mode: 'global', // 明确指定训练模式
|
||||
model_type: form.model_type,
|
||||
version: form.version,
|
||||
future_days: form.future_days,
|
||||
start_date: form.start_date,
|
||||
analyze_result: form.analyze_result
|
||||
}
|
||||
const response = await axios.post('/api/predict', payload)
|
||||
const response = await axios.post('/api/prediction', payload)
|
||||
if (response.data.status === 'success') {
|
||||
predictionResult.value = response.data.data
|
||||
ElMessage.success('预测完成!')
|
||||
|
@ -203,24 +203,27 @@ const startPrediction = async () => {
|
||||
try {
|
||||
predicting.value = true
|
||||
const payload = {
|
||||
product_id: form.product_id,
|
||||
model_type: form.model_type,
|
||||
version: form.version,
|
||||
future_days: form.future_days,
|
||||
start_date: form.start_date,
|
||||
analyze_result: form.analyze_result,
|
||||
product_id: form.product_id
|
||||
include_visualization: form.analyze_result,
|
||||
}
|
||||
const response = await axios.post('/api/predict', payload)
|
||||
// Corrected API endpoint from /api/predict to /api/prediction
|
||||
const response = await axios.post('/api/prediction', payload)
|
||||
if (response.data.status === 'success') {
|
||||
predictionResult.value = response.data.data
|
||||
// The backend response may have history_data and prediction_data at the top level
|
||||
predictionResult.value = response.data
|
||||
ElMessage.success('预测完成!')
|
||||
await nextTick()
|
||||
renderChart()
|
||||
} else {
|
||||
ElMessage.error(response.data.message || '预测失败')
|
||||
ElMessage.error(response.data.error || '预测失败')
|
||||
}
|
||||
} catch (error) {
|
||||
ElMessage.error('预测请求失败')
|
||||
} catch (error)
|
||||
{
|
||||
ElMessage.error(error.response?.data?.error || '预测请求失败')
|
||||
} finally {
|
||||
predicting.value = false
|
||||
}
|
||||
@ -231,28 +234,87 @@ const renderChart = () => {
|
||||
if (chart) {
|
||||
chart.destroy()
|
||||
}
|
||||
const predictions = predictionResult.value.predictions
|
||||
const labels = predictions.map(p => p.date)
|
||||
const data = predictions.map(p => p.sales)
|
||||
|
||||
// Backend provides history_data and prediction_data
|
||||
const historyData = predictionResult.value.history_data || []
|
||||
const predictionData = predictionResult.value.prediction_data || []
|
||||
|
||||
if (historyData.length === 0 && predictionData.length === 0) {
|
||||
ElMessage.warning('没有可用于图表的数据。')
|
||||
return
|
||||
}
|
||||
|
||||
const historyLabels = historyData.map(p => p.date)
|
||||
const historySales = historyData.map(p => p.sales)
|
||||
|
||||
const predictionLabels = predictionData.map(p => p.date)
|
||||
const predictionSales = predictionData.map(p => p.predicted_sales)
|
||||
|
||||
// Combine labels and remove duplicates, then sort
|
||||
const allLabels = [...new Set([...historyLabels, ...predictionLabels])].sort()
|
||||
|
||||
// Create a mapping of label to sales data for easier lookup
|
||||
const historyMap = new Map(historyData.map(p => [p.date, p.sales]))
|
||||
const predictionMap = new Map(predictionData.map(p => [p.date, p.predicted_sales]))
|
||||
|
||||
// Align data with the sorted labels
|
||||
const alignedHistorySales = allLabels.map(label => historyMap.get(label) ?? null)
|
||||
const alignedPredictionSales = allLabels.map(label => predictionMap.get(label) ?? null)
|
||||
|
||||
// The last point of history should connect to the first point of prediction for a smooth graph
|
||||
if (historyData.length > 0 && predictionData.length > 0) {
|
||||
const lastHistoryDate = historyLabels[historyLabels.length - 1]
|
||||
const lastHistoryValue = historySales[historySales.length - 1]
|
||||
if (!predictionMap.has(lastHistoryDate)) {
|
||||
alignedPredictionSales[allLabels.indexOf(lastHistoryDate)] = lastHistoryValue
|
||||
}
|
||||
}
|
||||
|
||||
chart = new Chart(chartCanvas.value, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels,
|
||||
datasets: [{
|
||||
labels: allLabels,
|
||||
datasets: [
|
||||
{
|
||||
label: '历史销量',
|
||||
data: alignedHistorySales,
|
||||
borderColor: '#67C23A',
|
||||
backgroundColor: 'rgba(103, 194, 58, 0.1)',
|
||||
tension: 0.1,
|
||||
spanGaps: false, // Do not draw line over nulls
|
||||
},
|
||||
{
|
||||
label: '预测销量',
|
||||
data,
|
||||
data: alignedPredictionSales,
|
||||
borderColor: '#409EFF',
|
||||
backgroundColor: 'rgba(64, 158, 255, 0.1)',
|
||||
tension: 0.4,
|
||||
fill: true
|
||||
}]
|
||||
tension: 0.1,
|
||||
fill: true,
|
||||
borderDash: [5, 5], // Dashed line for predictions
|
||||
}
|
||||
]
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
plugins: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '销量预测趋势图'
|
||||
text: `“${form.product_id}” - 销量预测趋势图`
|
||||
}
|
||||
},
|
||||
scales: {
|
||||
x: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '日期'
|
||||
}
|
||||
},
|
||||
y: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '销量'
|
||||
},
|
||||
beginAtZero: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -208,11 +208,15 @@ const startPrediction = async () => {
|
||||
future_days: form.future_days,
|
||||
start_date: form.start_date,
|
||||
analyze_result: form.analyze_result,
|
||||
store_id: form.store_id
|
||||
store_id: form.store_id,
|
||||
// 修正:对于店铺模型,product_id应传递店铺的标识符
|
||||
product_id: `store_${form.store_id}`
|
||||
}
|
||||
const response = await axios.post('/api/predict', payload)
|
||||
// 修正API端点
|
||||
const response = await axios.post('/api/prediction', payload)
|
||||
if (response.data.status === 'success') {
|
||||
predictionResult.value = response.data.data
|
||||
// 修正:数据现在直接在响应的顶层
|
||||
predictionResult.value = response.data
|
||||
ElMessage.success('预测完成!')
|
||||
await nextTick()
|
||||
renderChart()
|
||||
@ -231,30 +235,58 @@ const renderChart = () => {
|
||||
if (chart) {
|
||||
chart.destroy()
|
||||
}
|
||||
const predictions = predictionResult.value.predictions
|
||||
const labels = predictions.map(p => p.date)
|
||||
const data = predictions.map(p => p.sales)
|
||||
|
||||
const historyData = predictionResult.value.history_data || []
|
||||
const predictionData = predictionResult.value.prediction_data || []
|
||||
|
||||
const labels = [
|
||||
...historyData.map(p => p.date),
|
||||
...predictionData.map(p => p.date)
|
||||
]
|
||||
|
||||
const historySales = historyData.map(p => p.sales)
|
||||
// 预测数据需要填充与历史数据等长的null值,以保证图表正确对齐
|
||||
const predictionSales = [
|
||||
...Array(historyData.length).fill(null),
|
||||
...predictionData.map(p => p.predicted_sales)
|
||||
]
|
||||
|
||||
chart = new Chart(chartCanvas.value, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels,
|
||||
datasets: [{
|
||||
datasets: [
|
||||
{
|
||||
label: '历史销量',
|
||||
data: historySales,
|
||||
borderColor: '#67C23A',
|
||||
backgroundColor: 'rgba(103, 194, 58, 0.1)',
|
||||
fill: false,
|
||||
tension: 0.4
|
||||
},
|
||||
{
|
||||
label: '预测销量',
|
||||
data,
|
||||
data: predictionSales,
|
||||
borderColor: '#409EFF',
|
||||
backgroundColor: 'rgba(64, 158, 255, 0.1)',
|
||||
tension: 0.4,
|
||||
fill: true
|
||||
}]
|
||||
borderDash: [5, 5], // 虚线
|
||||
fill: false,
|
||||
tension: 0.4
|
||||
}
|
||||
]
|
||||
},
|
||||
options: {
|
||||
responsive: true,
|
||||
plugins: {
|
||||
title: {
|
||||
display: true,
|
||||
text: '销量预测趋势图'
|
||||
}
|
||||
text: '店铺销量历史与预测趋势图'
|
||||
}
|
||||
},
|
||||
interaction: {
|
||||
intersect: false,
|
||||
mode: 'index',
|
||||
},
|
||||
}
|
||||
})
|
||||
}
|
||||
|
200
lyf开发日志记录文档.md
Normal file
200
lyf开发日志记录文档.md
Normal file
@ -0,0 +1,200 @@
|
||||
# 开发日志记录
|
||||
|
||||
本文档记录了项目开发过程中的主要修改、问题修复和重要决策。
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-13:早期后端修复与重构
|
||||
**开发者**: lyf
|
||||
|
||||
### 13:30 - 修复数据加载路径问题
|
||||
- **任务目标**: 解决模型训练时因数据文件路径错误导致的数据加载失败问题。
|
||||
- **核心问题**: `server/core/predictor.py` 中的 `PharmacyPredictor` 类初始化时,硬编码了错误的默认数据文件路径。
|
||||
- **修复方案**: 将默认数据路径更正为 `'data/timeseries_training_data_sample_10s50p.parquet'`,并同步更新了所有训练器。
|
||||
|
||||
### 14:00 - 数据流重构
|
||||
- **任务目标**: 解决因数据处理流程中断导致关键特征丢失,从而引发模型训练失败的根本问题。
|
||||
- **核心问题**: `predictor.py` 未将预处理好的数据向下传递,导致各训练器重复加载并错误处理数据。
|
||||
- **修复方案**: 重构了核心数据流,确保数据在 `predictor.py` 中被统一加载和预处理,然后作为一个DataFrame显式传递给所有下游的训练器函数。
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-14:模型训练与并发问题集中攻坚
|
||||
**开发者**: lyf
|
||||
|
||||
### 10:16 - 修复训练器层 `KeyError`
|
||||
- **问题**: 所有模型训练均因 `KeyError: "['sales', 'price'] not in index"` 失败。
|
||||
- **分析**: 训练器硬编码的特征列表中包含了数据源中不存在的 `'price'` 列。
|
||||
- **修复**: 从所有四个训练器 (`mlstm`, `transformer`, `tcn`, `kan`) 的 `features` 列表中移除了对不存在的 `'price'` 列的依赖。
|
||||
|
||||
### 10:38 - 修复数据标准化层 `KeyError`
|
||||
- **问题**: 修复后出现新错误 `KeyError: "['sales'] not in index"`。
|
||||
- **分析**: `server/utils/multi_store_data_utils.py` 中的 `standardize_column_names` 函数列名映射错误,且缺少最终列选择机制。
|
||||
- **修复**: 修正了列名映射,并增加了列选择机制,确保函数返回的 `DataFrame` 结构统一且包含 `sales` 列。
|
||||
|
||||
### 11:04 - 修复JSON序列化失败问题
|
||||
- **问题**: 训练完成后,因 `Object of type float32 is not JSON serializable` 导致前后端通信失败。
|
||||
- **分析**: 训练产生的评估指标是NumPy的 `float32` 类型,无法被标准 `json` 库序列化。
|
||||
- **修复**: 在 `server/utils/training_process_manager.py` 中增加了 `convert_numpy_types` 辅助函数,在通过WebSocket或API返回数据前,将所有NumPy数值类型转换为Python原生类型,从根源上解决了所有序列化问题。
|
||||
|
||||
### 11:15 - 修复MAPE计算错误
|
||||
- **问题**: 训练日志显示 `MAPE: nan%` 并伴有 `RuntimeWarning: Mean of empty slice.`。
|
||||
- **分析**: 当测试集中的所有真实值都为0时,计算MAPE会导致对空数组求平均值。
|
||||
- **修复**: 在 `server/analysis/metrics.py` 中增加条件判断,若不存在非零真实值,则直接将MAPE设为0。
|
||||
|
||||
### 11:41 - 修复“按店铺训练”页面列表加载失败
|
||||
- **问题**: “选择店铺”的下拉列表为空。
|
||||
- **分析**: `standardize_column_names` 函数错误地移除了包括店铺元数据在内的非训练必需列。
|
||||
- **修复**: 将列筛选的逻辑从通用的 `standardize_column_names` 函数中移出,精确地应用到仅为模型训练准备数据的函数中。
|
||||
|
||||
### 13:00 - 修复“按店铺训练-所有药品”模式
|
||||
- **问题**: 选择“所有药品”训练时,因 `product_id` 被错误地处理为字符串 `"unknown"` 而失败。
|
||||
- **修复**: 在 `server/core/predictor.py` 中拦截 `"unknown"` ID,并将其意图正确地转换为“聚合此店铺的所有产品数据”。同时扩展了 `aggregate_multi_store_data` 函数,使其支持按店铺ID进行聚合。
|
||||
|
||||
### 14:19 - 修复并发训练中的稳定性问题
|
||||
- **问题**: 并发训练时出现 `API列表排序错误` 和 `WebSocket连接错误`。
|
||||
- **修复**:
|
||||
1. **排序**: 在 `api.py` 中为 `None` 类型的 `start_time` 提供了默认值,解决了 `TypeError`。
|
||||
2. **连接**: 在 `socketio.run()` 调用时增加了 `allow_unsafe_werkzeug=True` 参数,解决了调试模式下Socket.IO与Werkzeug的冲突。
|
||||
|
||||
### 15:30 - 根治模型训练中的维度不匹配问题
|
||||
- **问题**: 所有模型训练完成后,评估指标 `R²` 始终为0.0。
|
||||
- **根本原因**: `server/utils/data_utils.py` 的 `create_dataset` 函数在创建目标数据集 `dataY` 时,错误地保留了一个多余的维度。同时,模型文件 (`mlstm_model.py`, `transformer_model.py`) 的输出也存在维度问题。
|
||||
- **最终修复**:
|
||||
1. **数据层**: 在 `create_dataset` 中使用 `.flatten()` 修正了 `y` 标签的维度。
|
||||
2. **模型层**: 在所有模型的 `forward` 方法最后增加了 `.squeeze(-1)`,确保模型输出维度正确。
|
||||
3. **训练器层**: 撤销了所有为解决此问题而做的临时性维度调整,恢复了最直接的损失计算。
|
||||
|
||||
### 16:10 - 修复“全局模型训练-所有药品”模式
|
||||
- **问题**: 与“按店铺训练”类似,全局训练的“所有药品”模式也因 `product_id="unknown"` 而失败。
|
||||
- **修复**: 采用了与店铺训练完全相同的修复模式。在 `predictor.py` 中拦截 `"unknown"` 并将其意图转换为真正的全局聚合(`product_id=None`),并扩展 `aggregate_multi_store_data` 函数以支持此功能。
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-15:端到端修复“按药品预测”图表功能
|
||||
**开发者**: lyf
|
||||
|
||||
### 10:00 - 阶段一:修复数据库写入失败 (`sqlite3.IntegrityError`)
|
||||
- **问题**: 后端日志显示 `datatype mismatch`。
|
||||
- **分析**: `save_prediction_result` 函数试图将复杂Python对象直接存入数据库。
|
||||
- **修复**: 在 `server/api.py` 中,执行数据库插入前,使用 `json.dumps()` 将复杂对象序列化为JSON字符串。
|
||||
|
||||
### 10:30 - 阶段二:修复API响应结构与前端不匹配
|
||||
- **问题**: 图表依然无法渲染。
|
||||
- **分析**: 前端期望 `history_data` 在顶层,而后端将其封装在 `data` 子对象中。
|
||||
- **修复**: 修改 `server/api.py` 的 `predict` 函数,将关键数据提升到响应的根级别。
|
||||
|
||||
### 11:00 - 阶段三:修复历史数据与预测数据时间不连续
|
||||
- **问题**: 图表数据在时间上完全脱节。
|
||||
- **分析**: 获取历史数据的逻辑总是取整个数据集的最后30条,而非预测起始日期之前的30条。
|
||||
- **修复**: 在 `server/api.py` 中增加了正确的日期筛选逻辑。
|
||||
|
||||
### 14:00 - 阶段四:重构数据源,根治数据不一致问题
|
||||
- **问题**: 历史数据(绿线)与预测数据(蓝线)的口径完全不同。
|
||||
- **根本原因**: API层独立加载**原始数据**画图,而预测器使用**聚合后数据**预测。
|
||||
- **修复 (重构)**:
|
||||
1. 修改 `server/predictors/model_predictor.py`,使其返回预测结果的同时,也返回其所使用的、口径一致的历史数据。
|
||||
2. 彻底删除了 `server/api.py` 中所有独立加载历史数据的冗余代码,确保了数据源的唯一性。
|
||||
|
||||
### 15:00 - 阶段五:修复图表X轴日期格式问题
|
||||
- **问题**: X轴显示为混乱的GMT格式时间戳。
|
||||
- **分析**: `history_data` 中的 `Timestamp` 对象未被正确格式化。
|
||||
- **修复**: 在 `server/api.py` 中,为 `history_data` 增加了 `.strftime('%Y-%m-%d')` 的格式化处理。
|
||||
|
||||
### 16:00 - 阶段六:修复模型“学不会”的根本原因 (超参数传递中断)
|
||||
- **问题**: 即便流程正确,所有模型的预测结果依然是无法学习的直线。
|
||||
- **根本原因**: `server/core/predictor.py` 在调用训练器时,**没有将 `sequence_length` 等关键超参数传递下去**,导致所有模型都使用了错误的默认值。
|
||||
- **修复**:
|
||||
1. 修改 `server/core/predictor.py`,在调用中加入超参数的传递。
|
||||
2. 修改所有四个训练器文件,使其能接收并使用这些参数。
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-16:最终验证与项目总结
|
||||
**开发者**: lyf
|
||||
|
||||
### 10:00 - 阶段七:最终验证与结论
|
||||
- **问题**: 在修复所有代码问题后,对特定日期的预测结果依然是平线。
|
||||
- **分析**: 通过编写临时数据分析脚本 (`temp_check_parquet.py`) 最终确认,这是**数据本身**的问题。我们选择的预测日期在样本数据集中恰好处于一个“零销量”的空白期。
|
||||
- **最终结论**: 系统代码已完全修复。图表上显示的平线,是模型对“零销量”历史做出的**正确且符合逻辑**的反应。
|
||||
|
||||
### 11:45 - 项目总结与文档归档
|
||||
- **任务**: 根据用户要求,回顾整个调试过程,将所有问题、解决方案、优化思路和最终结论,按照日期和时间顺序,整理并更新到本开发日志中,形成一份高质量的技术档案。
|
||||
- **结果**: 本文档已更新完成。
|
||||
|
||||
|
||||
### 13:15 - 最终修复:根治模型标识符不一致问题
|
||||
- **问题**: 经过再次测试和日志分析,发现即便是修正后,店铺模型的 `model_identifier` 在训练时依然被错误地构建为 `01010023_store_01010023`。
|
||||
- **根本原因**: `server/core/predictor.py` 的 `train_model` 方法中,在 `training_mode == 'store'` 的分支下,构建 `model_identifier` 的逻辑存在冗余和错误。
|
||||
- **最终解决方案**: 删除了错误的拼接逻辑 `model_identifier = f"{store_id}_{product_id}"`,直接使用在之前步骤中已经被正确赋值为 `f"store_{store_id}"` 的 `product_id` 变量作为 `model_identifier`。这确保了从训练、保存到最终API查询,店铺模型的唯一标识符始终保持一致。
|
||||
|
||||
|
||||
### 13:30 - 最终修复(第二轮):根治模型保存路径错误
|
||||
- **问题**: 即便修复了标识符,模型版本依然无法加载。
|
||||
- **根本原因**: 通过分析训练日志,发现所有训练器(`transformer_trainer.py`, `mlstm_trainer.py`, `tcn_trainer.py`)中的 `save_checkpoint` 函数,都会强制在 `saved_models` 目录下创建一个 `checkpoints` 子目录,并将所有模型文件保存在其中。而负责查找模型的 `get_model_versions` 函数只在根目录查找,导致模型永远无法被发现。
|
||||
- **最终解决方案**: 逐一修改了所有相关训练器文件中的 `save_checkpoint` 函数,移除了创建和使用 `checkpoints` 子目录的逻辑,确保所有模型都直接保存在 `saved_models` 根目录下。
|
||||
- **结论**: 至此,模型保存的路径与查找的路径完全统一,从根本上解决了模型版本无法加载的问题。
|
||||
|
||||
|
||||
### 13:40 - 最终修复(第三轮):统一所有训练器的模型保存逻辑
|
||||
- **问题**: 在修复了 `transformer_trainer.py` 后,发现 `mlstm_trainer.py` 和 `tcn_trainer.py` 存在完全相同的路径和命名错误,导致问题依旧。
|
||||
- **根本原因**: `save_checkpoint` 函数在所有训练器中都被错误地实现,它们都强制创建了 `checkpoints` 子目录,并使用了错误的逻辑来拼接文件名。
|
||||
- **最终解决方案**:
|
||||
1. **逐一修复**: 逐一修改了 `transformer_trainer.py`, `mlstm_trainer.py`, 和 `tcn_trainer.py` 中的 `save_checkpoint` 函数。
|
||||
2. **路径修复**: 移除了创建和使用 `checkpoints` 子目录的逻辑,确保模型直接保存在 `model_dir` (即 `saved_models`) 的根目录下。
|
||||
3. **文件名修复**: 简化并修正了文件名的生成逻辑,直接使用 `product_id` 参数作为唯一标识符(该参数已由上游逻辑正确赋值为 `药品ID` 或 `store_{店铺ID}`),不再进行任何额外的、错误的拼接。
|
||||
- **结论**: 至此,所有训练器的模型保存逻辑完全统一,模型保存的路径和文件名与API的查找逻辑完全匹配,从根本上解决了模型版本无法加载的问题。
|
||||
|
||||
|
||||
---
|
||||
|
||||
## 2025-07-16 (续):端到端修复“店铺预测”图表功能
|
||||
**开发者**: lyf
|
||||
|
||||
### 15:30 - 最终修复(第四轮):打通店铺预测的数据流
|
||||
- **问题**: 在解决了模型加载问题后,“店铺预测”功能虽然可以成功执行,但前端图表依然空白,不显示历史数据和预测数据。
|
||||
- **根本原因**: 参数传递在调用链中出现断裂。
|
||||
1. `server/api.py` 在调用 `run_prediction` 时,没有传递 `training_mode`。
|
||||
2. `server/core/predictor.py` 在调用 `load_model_and_predict` 时,没有传递 `store_id` 和 `training_mode`。
|
||||
3. `server/predictors/model_predictor.py` 内部的数据加载逻辑,在处理店铺预测时,错误地使用了模型标识符(`store_{id}`)作为产品ID来过滤数据,导致无法加载到任何历史数据。
|
||||
- **最终解决方案 (三步修复)**:
|
||||
1. **修复 `model_predictor.py`**: 修改 `load_model_and_predict` 函数,使其能够根据 `training_mode` 参数智能地加载数据。当模式为 `'store'` 时,它会正确地聚合该店铺的所有销售数据作为历史数据,这与训练时的数据准备方式完全一致。
|
||||
2. **修复 `predictor.py`**: 修改 `predict` 方法,将 `store_id` 和 `training_mode` 参数正确地传递给底层的 `load_model_and_predict` 函数。
|
||||
3. **修复 `api.py`**: 修改 `predict` 路由和 `run_prediction` 辅助函数,确保 `training_mode` 参数在整个调用链中被完整传递。
|
||||
- **结论**: 通过以上修复,我们确保了从API接口到最底层数据加载器的参数传递是完整和正确的。现在,无论是药品预测还是店铺预测,系统都能够加载正确的历史数据用于图表绘制,彻底解决了图表显示空白的问题。
|
||||
|
||||
### 16:16 - 项目状态更新
|
||||
- **状态**: **所有已知问题已修复**。
|
||||
- **确认**: 用户已确认“现在药品和店铺预测流程通了”。
|
||||
- **后续**: 将本次修复过程归档至本文档。
|
||||
|
||||
|
||||
---
|
||||
|
||||
### 2025年7月16日 18:38 - 全模型预测功能通用性修复
|
||||
|
||||
**问题现象**:
|
||||
在解决了 `Transformer` 模型的预测问题后,发现一个更深层次的系统性问题:在所有预测模式(按药品、按店铺、全局)中,只有 `Transformer` 算法可以成功预测并显示图表,而其他四种模型(`mLSTM`, `KAN`, `优化版KAN`, `TCN`)虽然能成功训练,但在预测时均会失败,并提示“没有可用于图表的数据”。
|
||||
|
||||
**根本原因深度分析**:
|
||||
这个问题的核心在于**模型配置的持久化不完整且不统一**。
|
||||
|
||||
1. **Transformer 的“幸存”**: `Transformer` 模型的实现恰好不依赖于那些在保存时被遗漏的特定超参数,因此它能“幸存”下来。
|
||||
2. **其他模型的“共性缺陷”**: 其他所有模型 (`mLSTM`, `TCN`, `KAN`) 在它们的构造函数中,都依赖于一些在训练时定义、但在保存到检查点文件 (`.pth`) 时**被遗漏的**关键结构性参数。
|
||||
* **mLSTM**: 缺少 `mlstm_layers`, `embed_dim`, `dense_dim` 等参数。
|
||||
* **TCN**: 缺少 `num_channels`, `kernel_size` 等参数。
|
||||
* **KAN**: 缺少 `hidden_sizes` 列表。
|
||||
3. **连锁失败**:
|
||||
* 当 `server/predictors/model_predictor.py` 尝试加载这些模型的检查点文件时,它从 `checkpoint['config']` 中找不到实例化模型所必需的全部参数。
|
||||
* 模型实例化失败,抛出 `KeyError` 或 `TypeError`。
|
||||
* 这个异常导致 `load_model_and_predict` 函数提前返回 `None`,最终导致返回给前端的响应中缺少 `history_data`,前端因此无法渲染图表。
|
||||
|
||||
**系统性、可扩展的解决方案**:
|
||||
为了彻底解决这个问题,并为未来平稳地加入新算法,我们对所有非 Transformer 的训练器进行了标准化的、彻底的修复。
|
||||
|
||||
1. **修复 `mlstm_trainer.py`**: 在 `config` 字典中补全了 `mlstm_layers`, `embed_dim`, `dense_dim` 等所有缺失的参数。
|
||||
2. **修复 `tcn_trainer.py`**: 在 `config` 字典中补全了 `num_channels`, `kernel_size` 等所有缺失的参数。
|
||||
3. **修复 `kan_trainer.py`**: 在 `config` 字典中补全了 `hidden_sizes` 列表。
|
||||
|
||||
**结果**:
|
||||
通过这次系统性的修复,我们确保了所有训练器在保存模型时,都会将完整的、可用于重新实例化模型的配置信息写入检查点文件。这从根本上解决了所有模型算法的预测失败问题,使得整个系统在处理不同算法时具有了通用性和健壮性。
|
Binary file not shown.
151
server/api.py
151
server/api.py
@ -85,6 +85,7 @@ import numpy as np
|
||||
import io
|
||||
from werkzeug.utils import secure_filename
|
||||
import random
|
||||
import uuid
|
||||
|
||||
# 导入训练进度管理器 - 延迟初始化以避免循环导入
|
||||
try:
|
||||
@ -1507,47 +1508,56 @@ def predict():
|
||||
"""
|
||||
try:
|
||||
data = request.json
|
||||
product_id = data.get('product_id')
|
||||
model_type = data.get('model_type')
|
||||
store_id = data.get('store_id') # 新增店铺ID参数
|
||||
version = data.get('version') # 新增版本参数
|
||||
version = data.get('version')
|
||||
future_days = int(data.get('future_days', 7))
|
||||
start_date = data.get('start_date', '')
|
||||
include_visualization = data.get('include_visualization', False)
|
||||
|
||||
scope_msg = f", store_id={store_id}" if store_id else ", 全局模型"
|
||||
print(f"API接收到预测请求: product_id={product_id}, model_type={model_type}, version={version}{scope_msg}, future_days={future_days}, start_date={start_date}")
|
||||
# 确定训练模式和标识符
|
||||
training_mode = data.get('training_mode', 'product')
|
||||
product_id = data.get('product_id')
|
||||
store_id = data.get('store_id')
|
||||
|
||||
if not product_id or not model_type:
|
||||
return jsonify({"status": "error", "error": "product_id 和 model_type 是必需的"}), 400
|
||||
if training_mode == 'global':
|
||||
# 全局模式:使用硬编码的标识符,并为预测函数设置占位符
|
||||
model_identifier = "global_all_products_sum"
|
||||
product_id = 'all_products'
|
||||
product_name = "全局聚合数据"
|
||||
elif training_mode == 'store':
|
||||
# 店铺模式:验证store_id并构建标识符
|
||||
if not store_id:
|
||||
return jsonify({"status": "error", "error": "店铺模式需要 store_id"}), 400
|
||||
model_identifier = f"store_{store_id}"
|
||||
product_name = f"店铺 {store_id} 整体"
|
||||
else: # 默认为 'product' 模式
|
||||
# 药品模式:验证product_id并构建标识符
|
||||
if not product_id:
|
||||
return jsonify({"status": "error", "error": "药品模式需要 product_id"}), 400
|
||||
model_identifier = product_id
|
||||
product_name = get_product_name(product_id) or product_id
|
||||
|
||||
# 获取产品名称
|
||||
product_name = get_product_name(product_id)
|
||||
if not product_name:
|
||||
product_name = product_id
|
||||
print(f"API接收到预测请求: mode={training_mode}, model_identifier='{model_identifier}', model_type='{model_type}', version='{version}'")
|
||||
|
||||
# 根据版本获取模型ID
|
||||
if version:
|
||||
# 如果指定了版本,构造版本化的模型ID
|
||||
model_id = f"{product_id}_{model_type}_{version}"
|
||||
# 检查指定版本的模型是否存在
|
||||
model_file_path = get_model_file_path(product_id, model_type, version)
|
||||
if not model_type:
|
||||
return jsonify({"status": "error", "error": "model_type 是必需的"}), 400
|
||||
|
||||
# 获取模型版本
|
||||
if not version:
|
||||
version = get_latest_model_version(model_identifier, model_type)
|
||||
|
||||
if not version:
|
||||
return jsonify({"status": "error", "error": f"未找到标识符为 {model_identifier} 的 {model_type} 类型模型"}), 404
|
||||
|
||||
# 检查模型文件是否存在
|
||||
model_file_path = get_model_file_path(model_identifier, model_type, version)
|
||||
if not os.path.exists(model_file_path):
|
||||
return jsonify({"status": "error", "error": f"未找到产品 {product_id} 的 {model_type} 类型模型版本 {version}"}), 404
|
||||
else:
|
||||
# 如果没有指定版本,使用最新版本
|
||||
latest_version = get_latest_model_version(product_id, model_type)
|
||||
if latest_version:
|
||||
model_id = f"{product_id}_{model_type}_{latest_version}"
|
||||
version = latest_version
|
||||
else:
|
||||
# 兼容旧的无版本模型
|
||||
model_id = get_latest_model_id(model_type, product_id)
|
||||
if not model_id:
|
||||
return jsonify({"status": "error", "error": f"未找到产品 {product_id} 的 {model_type} 类型模型"}), 404
|
||||
return jsonify({"status": "error", "error": f"未找到模型文件: {model_file_path}"}), 404
|
||||
|
||||
model_id = f"{model_identifier}_{model_type}_{version}"
|
||||
|
||||
# 执行预测
|
||||
prediction_result = run_prediction(model_type, product_id, model_id, future_days, start_date, version, store_id)
|
||||
prediction_result = run_prediction(model_type, product_id, model_id, future_days, start_date, version, store_id, training_mode)
|
||||
|
||||
if prediction_result is None:
|
||||
return jsonify({"status": "error", "error": "预测失败,预测器返回None"}), 500
|
||||
@ -1610,17 +1620,20 @@ def predict():
|
||||
# 递归处理整个预测结果对象,确保所有NumPy类型都被转换
|
||||
processed_result = convert_numpy_types(prediction_result)
|
||||
|
||||
# 构建前端期望的响应格式
|
||||
# 构建前端期望的响应格式
|
||||
response_data = {
|
||||
'status': 'success',
|
||||
'data': processed_result
|
||||
'data': processed_result,
|
||||
'history_data': [],
|
||||
'prediction_data': []
|
||||
}
|
||||
|
||||
# 将history_data和prediction_data移到顶级
|
||||
if 'history_data' in processed_result:
|
||||
# 将history_data和prediction_data移到顶级,并确保它们存在
|
||||
if 'history_data' in processed_result and processed_result['history_data']:
|
||||
response_data['history_data'] = processed_result['history_data']
|
||||
|
||||
if 'prediction_data' in processed_result:
|
||||
if 'prediction_data' in processed_result and processed_result['prediction_data']:
|
||||
response_data['prediction_data'] = processed_result['prediction_data']
|
||||
|
||||
# 调试日志:打印响应数据结构
|
||||
@ -2704,7 +2717,7 @@ def get_product_name(product_id):
|
||||
return None
|
||||
|
||||
# 执行预测的辅助函数
|
||||
def run_prediction(model_type, product_id, model_id, future_days, start_date, version=None, store_id=None):
|
||||
def run_prediction(model_type, product_id, model_id, future_days, start_date, version=None, store_id=None, training_mode='product'):
|
||||
"""执行模型预测"""
|
||||
try:
|
||||
scope_msg = f", store_id={store_id}" if store_id else ", 全局模型"
|
||||
@ -2725,7 +2738,8 @@ def run_prediction(model_type, product_id, model_id, future_days, start_date, ve
|
||||
store_id=store_id,
|
||||
future_days=future_days,
|
||||
start_date=start_date,
|
||||
version=version
|
||||
version=version,
|
||||
training_mode=training_mode
|
||||
)
|
||||
|
||||
if prediction_result is None:
|
||||
@ -2742,44 +2756,31 @@ def run_prediction(model_type, product_id, model_id, future_days, start_date, ve
|
||||
# 将DataFrame转换为prediction_data格式
|
||||
prediction_data = []
|
||||
for _, row in predictions_df.iterrows():
|
||||
# 纠正:预测器返回的DataFrame中使用'sales'作为预测值列名。
|
||||
# 我们从'sales'列读取,然后放入前端期望的'predicted_sales'键中。
|
||||
sales_value = float(row['sales']) if pd.notna(row['sales']) else 0.0
|
||||
item = {
|
||||
'date': row['date'].strftime('%Y-%m-%d') if hasattr(row['date'], 'strftime') else str(row['date']),
|
||||
'predicted_sales': float(row['sales']) if pd.notna(row['sales']) else 0.0,
|
||||
'sales': float(row['sales']) if pd.notna(row['sales']) else 0.0 # 兼容字段
|
||||
'predicted_sales': sales_value,
|
||||
'sales': sales_value # 兼容字段
|
||||
}
|
||||
prediction_data.append(item)
|
||||
|
||||
prediction_result['prediction_data'] = prediction_data
|
||||
|
||||
# 获取历史数据用于对比
|
||||
try:
|
||||
# 读取原始数据
|
||||
from utils.multi_store_data_utils import load_multi_store_data
|
||||
df = load_multi_store_data()
|
||||
product_df = df[df['product_id'] == product_id].copy()
|
||||
|
||||
if not product_df.empty:
|
||||
# 获取最近30天的历史数据
|
||||
product_df['date'] = pd.to_datetime(product_df['date'])
|
||||
product_df = product_df.sort_values('date')
|
||||
|
||||
# 取最后30天的数据
|
||||
recent_history = product_df.tail(30)
|
||||
|
||||
history_data = []
|
||||
for _, row in recent_history.iterrows():
|
||||
# 统一数据格式:确保历史数据和预测数据在发送给前端前具有相同的结构和日期格式。
|
||||
if 'history_data' in prediction_result and isinstance(prediction_result['history_data'], pd.DataFrame):
|
||||
history_df = prediction_result['history_data']
|
||||
history_data_list = []
|
||||
for _, row in history_df.iterrows():
|
||||
item = {
|
||||
'date': row['date'].strftime('%Y-%m-%d'),
|
||||
'sales': float(row['sales']) if pd.notna(row['sales']) else 0.0
|
||||
}
|
||||
history_data.append(item)
|
||||
|
||||
prediction_result['history_data'] = history_data
|
||||
history_data_list.append(item)
|
||||
prediction_result['history_data'] = history_data_list
|
||||
else:
|
||||
prediction_result['history_data'] = []
|
||||
|
||||
except Exception as e:
|
||||
print(f"获取历史数据失败: {str(e)}")
|
||||
# 确保即使没有数据,也返回一个空列表
|
||||
prediction_result['history_data'] = []
|
||||
|
||||
return prediction_result
|
||||
@ -3279,7 +3280,6 @@ def analyze_prediction(prediction_result):
|
||||
sample_dates = [item.get('date') for item in prediction_data if item.get('date')]
|
||||
sample_dates = [d.strftime('%Y-%m-%d') if not isinstance(d, str) else d for d in sample_dates if d]
|
||||
if sample_dates:
|
||||
import random
|
||||
analysis['history_chart_data'] = {
|
||||
'dates': sample_dates,
|
||||
'changes': [round(random.uniform(-5, 5), 2) for _ in range(len(sample_dates))]
|
||||
@ -3355,15 +3355,28 @@ def save_prediction_result(prediction_result, product_id, product_name, model_ty
|
||||
conn = get_db_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# 提取并序列化需要存储的数据
|
||||
predictions_data_json = json.dumps(prediction_result.get('prediction_data', []), cls=CustomJSONEncoder)
|
||||
|
||||
# 从分析结果中获取指标,如果分析结果不存在,则使用空字典
|
||||
analysis_data = prediction_result.get('analysis', {})
|
||||
metrics_data = analysis_data.get('metrics', {}) if isinstance(analysis_data, dict) else {}
|
||||
metrics_json = json.dumps(metrics_data, cls=CustomJSONEncoder)
|
||||
|
||||
chart_data_json = json.dumps(prediction_result.get('chart_data', {}), cls=CustomJSONEncoder)
|
||||
analysis_json = json.dumps(analysis_data, cls=CustomJSONEncoder)
|
||||
|
||||
cursor.execute('''
|
||||
INSERT INTO prediction_history (
|
||||
id, product_id, product_name, model_type, model_id,
|
||||
start_date, future_days, created_at, file_path
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
prediction_id, product_id, product_name, model_type, model_id,
|
||||
start_date, future_days, created_at, file_path,
|
||||
predictions_data, metrics, chart_data, analysis
|
||||
) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
|
||||
''', (
|
||||
prediction_id, product_id, product_name, model_type, model_id,
|
||||
start_date if start_date else datetime.now().strftime('%Y-%m-%d'),
|
||||
future_days, datetime.now().isoformat(), file_path
|
||||
future_days, datetime.now().isoformat(), file_path,
|
||||
predictions_data_json, metrics_json, chart_data_json, analysis_json
|
||||
))
|
||||
|
||||
conn.commit()
|
||||
@ -3810,7 +3823,9 @@ def get_store_model_versions_api(store_id, model_type):
|
||||
def get_global_model_versions_api(model_type):
|
||||
"""获取全局模型版本列表API"""
|
||||
try:
|
||||
model_identifier = "global"
|
||||
# 全局模型的标识符是在训练时确定的,例如 'global_all_products_sum'
|
||||
# 这里我们假设前端请求的是默认的全局模型
|
||||
model_identifier = "global_all_products_sum"
|
||||
versions = get_model_versions(model_identifier, model_type)
|
||||
latest_version = get_latest_model_version(model_identifier, model_type)
|
||||
|
||||
|
@ -114,33 +114,27 @@ def get_next_model_version(product_id: str, model_type: str) -> str:
|
||||
else:
|
||||
return DEFAULT_VERSION
|
||||
|
||||
def get_model_file_path(product_id: str, model_type: str, version: str = None) -> str:
|
||||
def get_model_file_path(product_id: str, model_type: str, version: str) -> str:
|
||||
"""
|
||||
生成模型文件路径
|
||||
根据产品ID、模型类型和版本号,生成模型文件的准确路径。
|
||||
|
||||
Args:
|
||||
product_id: 产品ID
|
||||
product_id: 产品ID (纯数字)
|
||||
model_type: 模型类型
|
||||
version: 版本号,如果为None则获取下一个版本
|
||||
version: 版本字符串 (例如 'best', 'final_epoch_50', 'v1_legacy')
|
||||
|
||||
Returns:
|
||||
模型文件的完整路径
|
||||
"""
|
||||
if version is None:
|
||||
version = get_next_model_version(product_id, model_type)
|
||||
# 处理历史遗留的 "v1" 格式
|
||||
if version == "v1_legacy":
|
||||
filename = f"{model_type}_model_product_{product_id}.pth"
|
||||
return os.path.join(DEFAULT_MODEL_DIR, filename)
|
||||
|
||||
# 特殊处理v1版本:检查是否存在旧格式文件
|
||||
if version == "v1":
|
||||
# 检查旧格式文件是否存在
|
||||
old_format_filename = f"{model_type}_model_product_{product_id}.pth"
|
||||
old_format_path = os.path.join(DEFAULT_MODEL_DIR, old_format_filename)
|
||||
|
||||
if os.path.exists(old_format_path):
|
||||
print(f"找到旧格式模型文件: {old_format_path},将其作为v1版本")
|
||||
return old_format_path
|
||||
|
||||
# 使用新格式文件名
|
||||
filename = f"{model_type}_model_product_{product_id}_{version}.pth"
|
||||
# 修正:直接使用唯一的product_id(它可能包含store_前缀)来构建文件名
|
||||
# 文件名示例: transformer_17002608_epoch_best.pth 或 transformer_store_01010023_epoch_best.pth
|
||||
filename = f"{model_type}_{product_id}_epoch_{version}.pth"
|
||||
# 修正:直接在根模型目录查找,不再使用checkpoints子目录
|
||||
return os.path.join(DEFAULT_MODEL_DIR, filename)
|
||||
|
||||
def get_model_versions(product_id: str, model_type: str) -> list:
|
||||
@ -148,40 +142,45 @@ def get_model_versions(product_id: str, model_type: str) -> list:
|
||||
获取指定产品和模型类型的所有版本
|
||||
|
||||
Args:
|
||||
product_id: 产品ID
|
||||
product_id: 产品ID (现在应该是纯数字ID)
|
||||
model_type: 模型类型
|
||||
|
||||
Returns:
|
||||
版本列表,按版本号排序
|
||||
"""
|
||||
# 新格式:带版本号的文件
|
||||
pattern_new = f"{model_type}_model_product_{product_id}_v*.pth"
|
||||
existing_files_new = glob.glob(os.path.join(DEFAULT_MODEL_DIR, pattern_new))
|
||||
# 直接使用传入的product_id构建搜索模式
|
||||
# 搜索模式,匹配 "transformer_product_17002608_epoch_50.pth" 或 "transformer_product_17002608_epoch_best.pth"
|
||||
# 修正:直接使用唯一的product_id(它可能包含store_前缀)来构建搜索模式
|
||||
pattern = f"{model_type}_{product_id}_epoch_*.pth"
|
||||
# 修正:直接在根模型目录查找,不再使用checkpoints子目录
|
||||
search_path = os.path.join(DEFAULT_MODEL_DIR, pattern)
|
||||
existing_files = glob.glob(search_path)
|
||||
|
||||
# 旧格式:不带版本号的文件(兼容性支持)
|
||||
# 旧格式(兼容性支持)
|
||||
pattern_old = f"{model_type}_model_product_{product_id}.pth"
|
||||
old_file_path = os.path.join(DEFAULT_MODEL_DIR, pattern_old)
|
||||
has_old_format = os.path.exists(old_file_path)
|
||||
old_file_path = os.path.join(DEFAULT_MODEL_DIR, pattern_old)
|
||||
has_old_format = os.path.exists(old_file_path)
|
||||
|
||||
versions = []
|
||||
versions = set() # 使用集合避免重复
|
||||
|
||||
# 处理新格式文件
|
||||
for file_path in existing_files_new:
|
||||
# 从找到的文件中提取版本信息
|
||||
for file_path in existing_files:
|
||||
filename = os.path.basename(file_path)
|
||||
version_match = re.search(rf"_v(\d+)\.pth$", filename)
|
||||
# 匹配 _epoch_ 后面的内容作为版本
|
||||
version_match = re.search(r"_epoch_(.+)\.pth$", filename)
|
||||
if version_match:
|
||||
version_num = int(version_match.group(1))
|
||||
versions.append(f"v{version_num}")
|
||||
versions.add(version_match.group(1))
|
||||
|
||||
# 如果存在旧格式文件,将其视为v1
|
||||
if has_old_format:
|
||||
if "v1" not in versions: # 避免重复添加
|
||||
versions.append("v1")
|
||||
print(f"检测到旧格式模型文件: {old_file_path},将其视为版本v1")
|
||||
versions.add("v1_legacy") # 添加一个特殊标识
|
||||
print(f"检测到旧格式模型文件: {old_file_path},将其视为版本 v1_legacy")
|
||||
|
||||
# 按版本号排序
|
||||
versions.sort(key=lambda v: int(v[1:]))
|
||||
return versions
|
||||
# 转换为列表并排序
|
||||
sorted_versions = sorted(list(versions))
|
||||
return sorted_versions
|
||||
|
||||
def get_latest_model_version(product_id: str, model_type: str) -> str:
|
||||
"""
|
||||
|
@ -132,8 +132,8 @@ class PharmacyPredictor:
|
||||
file_path=self.data_path
|
||||
)
|
||||
log_message(f"按店铺聚合训练: 店铺 {store_id}, 聚合方法 {aggregation_method}, 数据量: {len(product_data)}")
|
||||
# 将product_id设置为店铺ID,以便模型保存时使用有意义的标识
|
||||
product_id = store_id
|
||||
# 将product_id设置为'store_{store_id}',与API查找逻辑保持一致
|
||||
product_id = f"store_{store_id}"
|
||||
except Exception as e:
|
||||
log_message(f"聚合店铺 {store_id} 数据失败: {e}", 'error')
|
||||
return None
|
||||
@ -179,7 +179,7 @@ class PharmacyPredictor:
|
||||
|
||||
# 根据训练模式构建模型标识符
|
||||
if training_mode == 'store':
|
||||
model_identifier = f"{store_id}_{product_id}"
|
||||
model_identifier = product_id
|
||||
elif training_mode == 'global':
|
||||
model_identifier = f"global_{product_id}_{aggregation_method}"
|
||||
else:
|
||||
@ -191,11 +191,14 @@ class PharmacyPredictor:
|
||||
if model_type == 'transformer':
|
||||
model_result, metrics, actual_version = train_product_model_with_transformer(
|
||||
product_id=product_id,
|
||||
model_identifier=model_identifier,
|
||||
product_df=product_data,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
epochs=epochs,
|
||||
sequence_length=sequence_length,
|
||||
forecast_horizon=forecast_horizon,
|
||||
model_dir=self.model_dir,
|
||||
version=version,
|
||||
socketio=socketio,
|
||||
@ -206,11 +209,14 @@ class PharmacyPredictor:
|
||||
elif model_type == 'mlstm':
|
||||
_, metrics, _, _ = train_product_model_with_mlstm(
|
||||
product_id=product_id,
|
||||
model_identifier=model_identifier,
|
||||
product_df=product_data,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
epochs=epochs,
|
||||
sequence_length=sequence_length,
|
||||
forecast_horizon=forecast_horizon,
|
||||
model_dir=self.model_dir,
|
||||
socketio=socketio,
|
||||
task_id=task_id,
|
||||
@ -219,33 +225,42 @@ class PharmacyPredictor:
|
||||
elif model_type == 'kan':
|
||||
_, metrics = train_product_model_with_kan(
|
||||
product_id=product_id,
|
||||
model_identifier=model_identifier,
|
||||
product_df=product_data,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
epochs=epochs,
|
||||
sequence_length=sequence_length,
|
||||
forecast_horizon=forecast_horizon,
|
||||
use_optimized=use_optimized,
|
||||
model_dir=self.model_dir
|
||||
)
|
||||
elif model_type == 'optimized_kan':
|
||||
_, metrics = train_product_model_with_kan(
|
||||
product_id=product_id,
|
||||
model_identifier=model_identifier,
|
||||
product_df=product_data,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
epochs=epochs,
|
||||
sequence_length=sequence_length,
|
||||
forecast_horizon=forecast_horizon,
|
||||
use_optimized=True,
|
||||
model_dir=self.model_dir
|
||||
)
|
||||
elif model_type == 'tcn':
|
||||
_, metrics, _, _ = train_product_model_with_tcn(
|
||||
product_id=product_id,
|
||||
model_identifier=model_identifier,
|
||||
product_df=product_data,
|
||||
store_id=store_id,
|
||||
training_mode=training_mode,
|
||||
aggregation_method=aggregation_method,
|
||||
epochs=epochs,
|
||||
sequence_length=sequence_length,
|
||||
forecast_horizon=forecast_horizon,
|
||||
model_dir=self.model_dir,
|
||||
socketio=socketio,
|
||||
task_id=task_id
|
||||
@ -298,7 +313,8 @@ class PharmacyPredictor:
|
||||
"""
|
||||
# 根据训练模式构建模型标识符
|
||||
if training_mode == 'store' and store_id:
|
||||
model_identifier = f"{store_id}_{product_id}"
|
||||
# 修正:店铺模型的标识符应该只基于店铺ID
|
||||
model_identifier = f"store_{store_id}"
|
||||
elif training_mode == 'global':
|
||||
model_identifier = f"global_{product_id}_{aggregation_method}"
|
||||
else:
|
||||
@ -307,10 +323,12 @@ class PharmacyPredictor:
|
||||
return load_model_and_predict(
|
||||
model_identifier,
|
||||
model_type,
|
||||
store_id=store_id,
|
||||
future_days=future_days,
|
||||
start_date=start_date,
|
||||
analyze_result=analyze_result,
|
||||
version=version
|
||||
version=version,
|
||||
training_mode=training_mode
|
||||
)
|
||||
|
||||
def train_optimized_kan_model(self, product_id, epochs=100, batch_size=32,
|
||||
|
@ -21,9 +21,9 @@ from models.optimized_kan_forecaster import OptimizedKANForecaster
|
||||
from analysis.trend_analysis import analyze_prediction_result
|
||||
from utils.visualization import plot_prediction_results
|
||||
from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data
|
||||
from core.config import DEVICE, get_model_file_path
|
||||
from core.config import DEVICE, get_model_file_path, DEFAULT_DATA_PATH
|
||||
|
||||
def load_model_and_predict(product_id, model_type, store_id=None, future_days=7, start_date=None, analyze_result=False, version=None):
|
||||
def load_model_and_predict(product_id, model_type, store_id=None, future_days=7, start_date=None, analyze_result=False, version=None, training_mode='product'):
|
||||
"""
|
||||
加载已训练的模型并进行预测
|
||||
|
||||
@ -101,41 +101,43 @@ def load_model_and_predict(product_id, model_type, store_id=None, future_days=7,
|
||||
|
||||
# 加载销售数据(支持多店铺)
|
||||
try:
|
||||
if store_id:
|
||||
# 加载特定店铺的数据
|
||||
product_df = get_store_product_sales_data(
|
||||
store_id,
|
||||
product_id,
|
||||
'pharmacy_sales_multi_store.csv'
|
||||
)
|
||||
store_name = product_df['store_name'].iloc[0] if 'store_name' in product_df.columns else f"店铺{store_id}"
|
||||
prediction_scope = f"店铺 '{store_name}' ({store_id})"
|
||||
else:
|
||||
# 聚合所有店铺的数据进行预测
|
||||
from utils.multi_store_data_utils import aggregate_multi_store_data
|
||||
|
||||
# 根据训练模式加载相应的数据
|
||||
if training_mode == 'store' and store_id:
|
||||
# 店铺模型:聚合该店铺的所有产品数据
|
||||
product_df = aggregate_multi_store_data(
|
||||
product_id,
|
||||
store_id=store_id,
|
||||
aggregation_method='sum',
|
||||
file_path='pharmacy_sales_multi_store.csv'
|
||||
file_path=DEFAULT_DATA_PATH
|
||||
)
|
||||
store_name = product_df['store_name'].iloc[0] if 'store_name' in product_df.columns and not product_df.empty else f"店铺{store_id}"
|
||||
prediction_scope = f"店铺 '{store_name}' ({store_id})"
|
||||
product_name = store_name
|
||||
elif training_mode == 'global':
|
||||
# 全局模型:聚合所有数据
|
||||
product_df = aggregate_multi_store_data(
|
||||
aggregation_method='sum',
|
||||
file_path=DEFAULT_DATA_PATH
|
||||
)
|
||||
prediction_scope = "全局聚合数据"
|
||||
product_name = "全局销售数据"
|
||||
else:
|
||||
# 产品模型(默认):聚合该产品在所有店铺的数据
|
||||
product_df = aggregate_multi_store_data(
|
||||
product_id=product_id,
|
||||
aggregation_method='sum',
|
||||
file_path=DEFAULT_DATA_PATH
|
||||
)
|
||||
prediction_scope = "全部店铺(聚合数据)"
|
||||
product_name = product_df['product_name'].iloc[0] if not product_df.empty else product_id
|
||||
except Exception as e:
|
||||
print(f"多店铺数据加载失败,尝试使用原始数据格式: {e}")
|
||||
# 后向兼容:尝试加载原始数据格式
|
||||
try:
|
||||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||||
if store_id:
|
||||
print(f"警告:原始数据不支持店铺过滤,将使用所有数据预测")
|
||||
prediction_scope = "默认数据"
|
||||
except Exception as e2:
|
||||
print(f"加载产品数据失败: {str(e2)}")
|
||||
print(f"加载数据失败: {e}")
|
||||
return None
|
||||
|
||||
if product_df.empty:
|
||||
print(f"产品 {product_id} 没有销售数据")
|
||||
print(f"产品 {product_id} 或店铺 {store_id} 没有销售数据")
|
||||
return None
|
||||
|
||||
product_name = product_df['product_name'].iloc[0]
|
||||
print(f"使用 {model_type} 模型预测产品 '{product_name}' (ID: {product_id}) 的未来 {future_days} 天销量")
|
||||
print(f"预测范围: {prediction_scope}")
|
||||
|
||||
@ -262,7 +264,7 @@ def load_model_and_predict(product_id, model_type, store_id=None, future_days=7,
|
||||
|
||||
# 准备输入数据
|
||||
try:
|
||||
features = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||||
sequence_length = config['sequence_length']
|
||||
|
||||
# 获取最近的sequence_length天数据作为输入
|
||||
@ -367,11 +369,24 @@ def load_model_and_predict(product_id, model_type, store_id=None, future_days=7,
|
||||
print(f"分析预测结果失败: {str(e)}")
|
||||
# 分析失败不影响主要功能,继续执行
|
||||
|
||||
# 准备用于图表展示的历史数据
|
||||
history_df = product_df
|
||||
if start_date:
|
||||
try:
|
||||
# 筛选出所有早于预测起始日期的数据
|
||||
history_df = product_df[product_df['date'] < pd.to_datetime(start_date)]
|
||||
except Exception as e:
|
||||
print(f"筛选历史数据时日期格式错误: {e}")
|
||||
|
||||
# 从正确的历史记录中取最后30天
|
||||
recent_history = history_df.tail(30)
|
||||
|
||||
return {
|
||||
'product_id': product_id,
|
||||
'product_name': product_name,
|
||||
'model_type': model_type,
|
||||
'predictions': predictions_df,
|
||||
'history_data': recent_history, # 将历史数据添加到返回结果中
|
||||
'analysis': analysis
|
||||
}
|
||||
except Exception as e:
|
||||
|
@ -21,7 +21,7 @@ from utils.visualization import plot_loss_curve
|
||||
from analysis.metrics import evaluate_model
|
||||
from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
|
||||
|
||||
def train_product_model_with_kan(product_id, product_df=None, store_id=None, training_mode='product', aggregation_method='sum', epochs=50, use_optimized=False, model_dir=DEFAULT_MODEL_DIR):
|
||||
def train_product_model_with_kan(product_id, model_identifier, product_df=None, store_id=None, training_mode='product', aggregation_method='sum', epochs=50, sequence_length=LOOK_BACK, forecast_horizon=FORECAST_HORIZON, use_optimized=False, model_dir=DEFAULT_MODEL_DIR):
|
||||
"""
|
||||
使用KAN模型训练产品销售预测模型
|
||||
|
||||
@ -79,11 +79,11 @@ def train_product_model_with_kan(product_id, product_df=None, store_id=None, tra
|
||||
raise ValueError(f"产品 {product_id} 没有可用的销售数据")
|
||||
|
||||
# 数据量检查
|
||||
min_required_samples = LOOK_BACK + FORECAST_HORIZON
|
||||
min_required_samples = sequence_length + forecast_horizon
|
||||
if len(product_df) < min_required_samples:
|
||||
error_msg = (
|
||||
f"❌ 训练数据不足错误\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={LOOK_BACK} + FORECAST_HORIZON={FORECAST_HORIZON})\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={sequence_length} + FORECAST_HORIZON={forecast_horizon})\n"
|
||||
f"实际数据量: {len(product_df)} 天\n"
|
||||
f"产品ID: {product_id}, 训练模式: {training_mode}\n"
|
||||
f"建议解决方案:\n"
|
||||
@ -123,8 +123,8 @@ def train_product_model_with_kan(product_id, product_df=None, store_id=None, tra
|
||||
y_train, y_test = y_scaled[:train_size], y_scaled[train_size:]
|
||||
|
||||
# 创建时间序列数据
|
||||
trainX, trainY = create_dataset(X_train, y_train, LOOK_BACK, FORECAST_HORIZON)
|
||||
testX, testY = create_dataset(X_test, y_test, LOOK_BACK, FORECAST_HORIZON)
|
||||
trainX, trainY = create_dataset(X_train, y_train, sequence_length, forecast_horizon)
|
||||
testX, testY = create_dataset(X_test, y_test, sequence_length, forecast_horizon)
|
||||
|
||||
# 转换为PyTorch的Tensor
|
||||
trainX_tensor = torch.Tensor(trainX)
|
||||
@ -142,7 +142,7 @@ def train_product_model_with_kan(product_id, product_df=None, store_id=None, tra
|
||||
|
||||
# 初始化KAN模型
|
||||
input_dim = X_train.shape[1]
|
||||
output_dim = FORECAST_HORIZON
|
||||
output_dim = forecast_horizon
|
||||
hidden_size = 64
|
||||
|
||||
if use_optimized:
|
||||
@ -283,8 +283,8 @@ def train_product_model_with_kan(product_id, product_df=None, store_id=None, tra
|
||||
'output_dim': output_dim,
|
||||
'hidden_size': hidden_size,
|
||||
'hidden_sizes': [hidden_size, hidden_size * 2, hidden_size],
|
||||
'sequence_length': LOOK_BACK,
|
||||
'forecast_horizon': FORECAST_HORIZON,
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': model_type_name,
|
||||
'use_optimized': use_optimized
|
||||
},
|
||||
@ -299,7 +299,7 @@ def train_product_model_with_kan(product_id, product_df=None, store_id=None, tra
|
||||
|
||||
model_path = model_manager.save_model(
|
||||
model_data=model_data,
|
||||
product_id=product_id,
|
||||
product_id=model_identifier,
|
||||
model_type=model_type_name,
|
||||
version='v1', # KAN训练器默认使用v1
|
||||
store_id=store_id,
|
||||
|
@ -25,7 +25,7 @@ from core.config import (
|
||||
)
|
||||
from utils.training_progress import progress_manager
|
||||
|
||||
def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
|
||||
def save_checkpoint(checkpoint_data: dict, epoch_or_label, model_identifier: str,
|
||||
model_type: str, model_dir: str, store_id=None,
|
||||
training_mode: str = 'product', aggregation_method=None):
|
||||
"""
|
||||
@ -42,16 +42,12 @@ def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
|
||||
aggregation_method: 聚合方法
|
||||
"""
|
||||
# 创建检查点目录
|
||||
checkpoint_dir = os.path.join(model_dir, 'checkpoints')
|
||||
# 直接在模型根目录保存,不再创建子目录
|
||||
checkpoint_dir = model_dir
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# 生成检查点文件名
|
||||
if training_mode == 'store' and store_id:
|
||||
filename = f"{model_type}_store_{store_id}_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
elif training_mode == 'global' and aggregation_method:
|
||||
filename = f"{model_type}_global_{product_id}_{aggregation_method}_epoch_{epoch_or_label}.pth"
|
||||
else:
|
||||
filename = f"{model_type}_product_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
# 修正:直接使用product_id作为唯一标识符,因为它已经包含了store_前缀或药品ID
|
||||
filename = f"{model_type}_{model_identifier}_epoch_{epoch_or_label}.pth"
|
||||
|
||||
checkpoint_path = os.path.join(checkpoint_dir, filename)
|
||||
|
||||
@ -106,11 +102,14 @@ def load_checkpoint(product_id: str, model_type: str, epoch_or_label,
|
||||
|
||||
def train_product_model_with_mlstm(
|
||||
product_id,
|
||||
model_identifier,
|
||||
product_df,
|
||||
store_id=None,
|
||||
training_mode='product',
|
||||
aggregation_method='sum',
|
||||
epochs=50,
|
||||
sequence_length=LOOK_BACK,
|
||||
forecast_horizon=FORECAST_HORIZON,
|
||||
model_dir=DEFAULT_MODEL_DIR,
|
||||
version=None,
|
||||
socketio=None,
|
||||
@ -215,11 +214,11 @@ def train_product_model_with_mlstm(
|
||||
training_scope = "所有店铺"
|
||||
|
||||
# 数据量检查
|
||||
min_required_samples = LOOK_BACK + FORECAST_HORIZON
|
||||
min_required_samples = sequence_length + forecast_horizon
|
||||
if len(product_df) < min_required_samples:
|
||||
error_msg = (
|
||||
f"❌ 训练数据不足错误\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={LOOK_BACK} + FORECAST_HORIZON={FORECAST_HORIZON})\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={sequence_length} + FORECAST_HORIZON={forecast_horizon})\n"
|
||||
f"实际数据量: {len(product_df)} 天\n"
|
||||
f"产品ID: {product_id}, 训练模式: {training_mode}\n"
|
||||
f"建议解决方案:\n"
|
||||
@ -269,8 +268,8 @@ def train_product_model_with_mlstm(
|
||||
y_train, y_test = y_scaled[:train_size], y_scaled[train_size:]
|
||||
|
||||
# 创建时间序列数据
|
||||
trainX, trainY = create_dataset(X_train, y_train, LOOK_BACK, FORECAST_HORIZON)
|
||||
testX, testY = create_dataset(X_test, y_test, LOOK_BACK, FORECAST_HORIZON)
|
||||
trainX, trainY = create_dataset(X_train, y_train, sequence_length, forecast_horizon)
|
||||
testX, testY = create_dataset(X_test, y_test, sequence_length, forecast_horizon)
|
||||
|
||||
# 转换为PyTorch的Tensor
|
||||
trainX_tensor = torch.Tensor(trainX)
|
||||
@ -295,7 +294,7 @@ def train_product_model_with_mlstm(
|
||||
|
||||
# 初始化mLSTM结合Transformer模型
|
||||
input_dim = X_train.shape[1]
|
||||
output_dim = FORECAST_HORIZON
|
||||
output_dim = forecast_horizon
|
||||
hidden_size = 128
|
||||
num_heads = 4
|
||||
dropout_rate = 0.1
|
||||
@ -432,12 +431,13 @@ def train_product_model_with_mlstm(
|
||||
'output_dim': output_dim,
|
||||
'hidden_size': hidden_size,
|
||||
'num_heads': num_heads,
|
||||
'dropout': dropout_rate,
|
||||
'dropout_rate': dropout_rate,
|
||||
'num_blocks': num_blocks,
|
||||
'embed_dim': embed_dim,
|
||||
'dense_dim': dense_dim,
|
||||
'sequence_length': LOOK_BACK,
|
||||
'forecast_horizon': FORECAST_HORIZON,
|
||||
'mlstm_layers': 2, # 确保这个参数被保存
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': 'mlstm'
|
||||
},
|
||||
'training_info': {
|
||||
@ -452,13 +452,13 @@ def train_product_model_with_mlstm(
|
||||
}
|
||||
|
||||
# 保存检查点
|
||||
save_checkpoint(checkpoint_data, epoch + 1, product_id, 'mlstm',
|
||||
save_checkpoint(checkpoint_data, epoch + 1, model_identifier, 'mlstm',
|
||||
model_dir, store_id, training_mode, aggregation_method)
|
||||
|
||||
# 如果是最佳模型,额外保存一份
|
||||
if test_loss < best_loss:
|
||||
best_loss = test_loss
|
||||
save_checkpoint(checkpoint_data, 'best', product_id, 'mlstm',
|
||||
save_checkpoint(checkpoint_data, 'best', model_identifier, 'mlstm',
|
||||
model_dir, store_id, training_mode, aggregation_method)
|
||||
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
|
||||
epochs_no_improve = 0
|
||||
@ -553,12 +553,13 @@ def train_product_model_with_mlstm(
|
||||
'output_dim': output_dim,
|
||||
'hidden_size': hidden_size,
|
||||
'num_heads': num_heads,
|
||||
'dropout': dropout_rate,
|
||||
'dropout_rate': dropout_rate,
|
||||
'num_blocks': num_blocks,
|
||||
'embed_dim': embed_dim,
|
||||
'dense_dim': dense_dim,
|
||||
'sequence_length': LOOK_BACK,
|
||||
'forecast_horizon': FORECAST_HORIZON,
|
||||
'mlstm_layers': 2, # 确保这个参数被保存
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': 'mlstm'
|
||||
},
|
||||
'metrics': metrics,
|
||||
@ -577,7 +578,7 @@ def train_product_model_with_mlstm(
|
||||
|
||||
# 保存最终模型(使用epoch标识)
|
||||
final_model_path = save_checkpoint(
|
||||
final_model_data, f"final_epoch_{epochs}", product_id, 'mlstm',
|
||||
final_model_data, f"final_epoch_{epochs}", model_identifier, 'mlstm',
|
||||
model_dir, store_id, training_mode, aggregation_method
|
||||
)
|
||||
|
||||
|
@ -21,7 +21,7 @@ from analysis.metrics import evaluate_model
|
||||
from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
|
||||
from utils.training_progress import progress_manager
|
||||
|
||||
def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
|
||||
def save_checkpoint(checkpoint_data: dict, epoch_or_label, model_identifier: str,
|
||||
model_type: str, model_dir: str, store_id=None,
|
||||
training_mode: str = 'product', aggregation_method=None):
|
||||
"""
|
||||
@ -38,16 +38,13 @@ def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
|
||||
aggregation_method: 聚合方法
|
||||
"""
|
||||
# 创建检查点目录
|
||||
checkpoint_dir = os.path.join(model_dir, 'checkpoints')
|
||||
# 直接在模型根目录保存,不再创建子目录
|
||||
checkpoint_dir = model_dir
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# 生成检查点文件名
|
||||
if training_mode == 'store' and store_id:
|
||||
filename = f"{model_type}_store_{store_id}_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
elif training_mode == 'global' and aggregation_method:
|
||||
filename = f"{model_type}_global_{product_id}_{aggregation_method}_epoch_{epoch_or_label}.pth"
|
||||
else:
|
||||
filename = f"{model_type}_product_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
# 修正:直接使用product_id作为唯一标识符,因为它已经包含了store_前缀或药品ID
|
||||
filename = f"{model_type}_{model_identifier}_epoch_{epoch_or_label}.pth"
|
||||
|
||||
checkpoint_path = os.path.join(checkpoint_dir, filename)
|
||||
|
||||
@ -59,11 +56,14 @@ def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
|
||||
|
||||
def train_product_model_with_tcn(
|
||||
product_id,
|
||||
model_identifier,
|
||||
product_df=None,
|
||||
store_id=None,
|
||||
training_mode='product',
|
||||
aggregation_method='sum',
|
||||
epochs=50,
|
||||
sequence_length=LOOK_BACK,
|
||||
forecast_horizon=FORECAST_HORIZON,
|
||||
model_dir=DEFAULT_MODEL_DIR,
|
||||
version=None,
|
||||
socketio=None,
|
||||
@ -159,11 +159,11 @@ def train_product_model_with_tcn(
|
||||
raise ValueError(f"产品 {product_id} 没有可用的销售数据")
|
||||
|
||||
# 数据量检查
|
||||
min_required_samples = LOOK_BACK + FORECAST_HORIZON
|
||||
min_required_samples = sequence_length + forecast_horizon
|
||||
if len(product_df) < min_required_samples:
|
||||
error_msg = (
|
||||
f"❌ 训练数据不足错误\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={LOOK_BACK} + FORECAST_HORIZON={FORECAST_HORIZON})\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={sequence_length} + FORECAST_HORIZON={forecast_horizon})\n"
|
||||
f"实际数据量: {len(product_df)} 天\n"
|
||||
f"产品ID: {product_id}, 训练模式: {training_mode}\n"
|
||||
f"建议解决方案:\n"
|
||||
@ -212,8 +212,8 @@ def train_product_model_with_tcn(
|
||||
progress_manager.set_stage("data_preprocessing", 50)
|
||||
|
||||
# 创建时间序列数据
|
||||
trainX, trainY = create_dataset(X_train, y_train, LOOK_BACK, FORECAST_HORIZON)
|
||||
testX, testY = create_dataset(X_test, y_test, LOOK_BACK, FORECAST_HORIZON)
|
||||
trainX, trainY = create_dataset(X_train, y_train, sequence_length, forecast_horizon)
|
||||
testX, testY = create_dataset(X_test, y_test, sequence_length, forecast_horizon)
|
||||
|
||||
# 转换为PyTorch的Tensor
|
||||
trainX_tensor = torch.Tensor(trainX)
|
||||
@ -240,7 +240,7 @@ def train_product_model_with_tcn(
|
||||
|
||||
# 初始化TCN模型
|
||||
input_dim = X_train.shape[1]
|
||||
output_dim = FORECAST_HORIZON
|
||||
output_dim = forecast_horizon
|
||||
hidden_size = 64
|
||||
num_layers = 3
|
||||
kernel_size = 3
|
||||
@ -382,10 +382,11 @@ def train_product_model_with_tcn(
|
||||
'output_dim': output_dim,
|
||||
'hidden_size': hidden_size,
|
||||
'num_layers': num_layers,
|
||||
'num_channels': [hidden_size] * num_layers,
|
||||
'dropout': dropout_rate,
|
||||
'kernel_size': kernel_size,
|
||||
'sequence_length': LOOK_BACK,
|
||||
'forecast_horizon': FORECAST_HORIZON,
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': 'tcn'
|
||||
},
|
||||
'training_info': {
|
||||
@ -399,13 +400,13 @@ def train_product_model_with_tcn(
|
||||
}
|
||||
|
||||
# 保存检查点
|
||||
save_checkpoint(checkpoint_data, epoch + 1, product_id, 'tcn',
|
||||
save_checkpoint(checkpoint_data, epoch + 1, model_identifier, 'tcn',
|
||||
model_dir, store_id, training_mode, aggregation_method)
|
||||
|
||||
# 如果是最佳模型,额外保存一份
|
||||
if test_loss < best_loss:
|
||||
best_loss = test_loss
|
||||
save_checkpoint(checkpoint_data, 'best', product_id, 'tcn',
|
||||
save_checkpoint(checkpoint_data, 'best', model_identifier, 'tcn',
|
||||
model_dir, store_id, training_mode, aggregation_method)
|
||||
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
|
||||
|
||||
@ -472,10 +473,11 @@ def train_product_model_with_tcn(
|
||||
'output_dim': output_dim,
|
||||
'hidden_size': hidden_size,
|
||||
'num_layers': num_layers,
|
||||
'num_channels': [hidden_size] * num_layers,
|
||||
'dropout': dropout_rate,
|
||||
'kernel_size': kernel_size,
|
||||
'sequence_length': LOOK_BACK,
|
||||
'forecast_horizon': FORECAST_HORIZON,
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': 'tcn'
|
||||
},
|
||||
'metrics': metrics,
|
||||
@ -495,7 +497,7 @@ def train_product_model_with_tcn(
|
||||
|
||||
# 保存最终模型(使用epoch标识)
|
||||
final_model_path = save_checkpoint(
|
||||
final_model_data, f"final_epoch_{epochs}", product_id, 'tcn',
|
||||
final_model_data, f"final_epoch_{epochs}", model_identifier, 'tcn',
|
||||
model_dir, store_id, training_mode, aggregation_method
|
||||
)
|
||||
|
||||
|
@ -27,7 +27,7 @@ from core.config import (
|
||||
from utils.training_progress import progress_manager
|
||||
from utils.model_manager import model_manager
|
||||
|
||||
def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
|
||||
def save_checkpoint(checkpoint_data: dict, epoch_or_label, model_identifier: str,
|
||||
model_type: str, model_dir: str, store_id=None,
|
||||
training_mode: str = 'product', aggregation_method=None):
|
||||
"""
|
||||
@ -43,17 +43,12 @@ def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
|
||||
training_mode: 训练模式
|
||||
aggregation_method: 聚合方法
|
||||
"""
|
||||
# 创建检查点目录
|
||||
checkpoint_dir = os.path.join(model_dir, 'checkpoints')
|
||||
# 直接在模型根目录保存,不再创建子目录
|
||||
checkpoint_dir = model_dir
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
# 生成检查点文件名
|
||||
if training_mode == 'store' and store_id:
|
||||
filename = f"{model_type}_store_{store_id}_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
elif training_mode == 'global' and aggregation_method:
|
||||
filename = f"{model_type}_global_{product_id}_{aggregation_method}_epoch_{epoch_or_label}.pth"
|
||||
else:
|
||||
filename = f"{model_type}_product_{product_id}_epoch_{epoch_or_label}.pth"
|
||||
# 修正:直接使用product_id作为唯一标识符,因为它已经包含了store_前缀或药品ID
|
||||
filename = f"{model_type}_{model_identifier}_epoch_{epoch_or_label}.pth"
|
||||
|
||||
checkpoint_path = os.path.join(checkpoint_dir, filename)
|
||||
|
||||
@ -65,11 +60,14 @@ def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
|
||||
|
||||
def train_product_model_with_transformer(
|
||||
product_id,
|
||||
model_identifier,
|
||||
product_df=None,
|
||||
store_id=None,
|
||||
training_mode='product',
|
||||
aggregation_method='sum',
|
||||
epochs=50,
|
||||
sequence_length=LOOK_BACK,
|
||||
forecast_horizon=FORECAST_HORIZON,
|
||||
model_dir=DEFAULT_MODEL_DIR,
|
||||
version=None,
|
||||
socketio=None,
|
||||
@ -177,11 +175,11 @@ def train_product_model_with_transformer(
|
||||
raise ValueError(f"产品 {product_id} 没有可用的销售数据")
|
||||
|
||||
# 数据量检查
|
||||
min_required_samples = LOOK_BACK + FORECAST_HORIZON
|
||||
min_required_samples = sequence_length + forecast_horizon
|
||||
if len(product_df) < min_required_samples:
|
||||
error_msg = (
|
||||
f"❌ 训练数据不足错误\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={LOOK_BACK} + FORECAST_HORIZON={FORECAST_HORIZON})\n"
|
||||
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={sequence_length} + FORECAST_HORIZON={forecast_horizon})\n"
|
||||
f"实际数据量: {len(product_df)} 天\n"
|
||||
f"产品ID: {product_id}, 训练模式: {training_mode}\n"
|
||||
f"建议解决方案:\n"
|
||||
@ -225,8 +223,8 @@ def train_product_model_with_transformer(
|
||||
y_train, y_test = y_scaled[:train_size], y_scaled[train_size:]
|
||||
|
||||
# 创建时间序列数据
|
||||
trainX, trainY = create_dataset(X_train, y_train, LOOK_BACK, FORECAST_HORIZON)
|
||||
testX, testY = create_dataset(X_test, y_test, LOOK_BACK, FORECAST_HORIZON)
|
||||
trainX, trainY = create_dataset(X_train, y_train, sequence_length, forecast_horizon)
|
||||
testX, testY = create_dataset(X_test, y_test, sequence_length, forecast_horizon)
|
||||
|
||||
progress_manager.set_stage("data_preprocessing", 70)
|
||||
|
||||
@ -256,7 +254,7 @@ def train_product_model_with_transformer(
|
||||
|
||||
# 初始化Transformer模型
|
||||
input_dim = X_train.shape[1]
|
||||
output_dim = FORECAST_HORIZON
|
||||
output_dim = forecast_horizon
|
||||
hidden_size = 64
|
||||
num_heads = 4
|
||||
dropout_rate = 0.1
|
||||
@ -270,7 +268,7 @@ def train_product_model_with_transformer(
|
||||
dim_feedforward=hidden_size * 2,
|
||||
dropout=dropout_rate,
|
||||
output_sequence_length=output_dim,
|
||||
seq_length=LOOK_BACK,
|
||||
seq_length=sequence_length,
|
||||
batch_size=batch_size
|
||||
)
|
||||
|
||||
@ -387,8 +385,8 @@ def train_product_model_with_transformer(
|
||||
'num_heads': num_heads,
|
||||
'dropout': dropout_rate,
|
||||
'num_layers': num_layers,
|
||||
'sequence_length': LOOK_BACK,
|
||||
'forecast_horizon': FORECAST_HORIZON,
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': 'transformer'
|
||||
},
|
||||
'training_info': {
|
||||
@ -402,13 +400,13 @@ def train_product_model_with_transformer(
|
||||
}
|
||||
|
||||
# 保存检查点
|
||||
save_checkpoint(checkpoint_data, epoch + 1, product_id, 'transformer',
|
||||
save_checkpoint(checkpoint_data, epoch + 1, model_identifier, 'transformer',
|
||||
model_dir, store_id, training_mode, aggregation_method)
|
||||
|
||||
# 如果是最佳模型,额外保存一份
|
||||
if test_loss < best_loss:
|
||||
best_loss = test_loss
|
||||
save_checkpoint(checkpoint_data, 'best', product_id, 'transformer',
|
||||
save_checkpoint(checkpoint_data, 'best', model_identifier, 'transformer',
|
||||
model_dir, store_id, training_mode, aggregation_method)
|
||||
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
|
||||
epochs_no_improve = 0
|
||||
@ -483,8 +481,8 @@ def train_product_model_with_transformer(
|
||||
'num_heads': num_heads,
|
||||
'dropout': dropout_rate,
|
||||
'num_layers': num_layers,
|
||||
'sequence_length': LOOK_BACK,
|
||||
'forecast_horizon': FORECAST_HORIZON,
|
||||
'sequence_length': sequence_length,
|
||||
'forecast_horizon': forecast_horizon,
|
||||
'model_type': 'transformer'
|
||||
},
|
||||
'metrics': metrics,
|
||||
@ -504,7 +502,7 @@ def train_product_model_with_transformer(
|
||||
|
||||
# 保存最终模型(使用epoch标识)
|
||||
final_model_path = save_checkpoint(
|
||||
final_model_data, f"final_epoch_{epochs}", product_id, 'transformer',
|
||||
final_model_data, f"final_epoch_{epochs}", model_identifier, 'transformer',
|
||||
model_dir, store_id, training_mode, aggregation_method
|
||||
)
|
||||
|
||||
|
@ -756,3 +756,33 @@
|
||||
```
|
||||
|
||||
通过以上步骤,您就可以在不改动项目其他任何部分的情况下,轻松地将数据源从本地文件切换到服务器数据库。
|
||||
|
||||
---
|
||||
**日期**: 2025-07-15
|
||||
**主题**: 修复“按药品预测”功能并增强图表展示
|
||||
**开发者**: lyf
|
||||
|
||||
### 问题描述
|
||||
“预测分析” -> “按药品预测”页面无法正常使用。前端API调用地址错误,且图表渲染逻辑与后端返回的数据结构不匹配。
|
||||
|
||||
### 解决方案
|
||||
对 `UI/src/views/prediction/ProductPredictionView.vue` 文件进行了以下修复和增强:
|
||||
|
||||
1. **API端点修复**:
|
||||
* **位置**: `startPrediction` 函数。
|
||||
* **操作**: 将API请求地址从错误的 `/api/predict` 修正为正确的 `/api/prediction`。
|
||||
|
||||
2. **数据处理对齐**:
|
||||
* **位置**: `startPrediction` 和 `renderChart` 函数。
|
||||
* **操作**: 修改了数据接收逻辑,使其能够正确处理后端返回的 `history_data` 和 `prediction_data` 字段。
|
||||
|
||||
3. **图表功能增强**:
|
||||
* **位置**: `renderChart` 函数。
|
||||
* **操作**: 重构了图表渲染逻辑,现在可以同时展示历史销量(绿色实线)和预测销量(蓝色虚线),为用户提供更直观的对比分析。
|
||||
|
||||
4. **错误提示优化**:
|
||||
* **位置**: `startPrediction` 函数的 `catch` 块。
|
||||
* **操作**: 改进了错误处理,现在可以从响应中提取并显示来自后端的更具体的错误信息。
|
||||
|
||||
### 最终结果
|
||||
“按药品预测”功能已与后端成功对接,可以正常使用,并且提供了更丰富、更健壮的可视化体验。
|
Loading…
x
Reference in New Issue
Block a user