ShopTRAINING/系统调用逻辑与核心代码分析.md

466 lines
22 KiB
Markdown
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 系统调用逻辑与核心代码分析
本文档旨在详细阐述本销售预测系统的端到端调用链路,从系统启动、前端交互、后端处理,到最终的模型训练、预测和图表展示。
## 1. 系统启动
系统由两部分组成Vue.js前端和Flask后端。
### 1.1. 启动后端API服务
在项目根目录下,通过以下命令启动后端服务:
```bash
python server/api.py
```
该命令会启动一个Flask应用监听在 `http://localhost:5000`并提供所有API和WebSocket服务。
### 1.2. 启动前端开发服务器
进入 `UI` 目录,执行以下命令:
```bash
cd UI
npm install
npm run dev
```
这将启动Vite开发服务器通常在 `http://localhost:5173`,并自动打开浏览器访问前端页面。
## 2. 核心调用链路概览
以最核心的 **“按药品训练 -> 按药品预测”** 流程为例,其高层调用链路如下:
**训练流程:**
`前端UI` -> `POST /api/training` -> `api.py: start_training()` -> `TrainingManager` -> `后台进程` -> `predictor.py: train_model()` -> `[model]_trainer.py: train_product_model_with_*()` -> `保存模型.pth`
**预测流程:**
`前端UI` -> `POST /api/prediction` -> `api.py: predict()` -> `predictor.py: predict()` -> `model_predictor.py: load_model_and_predict()` -> `加载模型.pth` -> `返回预测JSON` -> `前端图表渲染`
## 3. 详细流程:按药品训练
此流程的目标是为特定药品训练一个专用的预测模型。
### 3.1. 前端交互与API请求
1. **用户操作**: 用户在 **“按药品训练”** 页面 ([`UI/src/views/training/ProductTrainingView.vue`](UI/src/views/training/ProductTrainingView.vue:1)) 选择一个药品、一个模型类型如Transformer、设置训练轮次Epochs然后点击 **“启动药品训练”** 按钮。
2. **触发函数**: 点击事件调用 [`startTraining`](UI/src/views/training/ProductTrainingView.vue:521) 方法。
3. **构建Payload**: `startTraining` 方法构建一个包含训练参数的 `payload` 对象。关键字段是 `training_mode: 'product'`,用于告知后端这是针对特定产品的训练。
*核心代码 ([`UI/src/views/training/ProductTrainingView.vue`](UI/src/views/training/ProductTrainingView.vue:521))*
```javascript
const startTraining = async () => {
// ... 表单验证 ...
trainingLoading.value = true;
try {
const endpoint = "/api/training";
const payload = {
product_id: form.product_id,
store_id: form.data_scope === 'global' ? null : form.store_id,
model_type: form.model_type,
epochs: form.epochs,
training_mode: 'product' // 标识这是药品训练模式
};
const response = await axios.post(endpoint, payload);
// ... 处理响应启动WebSocket监听 ...
}
// ... 错误处理 ...
};
```
4. **API请求**: 使用 `axios` 向后端 `POST /api/training` 发送请求。
### 3.2. 后端API接收与任务分发
1. **路由处理**: 后端 [`server/api.py`](server/api.py:1) 中的 [`@app.route('/api/training', methods=['POST'])`](server/api.py:933) 装饰器捕获该请求,并由 [`start_training()`](server/api.py:971) 函数处理。
2. **任务提交**: `start_training()` 函数解析请求中的JSON数据然后调用 `training_manager.submit_task()` 将训练任务提交到一个后台进程池中执行以避免阻塞API主线程。这使得API可以立即返回一个任务ID而训练在后台异步进行。
*核心代码 ([`server/api.py`](server/api.py:971))*
```python
@app.route('/api/training', methods=['POST'])
def start_training():
data = request.get_json()
training_mode = data.get('training_mode', 'product')
model_type = data.get('model_type')
epochs = data.get('epochs', 50)
product_id = data.get('product_id')
store_id = data.get('store_id')
if not model_type or (training_mode == 'product' and not product_id):
return jsonify({'error': '缺少必要参数'}), 400
try:
# 使用训练进程管理器提交任务
task_id = training_manager.submit_task(
product_id=product_id or "unknown",
model_type=model_type,
training_mode=training_mode,
store_id=store_id,
epochs=epochs
)
logger.info(f"🚀 训练任务已提交到进程管理器: {task_id[:8]}")
return jsonify({
'message': '模型训练已开始(使用独立进程)',
'task_id': task_id,
})
except Exception as e:
logger.error(f"❌ 提交训练任务失败: {str(e)}")
return jsonify({'error': f'启动训练任务失败: {str(e)}'}), 500
```
### 3.3. 核心训练逻辑
1. **调用核心预测器**: 后台进程最终会调用 [`server/core/predictor.py`](server/core/predictor.py:1) 中的 [`PharmacyPredictor.train_model()`](server/core/predictor.py:63) 方法。
2. **数据准备**: `train_model` 方法首先根据 `training_mode` (`'product'`) 和 `product_id` 从数据源加载并聚合所有店铺关于该药品的销售数据。
3. **分发到具体训练器**: 接着,它根据 `model_type` 调用相应的训练函数。例如,如果 `model_type` 是 `transformer`,它会调用 `train_product_model_with_transformer`。
*核心代码 ([`server/core/predictor.py`](server/core/predictor.py:63))*
```python
class PharmacyPredictor:
def train_model(self, product_id, model_type='transformer', ..., training_mode='product', ...):
# ...
if training_mode == 'product':
product_data = self.data[self.data['product_id'] == product_id].copy()
# ...
# 根据训练模式构建模型标识符
model_identifier = product_id
try:
if model_type == 'transformer':
model_result, metrics, actual_version = train_product_model_with_transformer(
product_id=product_id,
model_identifier=model_identifier,
product_df=product_data,
# ... 其他参数 ...
)
# ... 其他模型的elif分支 ...
return metrics
except Exception as e:
# ... 错误处理 ...
return None
```
### 3.4. 模型训练与保存
1. **具体训练器**: 以 [`server/trainers/transformer_trainer.py`](server/trainers/transformer_trainer.py:1) 为例,`train_product_model_with_transformer` 函数执行以下步骤:
* **数据预处理**: 调用 `prepare_data` 和 `prepare_sequences` 将原始销售数据转换为模型可以理解的、带有时间序列特征的监督学习格式(输入序列和目标序列)。
* **模型实例化**: 创建 `TimeSeriesTransformer` 模型实例。
* **训练循环**: 执行指定的 `epochs` 次训练,计算损失并使用优化器更新模型权重。
* **进度更新**: 在训练过程中,通过 `socketio.emit` 向前端发送 `training_progress` 事件,实时更新进度条和日志。
* **模型保存**: 训练完成后,将模型权重 (`model.state_dict()`)、完整的模型配置 (`config`) 以及数据缩放器 (`scaler_X`, `scaler_y`) 打包成一个字典checkpoint并使用 `torch.save()` 保存到 `.pth` 文件中。文件名由 `get_model_file_path` 根据 `model_identifier`、`model_type` 和 `version` 统一生成。
*核心代码 ([`server/trainers/transformer_trainer.py`](server/trainers/transformer_trainer.py:33))*
```python
def train_product_model_with_transformer(...):
# ... 数据准备 ...
# 定义模型配置
config = {
'input_dim': input_dim,
'output_dim': forecast_horizon,
'hidden_size': hidden_size,
# ... 所有必要的超参数 ...
'model_type': 'transformer'
}
model = TimeSeriesTransformer(...)
# ... 训练循环 ...
# 保存模型
checkpoint = {
'model_state_dict': model.state_dict(),
'config': config,
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'metrics': test_metrics
}
model_path = get_model_file_path(model_identifier, 'transformer', version)
torch.save(checkpoint, model_path)
return model, test_metrics, version
```
## 4. 详细流程:按药品预测
训练完成后,用户可以使用已保存的模型进行预测。
### 4.1. 前端交互与API请求
1. **用户操作**: 用户在 **“按药品预测”** 页面 ([`UI/src/views/prediction/ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:1)) 选择同一个药品、对应的模型和版本,然后点击 **“开始预测”**。
2. **触发函数**: 点击事件调用 [`startPrediction`](UI/src/views/prediction/ProductPredictionView.vue:202) 方法。
3. **构建Payload**: 该方法构建一个包含预测参数的 `payload`。
*核心代码 ([`UI/src/views/prediction/ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:202))*
```javascript
const startPrediction = async () => {
try {
predicting.value = true
const payload = {
product_id: form.product_id,
model_type: form.model_type,
version: form.version,
future_days: form.future_days,
// training_mode is implicitly 'product' here
}
const response = await axios.post('/api/prediction', payload)
if (response.data.status === 'success') {
predictionResult.value = response.data
await nextTick()
renderChart()
}
// ... 错误处理 ...
}
// ...
}
```
4. **API请求**: 使用 `axios` 向后端 `POST /api/prediction` 发送请求。
### 4.2. 后端API接收与预测执行
1. **路由处理**: [`server/api.py`](server/api.py:1) 中的 [`@app.route('/api/prediction', methods=['POST'])`](server/api.py:1413) 捕获请求,由 [`predict()`](server/api.py:1469) 函数处理。
2. **调用核心预测器**: `predict()` 函数解析参数,然后调用 `run_prediction` 辅助函数,该函数内部再调用 [`server/core/predictor.py`](server/core/predictor.py:1) 中的 [`PharmacyPredictor.predict()`](server/core/predictor.py:295) 方法。
*核心代码 ([`server/api.py`](server/api.py:1469))*
```python
@app.route('/api/prediction', methods=['POST'])
def predict():
try:
data = request.json
# ... 解析参数 ...
training_mode = data.get('training_mode', 'product')
product_id = data.get('product_id')
# ...
# 根据模式确定模型标识符
if training_mode == 'product':
model_identifier = product_id
# ...
# 执行预测
prediction_result = run_prediction(model_type, product_id, model_id, ...)
# ... 格式化响应 ...
return jsonify(response_data)
except Exception as e:
# ... 错误处理 ...
```
3. **分发到模型加载器**: [`PharmacyPredictor.predict()`](server/core/predictor.py:295) 方法的主要作用是再次根据 `training_mode` 和 `product_id` 确定 `model_identifier`,然后将所有参数传递给 [`server/predictors/model_predictor.py`](server/predictors/model_predictor.py:1) 中的 [`load_model_and_predict()`](server/predictors/model_predictor.py:26) 函数。
*核心代码 ([`server/core/predictor.py`](server/core/predictor.py:295))*
```python
class PharmacyPredictor:
def predict(self, product_id, model_type, ..., training_mode='product', ...):
if training_mode == 'product':
model_identifier = product_id
# ...
return load_model_and_predict(
model_identifier,
model_type,
# ... 其他参数 ...
)
```
### 4.3. 模型加载与执行预测
[`load_model_and_predict()`](server/predictors/model_predictor.py:26) 是预测流程的核心,它执行以下步骤:
1. **定位模型文件**: 使用 `get_model_file_path` 根据 `product_id` (即 `model_identifier`), `model_type`, 和 `version` 找到之前保存的 `.pth` 模型文件。
2. **加载Checkpoint**: 使用 `torch.load()` 加载模型文件,得到包含 `model_state_dict`, `config`, 和 `scalers` 的字典。
3. **重建模型**: 根据加载的 `config` 中的超参数(如 `hidden_size`, `num_layers` 等),重新创建一个与训练时结构完全相同的模型实例。**这是我们之前修复的关键点,确保所有必要参数都被保存和加载。**
4. **加载权重**: 将加载的 `model_state_dict` 应用到新创建的模型实例上。
5. **准备输入数据**: 从数据源获取最新的 `sequence_length` 天的历史数据作为预测的输入。
6. **数据归一化**: 使用加载的 `scaler_X` 对输入数据进行归一化。
7. **执行预测**: 将归一化的数据输入模型 (`model(X_input)`),得到预测结果。
8. **反归一化**: 使用加载的 `scaler_y` 将模型的输出(预测值)反归一化,转换回原始的销售量尺度。
9. **构建结果**: 将预测值和对应的未来日期组合成一个DataFrame并连同历史数据一起返回。
*核心代码 ([`server/predictors/model_predictor.py`](server/predictors/model_predictor.py:26))*
```python
def load_model_and_predict(...):
# ... 找到模型文件路径 model_path ...
# 加载模型和配置
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
config = checkpoint['config']
scaler_X = checkpoint['scaler_X']
scaler_y = checkpoint['scaler_y']
# 创建模型实例 (以Transformer为例)
model = TimeSeriesTransformer(
num_features=config['input_dim'],
d_model=config['hidden_size'],
# ... 使用config中的所有参数 ...
).to(DEVICE)
# 加载模型参数
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# ... 准备输入数据 ...
# 归一化输入数据
X_scaled = scaler_X.transform(X)
X_input = torch.tensor(X_scaled.reshape(1, sequence_length, -1), ...).to(DEVICE)
# 预测
with torch.no_grad():
y_pred_scaled = model(X_input).cpu().numpy()
# 反归一化
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
# ... 构建返回结果 ...
return {
'predictions': predictions_df,
'history_data': recent_history,
# ...
}
```
### 4.4. 响应格式化与前端图表渲染
1. **API层格式化**: 在 [`server/api.py`](server/api.py:1) 的 [`predict()`](server/api.py:1469) 函数中,从 `load_model_and_predict` 返回的结果被精心格式化成前端期望的JSON结构该结构在顶层同时包含 `history_data` 和 `prediction_data` 两个数组。
2. **前端接收数据**: 前端 [`ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:1) 在 `startPrediction` 方法中接收到这个JSON响应并将其存入 `predictionResult` ref。
3. **图表渲染**: [`renderChart()`](UI/src/views/prediction/ProductPredictionView.vue:232) 方法被调用。它从 `predictionResult.value` 中提取 `history_data` 和 `prediction_data`然后使用Chart.js库将这两部分数据绘制在同一个 `<canvas>` 上,历史数据为实线,预测数据为虚线,从而形成一个连续的趋势图。
*核心代码 ([`UI/src/views/prediction/ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:232))*
```javascript
const renderChart = () => {
if (!chartCanvas.value || !predictionResult.value) return
// ...
// 后端直接提供 history_data 和 prediction_data
const historyData = predictionResult.value.history_data || []
const predictionData = predictionResult.value.prediction_data || []
const historyLabels = historyData.map(p => p.date)
const historySales = historyData.map(p => p.sales)
const predictionLabels = predictionData.map(p => p.date)
const predictionSales = predictionData.map(p => p.predicted_sales)
// ... 组合标签和数据,对齐数据点 ...
chart = new Chart(chartCanvas.value, {
type: 'line',
data: {
labels: allLabels,
datasets: [
{
label: '历史销量',
data: alignedHistorySales,
// ... 样式 ...
},
{
label: '预测销量',
data: alignedPredictionSales,
// ... 样式 ...
}
]
},
// ... Chart.js 配置 ...
})
}
```
至此,一个完整的“训练->预测->展示”的调用链路就完成了。
## 5. 模型保存与版本管理核心逻辑 (重构后)
为了根治版本混乱和模型加载失败的问题,系统进行了一项重要的重构。现在,所有与模型保存、命名和版本管理相关的逻辑都已**统一集中**到 [`server/utils/model_manager.py`](server/utils/model_manager.py:1) 的 `ModelManager` 类中。
### 5.1. 统一管理者:`ModelManager`
- **单一职责**: `ModelManager` 是系统中唯一负责处理模型文件IO的组件。所有训练器 (`trainer`) 在需要保存模型时,都必须通过它来进行。
- **核心功能**:
1. **自动版本控制**: 自动生成和递增符合规范的版本号。
2. **统一命名**: 根据模型的元数据算法类型、训练模式、ID等生成标准化的文件名。
3. **安全保存**: 将模型数据和元数据一起打包保存到 `.pth` 文件中。
4. **可靠检索**: 提供统一的接口来列出和查找模型。
### 5.2. 统一版本规范
所有模型版本现在都遵循一个严格的、可预测的格式:
- **数字版本**: `v{数字}`,例如 `v1`, `v2`, `v3`...
- **生成**: 当一次训练**正常完成**时,`ModelManager` 会自动计算出当前模型的下一个可用版本号(例如,如果已存在 `v1` 和 `v2`,则新版本为 `v3`),并以此命名最终的模型文件。
- **用途**: 代表一次完整的、稳定的训练产出。
- **特殊版本**: `best`
- **生成**: 在训练过程中,如果某个 `epoch` 产生的模型在验证集上的性能超过了之前所有 `epoch`,训练器会调用 `ModelManager` 将这个模型保存为 `best` 版本,覆盖掉旧的 `best` 模型。
- **用途**: 始终指向该模型迄今为止性能最佳的一个版本,便于快速进行高质量的预测。
### 5.3. 统一命名约定 (v2版)
随着系统增加了“按店铺”和“全局”训练模式,`ModelManager` 的 `generate_model_filename` 方法也已升级,以支持更丰富的、无歧义的命名格式:
- **药品模型**: `{model_type}_product_{product_id}_{version}.pth`
- *示例*: `transformer_product_17002608_best.pth`
- **店铺模型**: `{model_type}_store_{store_id}_{version}.pth`
- *示例*: `mlstm_store_01010023_v2.pth`
- **全局模型**: `{model_type}_global_{aggregation_method}_{version}.pth`
- *示例*: `tcn_global_sum_v1.pth`
这个新的命名系统确保了不同训练模式产出的模型可以清晰地被识别和管理。
### 5.4. Checkpoint文件内容 (结构不变)
每个 `.pth` 文件依然是一个包含模型权重、完整配置和数据缩放器的PyTorch Checkpoint。重构加强了**所有训练器都必须将完整的配置信息存入 `config` 字典**这一规则,确保了模型的完全可复现性。
### 5.5. 核心优势 (重构后)
- **逻辑集中**: 所有版本管理的复杂性都被封装在 `ModelManager` 内部,训练器只需调用 `save_model` 即可,无需关心版本号如何生成。
- **数据一致性**: 由于版本的生成、保存和检索都由同一个组件以同一种逻辑处理,从根本上杜绝了因命名或版本格式不匹配导致“模型未找到”的问题。
- **易于维护**: 未来如果需要修改版本策略或命名规则,只需修改 `ModelManager` 一个文件即可,无需改动所有训练器。
## 6. 核心流程的演进:支持店铺与全局模式
在最初的“按药品”流程基础上系统已重构以支持“按店铺”和“全局”的完整AI闭环。这引入了一些关键的逻辑变化
### 6.1. 训练流程的变化
- **统一入口**: 所有训练请求(药品、店铺、全局)都通过 `POST /api/training` 接口,由 `training_mode` 参数区分。
- **数据聚合**: 在 [`predictor.py`](server/core/predictor.py:1) 的 `train_model` 方法中,会根据 `training_mode` 调用 `aggregate_multi_store_data` 函数,为店铺或全局模式准备正确的聚合时间序列数据。
- **模型标识符**: `train_model` 方法现在会生成一个唯一的 `model_identifier`(例如 `product_17002608`, `store_01010023`, `global_sum`),并将其传递给所有下游训练器。这是确保模型被正确命名的关键。
### 6.2. 预测流程的重大修复
预测流程经过了重大修复,以解决之前因逻辑不统一导致的 `404` 错误。
- **废弃旧函数**: `core/config.py` 中的 `get_model_file_path` 和 `get_model_versions` 等旧的、有缺陷的辅助函数已被**完全废弃**。
- **统一查找逻辑**: 现在,[`api.py`](server/api.py:1) 的 `predict` 函数**必须**使用 `model_manager.list_models()` 方法来查找模型。
- **可靠的路径传递**: `predict` 函数找到正确的模型文件路径后,会将其作为一个参数,一路传递给 `run_prediction` 和最终的 `load_model_and_predict` 函数。
- **根除缺陷**: `load_model_and_predict` 函数内部所有手动的、过时的文件查找逻辑已被**完全移除**。它现在只负责接收一个明确的路径并加载模型。
这个修复确保了整个预测链路都依赖于 `ModelManager` 这一个“单一事实来源”,从根本上解决了因路径不匹配导致的预测失败问题。