系统开发设计指南
This commit is contained in:
parent
ee9ba299fa
commit
e437658b9d
@ -165,7 +165,7 @@
|
||||
|
||||
### 16:16 - 项目状态更新
|
||||
- **状态**: **所有已知问题已修复**。
|
||||
- **确认**: 用户已确认“现在药品和店铺预测流程通了”。
|
||||
- **确认**: 用户已确认“现在药品和店铺预测流程通了。
|
||||
- **后续**: 将本次修复过程归档至本文档。
|
||||
|
||||
|
||||
|
Binary file not shown.
464
系统调用逻辑与核心代码分析.md
Normal file
464
系统调用逻辑与核心代码分析.md
Normal file
@ -0,0 +1,464 @@
|
||||
# 系统调用逻辑与核心代码分析
|
||||
|
||||
本文档旨在详细阐述本销售预测系统的端到端调用链路,从系统启动、前端交互、后端处理,到最终的模型训练、预测和图表展示。
|
||||
|
||||
## 1. 系统启动
|
||||
|
||||
系统由两部分组成:Vue.js前端和Flask后端。
|
||||
|
||||
### 1.1. 启动后端API服务
|
||||
|
||||
在项目根目录下,通过以下命令启动后端服务:
|
||||
|
||||
```bash
|
||||
python server/api.py
|
||||
```
|
||||
|
||||
该命令会启动一个Flask应用,监听在 `http://localhost:5000`,并提供所有API和WebSocket服务。
|
||||
|
||||
### 1.2. 启动前端开发服务器
|
||||
|
||||
进入 `UI` 目录,执行以下命令:
|
||||
|
||||
```bash
|
||||
cd UI
|
||||
npm install
|
||||
npm run dev
|
||||
```
|
||||
|
||||
这将启动Vite开发服务器,通常在 `http://localhost:5173`,并自动打开浏览器访问前端页面。
|
||||
|
||||
## 2. 核心调用链路概览
|
||||
|
||||
以最核心的 **“按药品训练 -> 按药品预测”** 流程为例,其高层调用链路如下:
|
||||
|
||||
**训练流程:**
|
||||
`前端UI` -> `POST /api/training` -> `api.py: start_training()` -> `TrainingManager` -> `后台进程` -> `predictor.py: train_model()` -> `[model]_trainer.py: train_product_model_with_*()` -> `保存模型.pth`
|
||||
|
||||
**预测流程:**
|
||||
`前端UI` -> `POST /api/prediction` -> `api.py: predict()` -> `predictor.py: predict()` -> `model_predictor.py: load_model_and_predict()` -> `加载模型.pth` -> `返回预测JSON` -> `前端图表渲染`
|
||||
|
||||
## 3. 详细流程:按药品训练
|
||||
|
||||
此流程的目标是为特定药品训练一个专用的预测模型。
|
||||
|
||||
### 3.1. 前端交互与API请求
|
||||
|
||||
1. **用户操作**: 用户在 **“按药品训练”** 页面 ([`UI/src/views/training/ProductTrainingView.vue`](UI/src/views/training/ProductTrainingView.vue:1)) 选择一个药品、一个模型类型(如Transformer)、设置训练轮次(Epochs),然后点击 **“启动药品训练”** 按钮。
|
||||
|
||||
2. **触发函数**: 点击事件调用 [`startTraining`](UI/src/views/training/ProductTrainingView.vue:521) 方法。
|
||||
|
||||
3. **构建Payload**: `startTraining` 方法构建一个包含训练参数的 `payload` 对象。关键字段是 `training_mode: 'product'`,用于告知后端这是针对特定产品的训练。
|
||||
|
||||
*核心代码 ([`UI/src/views/training/ProductTrainingView.vue`](UI/src/views/training/ProductTrainingView.vue:521))*
|
||||
```javascript
|
||||
const startTraining = async () => {
|
||||
// ... 表单验证 ...
|
||||
trainingLoading.value = true;
|
||||
try {
|
||||
const endpoint = "/api/training";
|
||||
|
||||
const payload = {
|
||||
product_id: form.product_id,
|
||||
store_id: form.data_scope === 'global' ? null : form.store_id,
|
||||
model_type: form.model_type,
|
||||
epochs: form.epochs,
|
||||
training_mode: 'product' // 标识这是药品训练模式
|
||||
};
|
||||
|
||||
const response = await axios.post(endpoint, payload);
|
||||
// ... 处理响应,启动WebSocket监听 ...
|
||||
}
|
||||
// ... 错误处理 ...
|
||||
};
|
||||
```
|
||||
|
||||
4. **API请求**: 使用 `axios` 向后端 `POST /api/training` 发送请求。
|
||||
|
||||
### 3.2. 后端API接收与任务分发
|
||||
|
||||
1. **路由处理**: 后端 [`server/api.py`](server/api.py:1) 中的 [`@app.route('/api/training', methods=['POST'])`](server/api.py:933) 装饰器捕获该请求,并由 [`start_training()`](server/api.py:971) 函数处理。
|
||||
|
||||
2. **任务提交**: `start_training()` 函数解析请求中的JSON数据,然后调用 `training_manager.submit_task()` 将训练任务提交到一个后台进程池中执行,以避免阻塞API主线程。这使得API可以立即返回一个任务ID,而训练在后台异步进行。
|
||||
|
||||
*核心代码 ([`server/api.py`](server/api.py:971))*
|
||||
```python
|
||||
@app.route('/api/training', methods=['POST'])
|
||||
def start_training():
|
||||
data = request.get_json()
|
||||
|
||||
training_mode = data.get('training_mode', 'product')
|
||||
model_type = data.get('model_type')
|
||||
epochs = data.get('epochs', 50)
|
||||
product_id = data.get('product_id')
|
||||
store_id = data.get('store_id')
|
||||
|
||||
if not model_type or (training_mode == 'product' and not product_id):
|
||||
return jsonify({'error': '缺少必要参数'}), 400
|
||||
|
||||
try:
|
||||
# 使用训练进程管理器提交任务
|
||||
task_id = training_manager.submit_task(
|
||||
product_id=product_id or "unknown",
|
||||
model_type=model_type,
|
||||
training_mode=training_mode,
|
||||
store_id=store_id,
|
||||
epochs=epochs
|
||||
)
|
||||
|
||||
logger.info(f"🚀 训练任务已提交到进程管理器: {task_id[:8]}")
|
||||
|
||||
return jsonify({
|
||||
'message': '模型训练已开始(使用独立进程)',
|
||||
'task_id': task_id,
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ 提交训练任务失败: {str(e)}")
|
||||
return jsonify({'error': f'启动训练任务失败: {str(e)}'}), 500
|
||||
```
|
||||
|
||||
### 3.3. 核心训练逻辑
|
||||
|
||||
1. **调用核心预测器**: 后台进程最终会调用 [`server/core/predictor.py`](server/core/predictor.py:1) 中的 [`PharmacyPredictor.train_model()`](server/core/predictor.py:63) 方法。
|
||||
|
||||
2. **数据准备**: `train_model` 方法首先根据 `training_mode` (`'product'`) 和 `product_id` 从数据源加载并聚合所有店铺关于该药品的销售数据。
|
||||
|
||||
3. **分发到具体训练器**: 接着,它根据 `model_type` 调用相应的训练函数。例如,如果 `model_type` 是 `transformer`,它会调用 `train_product_model_with_transformer`。
|
||||
|
||||
*核心代码 ([`server/core/predictor.py`](server/core/predictor.py:63))*
|
||||
```python
|
||||
class PharmacyPredictor:
|
||||
def train_model(self, product_id, model_type='transformer', ..., training_mode='product', ...):
|
||||
# ...
|
||||
if training_mode == 'product':
|
||||
product_data = self.data[self.data['product_id'] == product_id].copy()
|
||||
# ...
|
||||
|
||||
# 根据训练模式构建模型标识符
|
||||
model_identifier = product_id
|
||||
|
||||
try:
|
||||
if model_type == 'transformer':
|
||||
model_result, metrics, actual_version = train_product_model_with_transformer(
|
||||
product_id=product_id,
|
||||
model_identifier=model_identifier,
|
||||
product_df=product_data,
|
||||
# ... 其他参数 ...
|
||||
)
|
||||
# ... 其他模型的elif分支 ...
|
||||
|
||||
return metrics
|
||||
except Exception as e:
|
||||
# ... 错误处理 ...
|
||||
return None
|
||||
```
|
||||
|
||||
### 3.4. 模型训练与保存
|
||||
|
||||
1. **具体训练器**: 以 [`server/trainers/transformer_trainer.py`](server/trainers/transformer_trainer.py:1) 为例,`train_product_model_with_transformer` 函数执行以下步骤:
|
||||
* **数据预处理**: 调用 `prepare_data` 和 `prepare_sequences` 将原始销售数据转换为模型可以理解的、带有时间序列特征的监督学习格式(输入序列和目标序列)。
|
||||
* **模型实例化**: 创建 `TimeSeriesTransformer` 模型实例。
|
||||
* **训练循环**: 执行指定的 `epochs` 次训练,计算损失并使用优化器更新模型权重。
|
||||
* **进度更新**: 在训练过程中,通过 `socketio.emit` 向前端发送 `training_progress` 事件,实时更新进度条和日志。
|
||||
* **模型保存**: 训练完成后,将模型权重 (`model.state_dict()`)、完整的模型配置 (`config`) 以及数据缩放器 (`scaler_X`, `scaler_y`) 打包成一个字典(checkpoint),并使用 `torch.save()` 保存到 `.pth` 文件中。文件名由 `get_model_file_path` 根据 `model_identifier`、`model_type` 和 `version` 统一生成。
|
||||
|
||||
*核心代码 ([`server/trainers/transformer_trainer.py`](server/trainers/transformer_trainer.py:33))*
|
||||
```python
|
||||
def train_product_model_with_transformer(...):
|
||||
# ... 数据准备 ...
|
||||
|
||||
# 定义模型配置
|
||||
config = {
|
||||
'input_dim': input_dim,
|
||||
'output_dim': forecast_horizon,
|
||||
'hidden_size': hidden_size,
|
||||
# ... 所有必要的超参数 ...
|
||||
'model_type': 'transformer'
|
||||
}
|
||||
|
||||
model = TimeSeriesTransformer(...)
|
||||
|
||||
# ... 训练循环 ...
|
||||
|
||||
# 保存模型
|
||||
checkpoint = {
|
||||
'model_state_dict': model.state_dict(),
|
||||
'config': config,
|
||||
'scaler_X': scaler_X,
|
||||
'scaler_y': scaler_y,
|
||||
'metrics': test_metrics
|
||||
}
|
||||
|
||||
model_path = get_model_file_path(model_identifier, 'transformer', version)
|
||||
torch.save(checkpoint, model_path)
|
||||
|
||||
return model, test_metrics, version
|
||||
```
|
||||
|
||||
## 4. 详细流程:按药品预测
|
||||
|
||||
训练完成后,用户可以使用已保存的模型进行预测。
|
||||
|
||||
### 4.1. 前端交互与API请求
|
||||
|
||||
1. **用户操作**: 用户在 **“按药品预测”** 页面 ([`UI/src/views/prediction/ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:1)) 选择同一个药品、对应的模型和版本,然后点击 **“开始预测”**。
|
||||
|
||||
2. **触发函数**: 点击事件调用 [`startPrediction`](UI/src/views/prediction/ProductPredictionView.vue:202) 方法。
|
||||
|
||||
3. **构建Payload**: 该方法构建一个包含预测参数的 `payload`。
|
||||
|
||||
*核心代码 ([`UI/src/views/prediction/ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:202))*
|
||||
```javascript
|
||||
const startPrediction = async () => {
|
||||
try {
|
||||
predicting.value = true
|
||||
const payload = {
|
||||
product_id: form.product_id,
|
||||
model_type: form.model_type,
|
||||
version: form.version,
|
||||
future_days: form.future_days,
|
||||
// training_mode is implicitly 'product' here
|
||||
}
|
||||
const response = await axios.post('/api/prediction', payload)
|
||||
if (response.data.status === 'success') {
|
||||
predictionResult.value = response.data
|
||||
await nextTick()
|
||||
renderChart()
|
||||
}
|
||||
// ... 错误处理 ...
|
||||
}
|
||||
// ...
|
||||
}
|
||||
```
|
||||
|
||||
4. **API请求**: 使用 `axios` 向后端 `POST /api/prediction` 发送请求。
|
||||
|
||||
### 4.2. 后端API接收与预测执行
|
||||
|
||||
1. **路由处理**: [`server/api.py`](server/api.py:1) 中的 [`@app.route('/api/prediction', methods=['POST'])`](server/api.py:1413) 捕获请求,由 [`predict()`](server/api.py:1469) 函数处理。
|
||||
|
||||
2. **调用核心预测器**: `predict()` 函数解析参数,然后调用 `run_prediction` 辅助函数,该函数内部再调用 [`server/core/predictor.py`](server/core/predictor.py:1) 中的 [`PharmacyPredictor.predict()`](server/core/predictor.py:295) 方法。
|
||||
|
||||
*核心代码 ([`server/api.py`](server/api.py:1469))*
|
||||
```python
|
||||
@app.route('/api/prediction', methods=['POST'])
|
||||
def predict():
|
||||
try:
|
||||
data = request.json
|
||||
# ... 解析参数 ...
|
||||
training_mode = data.get('training_mode', 'product')
|
||||
product_id = data.get('product_id')
|
||||
# ...
|
||||
|
||||
# 根据模式确定模型标识符
|
||||
if training_mode == 'product':
|
||||
model_identifier = product_id
|
||||
# ...
|
||||
|
||||
# 执行预测
|
||||
prediction_result = run_prediction(model_type, product_id, model_id, ...)
|
||||
|
||||
# ... 格式化响应 ...
|
||||
return jsonify(response_data)
|
||||
except Exception as e:
|
||||
# ... 错误处理 ...
|
||||
```
|
||||
|
||||
3. **分发到模型加载器**: [`PharmacyPredictor.predict()`](server/core/predictor.py:295) 方法的主要作用是再次根据 `training_mode` 和 `product_id` 确定 `model_identifier`,然后将所有参数传递给 [`server/predictors/model_predictor.py`](server/predictors/model_predictor.py:1) 中的 [`load_model_and_predict()`](server/predictors/model_predictor.py:26) 函数。
|
||||
|
||||
*核心代码 ([`server/core/predictor.py`](server/core/predictor.py:295))*
|
||||
```python
|
||||
class PharmacyPredictor:
|
||||
def predict(self, product_id, model_type, ..., training_mode='product', ...):
|
||||
if training_mode == 'product':
|
||||
model_identifier = product_id
|
||||
# ...
|
||||
|
||||
return load_model_and_predict(
|
||||
model_identifier,
|
||||
model_type,
|
||||
# ... 其他参数 ...
|
||||
)
|
||||
```
|
||||
|
||||
### 4.3. 模型加载与执行预测
|
||||
|
||||
[`load_model_and_predict()`](server/predictors/model_predictor.py:26) 是预测流程的核心,它执行以下步骤:
|
||||
|
||||
1. **定位模型文件**: 使用 `get_model_file_path` 根据 `product_id` (即 `model_identifier`), `model_type`, 和 `version` 找到之前保存的 `.pth` 模型文件。
|
||||
|
||||
2. **加载Checkpoint**: 使用 `torch.load()` 加载模型文件,得到包含 `model_state_dict`, `config`, 和 `scalers` 的字典。
|
||||
|
||||
3. **重建模型**: 根据加载的 `config` 中的超参数(如 `hidden_size`, `num_layers` 等),重新创建一个与训练时结构完全相同的模型实例。**这是我们之前修复的关键点,确保所有必要参数都被保存和加载。**
|
||||
|
||||
4. **加载权重**: 将加载的 `model_state_dict` 应用到新创建的模型实例上。
|
||||
|
||||
5. **准备输入数据**: 从数据源获取最新的 `sequence_length` 天的历史数据作为预测的输入。
|
||||
|
||||
6. **数据归一化**: 使用加载的 `scaler_X` 对输入数据进行归一化。
|
||||
|
||||
7. **执行预测**: 将归一化的数据输入模型 (`model(X_input)`),得到预测结果。
|
||||
|
||||
8. **反归一化**: 使用加载的 `scaler_y` 将模型的输出(预测值)反归一化,转换回原始的销售量尺度。
|
||||
|
||||
9. **构建结果**: 将预测值和对应的未来日期组合成一个DataFrame,并连同历史数据一起返回。
|
||||
|
||||
*核心代码 ([`server/predictors/model_predictor.py`](server/predictors/model_predictor.py:26))*
|
||||
```python
|
||||
def load_model_and_predict(...):
|
||||
# ... 找到模型文件路径 model_path ...
|
||||
|
||||
# 加载模型和配置
|
||||
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
|
||||
config = checkpoint['config']
|
||||
scaler_X = checkpoint['scaler_X']
|
||||
scaler_y = checkpoint['scaler_y']
|
||||
|
||||
# 创建模型实例 (以Transformer为例)
|
||||
model = TimeSeriesTransformer(
|
||||
num_features=config['input_dim'],
|
||||
d_model=config['hidden_size'],
|
||||
# ... 使用config中的所有参数 ...
|
||||
).to(DEVICE)
|
||||
|
||||
# 加载模型参数
|
||||
model.load_state_dict(checkpoint['model_state_dict'])
|
||||
model.eval()
|
||||
|
||||
# ... 准备输入数据 ...
|
||||
|
||||
# 归一化输入数据
|
||||
X_scaled = scaler_X.transform(X)
|
||||
X_input = torch.tensor(X_scaled.reshape(1, sequence_length, -1), ...).to(DEVICE)
|
||||
|
||||
# 预测
|
||||
with torch.no_grad():
|
||||
y_pred_scaled = model(X_input).cpu().numpy()
|
||||
|
||||
# 反归一化
|
||||
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
|
||||
|
||||
# ... 构建返回结果 ...
|
||||
return {
|
||||
'predictions': predictions_df,
|
||||
'history_data': recent_history,
|
||||
# ...
|
||||
}
|
||||
```
|
||||
|
||||
### 4.4. 响应格式化与前端图表渲染
|
||||
|
||||
1. **API层格式化**: 在 [`server/api.py`](server/api.py:1) 的 [`predict()`](server/api.py:1469) 函数中,从 `load_model_and_predict` 返回的结果被精心格式化成前端期望的JSON结构,该结构在顶层同时包含 `history_data` 和 `prediction_data` 两个数组。
|
||||
|
||||
2. **前端接收数据**: 前端 [`ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:1) 在 `startPrediction` 方法中接收到这个JSON响应,并将其存入 `predictionResult` ref。
|
||||
|
||||
3. **图表渲染**: [`renderChart()`](UI/src/views/prediction/ProductPredictionView.vue:232) 方法被调用。它从 `predictionResult.value` 中提取 `history_data` 和 `prediction_data`,然后使用Chart.js库将这两部分数据绘制在同一个 `<canvas>` 上,历史数据为实线,预测数据为虚线,从而形成一个连续的趋势图。
|
||||
|
||||
*核心代码 ([`UI/src/views/prediction/ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:232))*
|
||||
```javascript
|
||||
const renderChart = () => {
|
||||
if (!chartCanvas.value || !predictionResult.value) return
|
||||
// ...
|
||||
|
||||
// 后端直接提供 history_data 和 prediction_data
|
||||
const historyData = predictionResult.value.history_data || []
|
||||
const predictionData = predictionResult.value.prediction_data || []
|
||||
|
||||
const historyLabels = historyData.map(p => p.date)
|
||||
const historySales = historyData.map(p => p.sales)
|
||||
|
||||
const predictionLabels = predictionData.map(p => p.date)
|
||||
const predictionSales = predictionData.map(p => p.predicted_sales)
|
||||
|
||||
// ... 组合标签和数据,对齐数据点 ...
|
||||
|
||||
chart = new Chart(chartCanvas.value, {
|
||||
type: 'line',
|
||||
data: {
|
||||
labels: allLabels,
|
||||
datasets: [
|
||||
{
|
||||
label: '历史销量',
|
||||
data: alignedHistorySales,
|
||||
// ... 样式 ...
|
||||
},
|
||||
{
|
||||
label: '预测销量',
|
||||
data: alignedPredictionSales,
|
||||
// ... 样式 ...
|
||||
}
|
||||
]
|
||||
},
|
||||
// ... Chart.js 配置 ...
|
||||
})
|
||||
}
|
||||
```
|
||||
|
||||
至此,一个完整的“训练->预测->展示”的调用链路就完成了。
|
||||
|
||||
## 5. 模型保存规则与路径
|
||||
|
||||
为了确保模型的唯一性、可追溯性和可复现性,系统采用了一套严格的文件保存和命名规则。所有相关的逻辑都集中在 [`server/core/config.py`](server/core/config.py:1) 中。
|
||||
|
||||
### 5.1. 统一保存目录
|
||||
|
||||
所有训练产物,包括模型权重、配置和数据缩放器(Scalers),都保存在项目根目录下的 `saved_models/` 文件夹中。
|
||||
|
||||
- **路径**: `PROJECT_ROOT/saved_models/`
|
||||
- **定义**: 该路径由 [`server/core/config.py`](server/core/config.py:1) 中的 `DEFAULT_MODEL_DIR` 变量指定。
|
||||
|
||||
### 5.2. 文件命名规范
|
||||
|
||||
模型文件的命名遵循一个标准化的格式,以便在预测时能够被精确地定位和加载。该命名逻辑由 [`get_model_file_path()`](server/core/config.py:136) 函数统一管理。
|
||||
|
||||
**命名格式**: `{model_type}_{model_identifier}_epoch_{version}.pth`
|
||||
|
||||
**各部分说明**:
|
||||
|
||||
- `{model_type}`: 模型的算法类型。例如:`transformer`, `mlstm`, `tcn`, `kan`。
|
||||
- `{model_identifier}`: 模型的唯一业务标识符,它根据训练模式(`training_mode`)动态生成:
|
||||
- **按药品训练 (`product`)**: 标识符就是 `product_id`。
|
||||
- *示例*: `transformer_17002608_epoch_best.pth`
|
||||
- **按店铺训练 (`store`)**: 标识符是 `store_{store_id}`。
|
||||
- *示例*: `tcn_store_01010023_epoch_best.pth`
|
||||
- **全局训练 (`global`)**: 标识符是固定的字符串 `'global'`。
|
||||
- *示例*: `mlstm_global_epoch_best.pth`
|
||||
- `{version}`: 模型的版本。在训练过程中,通常会保存两个版本:
|
||||
- `best`: 在验证集上表现最佳的模型。
|
||||
- `{epoch_number}`: 训练完成时的最终模型,例如 `50`。
|
||||
前端的“版本”下拉框中显示的就是这些版本字符串。
|
||||
|
||||
### 5.3. Checkpoint文件内容
|
||||
|
||||
每个 `.pth` 文件都是一个PyTorch Checkpoint,它是一个Python字典,包含了重建和使用模型所需的所有信息。这是确保预测与训练环境一致的关键。
|
||||
|
||||
**Checkpoint结构**:
|
||||
|
||||
```python
|
||||
checkpoint = {
|
||||
# 1. 模型权重
|
||||
'model_state_dict': model.state_dict(),
|
||||
|
||||
# 2. 完整的模型配置
|
||||
'config': {
|
||||
'input_dim': ...,
|
||||
'hidden_size': ...,
|
||||
'num_layers': ...,
|
||||
'model_type': 'transformer',
|
||||
# ... 其他所有重建模型所需的超参数 ...
|
||||
},
|
||||
|
||||
# 3. 数据归一化缩放器
|
||||
'scaler_X': scaler_X, # 用于输入特征
|
||||
'scaler_y': scaler_y, # 用于目标值(销量)
|
||||
|
||||
# 4. (可选) 模型性能指标
|
||||
'metrics': {'mse': 0.01, 'mae': 0.05, ...}
|
||||
}
|
||||
```
|
||||
|
||||
**核心优势**:
|
||||
|
||||
- **可复现性**: 通过保存完整的 `config`,我们可以在预测时精确地重建出与训练时结构完全相同的模型实例,避免了因模型结构不匹配导致的加载失败(这是之前修复的一个核心BUG)。
|
||||
- **数据一致性**: 保存 `scaler_X` 和 `scaler_y` 确保了在预测时使用与训练时完全相同的归一化/反归一化逻辑,保证了预测结果的正确性。
|
116
项目快速上手指南.md
Normal file
116
项目快速上手指南.md
Normal file
@ -0,0 +1,116 @@
|
||||
# 项目快速上手指南 (面向新开发者)
|
||||
|
||||
欢迎加入项目!本指南旨在帮助你快速理解项目的核心功能、技术架构和开发流程,特别是为你(一位Java背景的开发者)提供清晰的切入点。
|
||||
|
||||
## 1. 项目是做什么的?(实现了什么功能)
|
||||
|
||||
这是一个基于历史销售数据的 **智能销售预测系统**。
|
||||
|
||||
核心功能有三个,全部通过Web界面操作:
|
||||
1. **模型训练**: 用户可以选择某个**药品**、某个**店铺**或**全局**数据,然后选择一种机器学习算法(如Transformer、mLSTM等)进行训练,最终生成一个预测模型。
|
||||
2. **销售预测**: 使用已经训练好的模型,对未来的销量进行预测。
|
||||
3. **结果可视化**: 将历史销量和预测销量在同一个图表中展示出来,方便用户直观地看到趋势。
|
||||
|
||||
简单来说,它就是一个 **"数据 -> 训练 -> 模型 -> 预测 -> 可视化"** 的完整闭环应用。
|
||||
|
||||
## 2. 用了什么技术?(技术栈)
|
||||
|
||||
你可以将这个项目的技术栈与Java世界进行类比:
|
||||
|
||||
| 层面 | 本项目技术 | Java世界类比 | 说明 |
|
||||
| :--- | :--- | :--- | :--- |
|
||||
| **后端框架** | **Flask** | Spring Boot | 一个轻量级的Web框架,用于提供API接口。 |
|
||||
| **前端框架** | **Vue.js** | React / Angular | 用于构建用户交互界面的现代化JavaScript框架。 |
|
||||
| **核心算法库** | **PyTorch** | (无直接对应) | 类似于Java的Deeplearning4j,是实现深度学习算法的核心。 |
|
||||
| **数据处理** | **Pandas** | (无直接对应) | Python中用于数据分析和处理的“瑞士军刀”,可以看作是内存中的强大数据表格。 |
|
||||
| **构建/打包** | **Vite** (前端) | Maven / Gradle | 前端项目的构建和依赖管理工具。 |
|
||||
| **数据库** | **SQLite** | H2 / MySQL | 一个轻量级的本地文件数据库,用于记录预测历史等。 |
|
||||
| **实时通信** | **Socket.IO** | WebSocket / STOMP | 用于后端在训练时向前端实时推送进度。 |
|
||||
|
||||
## 3. 系统架构是怎样的?(架构层级和设计)
|
||||
|
||||
本项目是经典的前后端分离架构,可以分为四个主要层次:
|
||||
|
||||
```
|
||||
+------------------------------------------------------+
|
||||
| 用户 (Browser) |
|
||||
+------------------------------------------------------+
|
||||
|
|
||||
+------------------------------------------------------+
|
||||
| 1. 前端层 (Frontend - Vue.js) |
|
||||
| - Views (页面组件, e.g., ProductPredictionView.vue) |
|
||||
| - API Calls (使用axios与后端通信) |
|
||||
| - Charting (使用Chart.js进行图表渲染) |
|
||||
+------------------------------------------------------+
|
||||
| (HTTP/S, WebSocket)
|
||||
+------------------------------------------------------+
|
||||
| 2. 后端API层 (Backend API - Flask) |
|
||||
| - api.py (类似Controller, 定义RESTful接口) |
|
||||
| - 接收请求, 验证参数, 调用业务逻辑层 |
|
||||
+------------------------------------------------------+
|
||||
|
|
||||
+------------------------------------------------------+
|
||||
| 3. 业务逻辑层 (Business Logic - Python) |
|
||||
| - core/predictor.py (类似Service层) |
|
||||
| - 封装核心业务, 如“根据参数选择合适的训练器” |
|
||||
+------------------------------------------------------+
|
||||
|
|
||||
+------------------------------------------------------+
|
||||
| 4. 数据与模型层 (Data & Model - PyTorch/Pandas) |
|
||||
| - trainers/*.py (具体的算法实现和训练逻辑) |
|
||||
| - predictors/model_predictor.py (模型加载与预测逻辑) |
|
||||
| - saved_models/ (存放训练好的.pth模型文件) |
|
||||
| - data/ (存放原始数据.parquet文件) |
|
||||
+------------------------------------------------------+
|
||||
```
|
||||
|
||||
## 4. 关键执行流程
|
||||
|
||||
以最常见的“按药品预测”为例:
|
||||
|
||||
1. **前端**: 用户在页面上选择药品和模型,点击“预测”按钮。Vue组件通过`axios`向后端发送一个POST请求到 `/api/prediction`。
|
||||
2. **API层**: `api.py` 接收到请求,像一个Controller一样,解析出药品ID、模型类型等参数。
|
||||
3. **业务逻辑层**: `api.py` 调用 `core/predictor.py` 中的 `predict` 方法,将参数传递下去。这一层是业务的“调度中心”。
|
||||
4. **模型层**: `core/predictor.py` 最终调用 `predictors/model_predictor.py` 中的 `load_model_and_predict` 函数。
|
||||
5. **模型加载与执行**:
|
||||
* 根据参数在 `saved_models/` 目录下找到对应的模型文件(例如 `transformer_17002608_epoch_best.pth`)。
|
||||
* 加载文件,从中恢复出 **模型结构**、**模型权重** 和 **数据缩放器**。
|
||||
* 准备最新的历史数据作为输入,执行预测。
|
||||
* 将预测结果返回。
|
||||
6. **返回与渲染**: 结果逐层返回到`api.py`,在这里被格式化为JSON,然后发送给前端。前端接收到JSON后,使用`Chart.js`将历史和预测数据画在图表上。
|
||||
|
||||
## 5. 如何添加一个新的算法?(开发者指南)
|
||||
|
||||
这是你最可能接触到的新功能开发。假设你要添加一个名为 `NewNet` 的新算法,你需要按以下步骤操作:
|
||||
|
||||
**目标**: 让 `NewNet` 出现在前端的“模型类型”下拉框中,并能成功训练和预测。
|
||||
|
||||
1. **创建训练器文件**:
|
||||
* 在 `server/trainers/` 目录下,复制一份现有的训练器文件(例如 `tcn_trainer.py`)并重命名为 `newnet_trainer.py`。
|
||||
* 在 `newnet_trainer.py` 中:
|
||||
* 定义你的 `NewNet` 模型类(继承自 `torch.nn.Module`)。
|
||||
* 修改 `train_..._with_tcn` 函数,将其重命名为 `train_..._with_newnet`。
|
||||
* 在这个新函数里,确保实例化的是你的 `NewNet` 模型。
|
||||
* **最关键的一步**: 在保存checkpoint时,确保 `config` 字典里包含了重建 `NewNet` 所需的所有超参数(比如层数、节点数等)。
|
||||
|
||||
2. **注册新模型**:
|
||||
* 打开 `server/core/config.py` 文件。
|
||||
* 找到 `SUPPORTED_MODELS` 列表。
|
||||
* 在列表中添加你的新模型标识符 `'newnet'`。
|
||||
|
||||
3. **接入业务逻辑层 (训练)**:
|
||||
* 打开 `server/core/predictor.py` 文件。
|
||||
* 在 `train_model` 方法中,找到 `if/elif` 模型选择逻辑。
|
||||
* 添加一个新的 `elif model_type == 'newnet':` 分支,让它调用你在第一步中创建的 `train_..._with_newnet` 函数。
|
||||
|
||||
4. **接入模型层 (预测)**:
|
||||
* 打开 `server/predictors/model_predictor.py` 文件。
|
||||
* 在 `load_model_and_predict` 函数中,找到 `if/elif` 模型实例化逻辑。
|
||||
* 添加一个新的 `elif model_type == 'newnet':` 分支,确保它能根据 `config` 正确地创建 `NewNet` 模型实例。
|
||||
|
||||
5. **更新前端界面**:
|
||||
* 打开 `UI/src/views/training/` 和 `UI/src/views/prediction/` 目录下的相关Vue文件(如 `ProductTrainingView.vue`)。
|
||||
* 找到定义模型选项的地方(通常是一个数组或对象)。
|
||||
* 添加 `{ label: '新网络模型 (NewNet)', value: 'newnet' }` 这样的新选项。
|
||||
|
||||
完成以上步骤后,重启服务,你就可以在界面上选择并使用你的新算法了。
|
Loading…
x
Reference in New Issue
Block a user