ShopTRAINING/系统调用逻辑与核心代码分析.md

464 lines
20 KiB
Markdown
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# 系统调用逻辑与核心代码分析
本文档旨在详细阐述本销售预测系统的端到端调用链路,从系统启动、前端交互、后端处理,到最终的模型训练、预测和图表展示。
## 1. 系统启动
系统由两部分组成Vue.js前端和Flask后端。
### 1.1. 启动后端API服务
在项目根目录下,通过以下命令启动后端服务:
```bash
python server/api.py
```
该命令会启动一个Flask应用监听在 `http://localhost:5000`并提供所有API和WebSocket服务。
### 1.2. 启动前端开发服务器
进入 `UI` 目录,执行以下命令:
```bash
cd UI
npm install
npm run dev
```
这将启动Vite开发服务器通常在 `http://localhost:5173`,并自动打开浏览器访问前端页面。
## 2. 核心调用链路概览
以最核心的 **“按药品训练 -> 按药品预测”** 流程为例,其高层调用链路如下:
**训练流程:**
`前端UI` -> `POST /api/training` -> `api.py: start_training()` -> `TrainingManager` -> `后台进程` -> `predictor.py: train_model()` -> `[model]_trainer.py: train_product_model_with_*()` -> `保存模型.pth`
**预测流程:**
`前端UI` -> `POST /api/prediction` -> `api.py: predict()` -> `predictor.py: predict()` -> `model_predictor.py: load_model_and_predict()` -> `加载模型.pth` -> `返回预测JSON` -> `前端图表渲染`
## 3. 详细流程:按药品训练
此流程的目标是为特定药品训练一个专用的预测模型。
### 3.1. 前端交互与API请求
1. **用户操作**: 用户在 **“按药品训练”** 页面 ([`UI/src/views/training/ProductTrainingView.vue`](UI/src/views/training/ProductTrainingView.vue:1)) 选择一个药品、一个模型类型如Transformer、设置训练轮次Epochs然后点击 **“启动药品训练”** 按钮。
2. **触发函数**: 点击事件调用 [`startTraining`](UI/src/views/training/ProductTrainingView.vue:521) 方法。
3. **构建Payload**: `startTraining` 方法构建一个包含训练参数的 `payload` 对象。关键字段是 `training_mode: 'product'`,用于告知后端这是针对特定产品的训练。
*核心代码 ([`UI/src/views/training/ProductTrainingView.vue`](UI/src/views/training/ProductTrainingView.vue:521))*
```javascript
const startTraining = async () => {
// ... 表单验证 ...
trainingLoading.value = true;
try {
const endpoint = "/api/training";
const payload = {
product_id: form.product_id,
store_id: form.data_scope === 'global' ? null : form.store_id,
model_type: form.model_type,
epochs: form.epochs,
training_mode: 'product' // 标识这是药品训练模式
};
const response = await axios.post(endpoint, payload);
// ... 处理响应启动WebSocket监听 ...
}
// ... 错误处理 ...
};
```
4. **API请求**: 使用 `axios` 向后端 `POST /api/training` 发送请求。
### 3.2. 后端API接收与任务分发
1. **路由处理**: 后端 [`server/api.py`](server/api.py:1) 中的 [`@app.route('/api/training', methods=['POST'])`](server/api.py:933) 装饰器捕获该请求,并由 [`start_training()`](server/api.py:971) 函数处理。
2. **任务提交**: `start_training()` 函数解析请求中的JSON数据然后调用 `training_manager.submit_task()` 将训练任务提交到一个后台进程池中执行以避免阻塞API主线程。这使得API可以立即返回一个任务ID而训练在后台异步进行。
*核心代码 ([`server/api.py`](server/api.py:971))*
```python
@app.route('/api/training', methods=['POST'])
def start_training():
data = request.get_json()
training_mode = data.get('training_mode', 'product')
model_type = data.get('model_type')
epochs = data.get('epochs', 50)
product_id = data.get('product_id')
store_id = data.get('store_id')
if not model_type or (training_mode == 'product' and not product_id):
return jsonify({'error': '缺少必要参数'}), 400
try:
# 使用训练进程管理器提交任务
task_id = training_manager.submit_task(
product_id=product_id or "unknown",
model_type=model_type,
training_mode=training_mode,
store_id=store_id,
epochs=epochs
)
logger.info(f"🚀 训练任务已提交到进程管理器: {task_id[:8]}")
return jsonify({
'message': '模型训练已开始(使用独立进程)',
'task_id': task_id,
})
except Exception as e:
logger.error(f"❌ 提交训练任务失败: {str(e)}")
return jsonify({'error': f'启动训练任务失败: {str(e)}'}), 500
```
### 3.3. 核心训练逻辑
1. **调用核心预测器**: 后台进程最终会调用 [`server/core/predictor.py`](server/core/predictor.py:1) 中的 [`PharmacyPredictor.train_model()`](server/core/predictor.py:63) 方法。
2. **数据准备**: `train_model` 方法首先根据 `training_mode` (`'product'`) 和 `product_id` 从数据源加载并聚合所有店铺关于该药品的销售数据。
3. **分发到具体训练器**: 接着,它根据 `model_type` 调用相应的训练函数。例如,如果 `model_type` 是 `transformer`,它会调用 `train_product_model_with_transformer`。
*核心代码 ([`server/core/predictor.py`](server/core/predictor.py:63))*
```python
class PharmacyPredictor:
def train_model(self, product_id, model_type='transformer', ..., training_mode='product', ...):
# ...
if training_mode == 'product':
product_data = self.data[self.data['product_id'] == product_id].copy()
# ...
# 根据训练模式构建模型标识符
model_identifier = product_id
try:
if model_type == 'transformer':
model_result, metrics, actual_version = train_product_model_with_transformer(
product_id=product_id,
model_identifier=model_identifier,
product_df=product_data,
# ... 其他参数 ...
)
# ... 其他模型的elif分支 ...
return metrics
except Exception as e:
# ... 错误处理 ...
return None
```
### 3.4. 模型训练与保存
1. **具体训练器**: 以 [`server/trainers/transformer_trainer.py`](server/trainers/transformer_trainer.py:1) 为例,`train_product_model_with_transformer` 函数执行以下步骤:
* **数据预处理**: 调用 `prepare_data` 和 `prepare_sequences` 将原始销售数据转换为模型可以理解的、带有时间序列特征的监督学习格式(输入序列和目标序列)。
* **模型实例化**: 创建 `TimeSeriesTransformer` 模型实例。
* **训练循环**: 执行指定的 `epochs` 次训练,计算损失并使用优化器更新模型权重。
* **进度更新**: 在训练过程中,通过 `socketio.emit` 向前端发送 `training_progress` 事件,实时更新进度条和日志。
* **模型保存**: 训练完成后,将模型权重 (`model.state_dict()`)、完整的模型配置 (`config`) 以及数据缩放器 (`scaler_X`, `scaler_y`) 打包成一个字典checkpoint并使用 `torch.save()` 保存到 `.pth` 文件中。文件名由 `get_model_file_path` 根据 `model_identifier`、`model_type` 和 `version` 统一生成。
*核心代码 ([`server/trainers/transformer_trainer.py`](server/trainers/transformer_trainer.py:33))*
```python
def train_product_model_with_transformer(...):
# ... 数据准备 ...
# 定义模型配置
config = {
'input_dim': input_dim,
'output_dim': forecast_horizon,
'hidden_size': hidden_size,
# ... 所有必要的超参数 ...
'model_type': 'transformer'
}
model = TimeSeriesTransformer(...)
# ... 训练循环 ...
# 保存模型
checkpoint = {
'model_state_dict': model.state_dict(),
'config': config,
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'metrics': test_metrics
}
model_path = get_model_file_path(model_identifier, 'transformer', version)
torch.save(checkpoint, model_path)
return model, test_metrics, version
```
## 4. 详细流程:按药品预测
训练完成后,用户可以使用已保存的模型进行预测。
### 4.1. 前端交互与API请求
1. **用户操作**: 用户在 **“按药品预测”** 页面 ([`UI/src/views/prediction/ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:1)) 选择同一个药品、对应的模型和版本,然后点击 **“开始预测”**。
2. **触发函数**: 点击事件调用 [`startPrediction`](UI/src/views/prediction/ProductPredictionView.vue:202) 方法。
3. **构建Payload**: 该方法构建一个包含预测参数的 `payload`。
*核心代码 ([`UI/src/views/prediction/ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:202))*
```javascript
const startPrediction = async () => {
try {
predicting.value = true
const payload = {
product_id: form.product_id,
model_type: form.model_type,
version: form.version,
future_days: form.future_days,
// training_mode is implicitly 'product' here
}
const response = await axios.post('/api/prediction', payload)
if (response.data.status === 'success') {
predictionResult.value = response.data
await nextTick()
renderChart()
}
// ... 错误处理 ...
}
// ...
}
```
4. **API请求**: 使用 `axios` 向后端 `POST /api/prediction` 发送请求。
### 4.2. 后端API接收与预测执行
1. **路由处理**: [`server/api.py`](server/api.py:1) 中的 [`@app.route('/api/prediction', methods=['POST'])`](server/api.py:1413) 捕获请求,由 [`predict()`](server/api.py:1469) 函数处理。
2. **调用核心预测器**: `predict()` 函数解析参数,然后调用 `run_prediction` 辅助函数,该函数内部再调用 [`server/core/predictor.py`](server/core/predictor.py:1) 中的 [`PharmacyPredictor.predict()`](server/core/predictor.py:295) 方法。
*核心代码 ([`server/api.py`](server/api.py:1469))*
```python
@app.route('/api/prediction', methods=['POST'])
def predict():
try:
data = request.json
# ... 解析参数 ...
training_mode = data.get('training_mode', 'product')
product_id = data.get('product_id')
# ...
# 根据模式确定模型标识符
if training_mode == 'product':
model_identifier = product_id
# ...
# 执行预测
prediction_result = run_prediction(model_type, product_id, model_id, ...)
# ... 格式化响应 ...
return jsonify(response_data)
except Exception as e:
# ... 错误处理 ...
```
3. **分发到模型加载器**: [`PharmacyPredictor.predict()`](server/core/predictor.py:295) 方法的主要作用是再次根据 `training_mode` 和 `product_id` 确定 `model_identifier`,然后将所有参数传递给 [`server/predictors/model_predictor.py`](server/predictors/model_predictor.py:1) 中的 [`load_model_and_predict()`](server/predictors/model_predictor.py:26) 函数。
*核心代码 ([`server/core/predictor.py`](server/core/predictor.py:295))*
```python
class PharmacyPredictor:
def predict(self, product_id, model_type, ..., training_mode='product', ...):
if training_mode == 'product':
model_identifier = product_id
# ...
return load_model_and_predict(
model_identifier,
model_type,
# ... 其他参数 ...
)
```
### 4.3. 模型加载与执行预测
[`load_model_and_predict()`](server/predictors/model_predictor.py:26) 是预测流程的核心,它执行以下步骤:
1. **定位模型文件**: 使用 `get_model_file_path` 根据 `product_id` (即 `model_identifier`), `model_type`, 和 `version` 找到之前保存的 `.pth` 模型文件。
2. **加载Checkpoint**: 使用 `torch.load()` 加载模型文件,得到包含 `model_state_dict`, `config`, 和 `scalers` 的字典。
3. **重建模型**: 根据加载的 `config` 中的超参数(如 `hidden_size`, `num_layers` 等),重新创建一个与训练时结构完全相同的模型实例。**这是我们之前修复的关键点,确保所有必要参数都被保存和加载。**
4. **加载权重**: 将加载的 `model_state_dict` 应用到新创建的模型实例上。
5. **准备输入数据**: 从数据源获取最新的 `sequence_length` 天的历史数据作为预测的输入。
6. **数据归一化**: 使用加载的 `scaler_X` 对输入数据进行归一化。
7. **执行预测**: 将归一化的数据输入模型 (`model(X_input)`),得到预测结果。
8. **反归一化**: 使用加载的 `scaler_y` 将模型的输出(预测值)反归一化,转换回原始的销售量尺度。
9. **构建结果**: 将预测值和对应的未来日期组合成一个DataFrame并连同历史数据一起返回。
*核心代码 ([`server/predictors/model_predictor.py`](server/predictors/model_predictor.py:26))*
```python
def load_model_and_predict(...):
# ... 找到模型文件路径 model_path ...
# 加载模型和配置
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
config = checkpoint['config']
scaler_X = checkpoint['scaler_X']
scaler_y = checkpoint['scaler_y']
# 创建模型实例 (以Transformer为例)
model = TimeSeriesTransformer(
num_features=config['input_dim'],
d_model=config['hidden_size'],
# ... 使用config中的所有参数 ...
).to(DEVICE)
# 加载模型参数
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# ... 准备输入数据 ...
# 归一化输入数据
X_scaled = scaler_X.transform(X)
X_input = torch.tensor(X_scaled.reshape(1, sequence_length, -1), ...).to(DEVICE)
# 预测
with torch.no_grad():
y_pred_scaled = model(X_input).cpu().numpy()
# 反归一化
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
# ... 构建返回结果 ...
return {
'predictions': predictions_df,
'history_data': recent_history,
# ...
}
```
### 4.4. 响应格式化与前端图表渲染
1. **API层格式化**: 在 [`server/api.py`](server/api.py:1) 的 [`predict()`](server/api.py:1469) 函数中,从 `load_model_and_predict` 返回的结果被精心格式化成前端期望的JSON结构该结构在顶层同时包含 `history_data` 和 `prediction_data` 两个数组。
2. **前端接收数据**: 前端 [`ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:1) 在 `startPrediction` 方法中接收到这个JSON响应并将其存入 `predictionResult` ref。
3. **图表渲染**: [`renderChart()`](UI/src/views/prediction/ProductPredictionView.vue:232) 方法被调用。它从 `predictionResult.value` 中提取 `history_data` 和 `prediction_data`然后使用Chart.js库将这两部分数据绘制在同一个 `<canvas>` 上,历史数据为实线,预测数据为虚线,从而形成一个连续的趋势图。
*核心代码 ([`UI/src/views/prediction/ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:232))*
```javascript
const renderChart = () => {
if (!chartCanvas.value || !predictionResult.value) return
// ...
// 后端直接提供 history_data 和 prediction_data
const historyData = predictionResult.value.history_data || []
const predictionData = predictionResult.value.prediction_data || []
const historyLabels = historyData.map(p => p.date)
const historySales = historyData.map(p => p.sales)
const predictionLabels = predictionData.map(p => p.date)
const predictionSales = predictionData.map(p => p.predicted_sales)
// ... 组合标签和数据,对齐数据点 ...
chart = new Chart(chartCanvas.value, {
type: 'line',
data: {
labels: allLabels,
datasets: [
{
label: '历史销量',
data: alignedHistorySales,
// ... 样式 ...
},
{
label: '预测销量',
data: alignedPredictionSales,
// ... 样式 ...
}
]
},
// ... Chart.js 配置 ...
})
}
```
至此,一个完整的“训练->预测->展示”的调用链路就完成了。
## 5. 模型保存规则与路径
为了确保模型的唯一性、可追溯性和可复现性,系统采用了一套严格的文件保存和命名规则。所有相关的逻辑都集中在 [`server/core/config.py`](server/core/config.py:1) 中。
### 5.1. 统一保存目录
所有训练产物包括模型权重、配置和数据缩放器Scalers都保存在项目根目录下的 `saved_models/` 文件夹中。
- **路径**: `PROJECT_ROOT/saved_models/`
- **定义**: 该路径由 [`server/core/config.py`](server/core/config.py:1) 中的 `DEFAULT_MODEL_DIR` 变量指定。
### 5.2. 文件命名规范
模型文件的命名遵循一个标准化的格式,以便在预测时能够被精确地定位和加载。该命名逻辑由 [`get_model_file_path()`](server/core/config.py:136) 函数统一管理。
**命名格式**: `{model_type}_{model_identifier}_epoch_{version}.pth`
**各部分说明**:
- `{model_type}`: 模型的算法类型。例如:`transformer`, `mlstm`, `tcn`, `kan`。
- `{model_identifier}`: 模型的唯一业务标识符,它根据训练模式(`training_mode`)动态生成:
- **按药品训练 (`product`)**: 标识符就是 `product_id`。
- *示例*: `transformer_17002608_epoch_best.pth`
- **按店铺训练 (`store`)**: 标识符是 `store_{store_id}`。
- *示例*: `tcn_store_01010023_epoch_best.pth`
- **全局训练 (`global`)**: 标识符是固定的字符串 `'global'`。
- *示例*: `mlstm_global_epoch_best.pth`
- `{version}`: 模型的版本。在训练过程中,通常会保存两个版本:
- `best`: 在验证集上表现最佳的模型。
- `{epoch_number}`: 训练完成时的最终模型,例如 `50`。
前端的“版本”下拉框中显示的就是这些版本字符串。
### 5.3. Checkpoint文件内容
每个 `.pth` 文件都是一个PyTorch Checkpoint它是一个Python字典包含了重建和使用模型所需的所有信息。这是确保预测与训练环境一致的关键。
**Checkpoint结构**:
```python
checkpoint = {
# 1. 模型权重
'model_state_dict': model.state_dict(),
# 2. 完整的模型配置
'config': {
'input_dim': ...,
'hidden_size': ...,
'num_layers': ...,
'model_type': 'transformer',
# ... 其他所有重建模型所需的超参数 ...
},
# 3. 数据归一化缩放器
'scaler_X': scaler_X, # 用于输入特征
'scaler_y': scaler_y, # 用于目标值(销量)
# 4. (可选) 模型性能指标
'metrics': {'mse': 0.01, 'mae': 0.05, ...}
}
```
**核心优势**:
- **可复现性**: 通过保存完整的 `config`我们可以在预测时精确地重建出与训练时结构完全相同的模型实例避免了因模型结构不匹配导致的加载失败这是之前修复的一个核心BUG
- **数据一致性**: 保存 `scaler_X` 和 `scaler_y` 确保了在预测时使用与训练时完全相同的归一化/反归一化逻辑,保证了预测结果的正确性。