系统开发设计指南

This commit is contained in:
LYFxiaoan 2025-07-17 15:52:04 +08:00
parent ee9ba299fa
commit e437658b9d
4 changed files with 581 additions and 1 deletions

View File

@ -165,7 +165,7 @@
### 16:16 - 项目状态更新
- **状态**: **所有已知问题已修复**
- **确认**: 用户已确认“现在药品和店铺预测流程通了
- **确认**: 用户已确认“现在药品和店铺预测流程通了。
- **后续**: 将本次修复过程归档至本文档。

Binary file not shown.

View File

@ -0,0 +1,464 @@
# 系统调用逻辑与核心代码分析
本文档旨在详细阐述本销售预测系统的端到端调用链路,从系统启动、前端交互、后端处理,到最终的模型训练、预测和图表展示。
## 1. 系统启动
系统由两部分组成Vue.js前端和Flask后端。
### 1.1. 启动后端API服务
在项目根目录下,通过以下命令启动后端服务:
```bash
python server/api.py
```
该命令会启动一个Flask应用监听在 `http://localhost:5000`并提供所有API和WebSocket服务。
### 1.2. 启动前端开发服务器
进入 `UI` 目录,执行以下命令:
```bash
cd UI
npm install
npm run dev
```
这将启动Vite开发服务器通常在 `http://localhost:5173`,并自动打开浏览器访问前端页面。
## 2. 核心调用链路概览
以最核心的 **“按药品训练 -> 按药品预测”** 流程为例,其高层调用链路如下:
**训练流程:**
`前端UI` -> `POST /api/training` -> `api.py: start_training()` -> `TrainingManager` -> `后台进程` -> `predictor.py: train_model()` -> `[model]_trainer.py: train_product_model_with_*()` -> `保存模型.pth`
**预测流程:**
`前端UI` -> `POST /api/prediction` -> `api.py: predict()` -> `predictor.py: predict()` -> `model_predictor.py: load_model_and_predict()` -> `加载模型.pth` -> `返回预测JSON` -> `前端图表渲染`
## 3. 详细流程:按药品训练
此流程的目标是为特定药品训练一个专用的预测模型。
### 3.1. 前端交互与API请求
1. **用户操作**: 用户在 **“按药品训练”** 页面 ([`UI/src/views/training/ProductTrainingView.vue`](UI/src/views/training/ProductTrainingView.vue:1)) 选择一个药品、一个模型类型如Transformer、设置训练轮次Epochs然后点击 **“启动药品训练”** 按钮。
2. **触发函数**: 点击事件调用 [`startTraining`](UI/src/views/training/ProductTrainingView.vue:521) 方法。
3. **构建Payload**: `startTraining` 方法构建一个包含训练参数的 `payload` 对象。关键字段是 `training_mode: 'product'`,用于告知后端这是针对特定产品的训练。
*核心代码 ([`UI/src/views/training/ProductTrainingView.vue`](UI/src/views/training/ProductTrainingView.vue:521))*
```javascript
const startTraining = async () => {
// ... 表单验证 ...
trainingLoading.value = true;
try {
const endpoint = "/api/training";
const payload = {
product_id: form.product_id,
store_id: form.data_scope === 'global' ? null : form.store_id,
model_type: form.model_type,
epochs: form.epochs,
training_mode: 'product' // 标识这是药品训练模式
};
const response = await axios.post(endpoint, payload);
// ... 处理响应启动WebSocket监听 ...
}
// ... 错误处理 ...
};
```
4. **API请求**: 使用 `axios` 向后端 `POST /api/training` 发送请求。
### 3.2. 后端API接收与任务分发
1. **路由处理**: 后端 [`server/api.py`](server/api.py:1) 中的 [`@app.route('/api/training', methods=['POST'])`](server/api.py:933) 装饰器捕获该请求,并由 [`start_training()`](server/api.py:971) 函数处理。
2. **任务提交**: `start_training()` 函数解析请求中的JSON数据然后调用 `training_manager.submit_task()` 将训练任务提交到一个后台进程池中执行以避免阻塞API主线程。这使得API可以立即返回一个任务ID而训练在后台异步进行。
*核心代码 ([`server/api.py`](server/api.py:971))*
```python
@app.route('/api/training', methods=['POST'])
def start_training():
data = request.get_json()
training_mode = data.get('training_mode', 'product')
model_type = data.get('model_type')
epochs = data.get('epochs', 50)
product_id = data.get('product_id')
store_id = data.get('store_id')
if not model_type or (training_mode == 'product' and not product_id):
return jsonify({'error': '缺少必要参数'}), 400
try:
# 使用训练进程管理器提交任务
task_id = training_manager.submit_task(
product_id=product_id or "unknown",
model_type=model_type,
training_mode=training_mode,
store_id=store_id,
epochs=epochs
)
logger.info(f"🚀 训练任务已提交到进程管理器: {task_id[:8]}")
return jsonify({
'message': '模型训练已开始(使用独立进程)',
'task_id': task_id,
})
except Exception as e:
logger.error(f"❌ 提交训练任务失败: {str(e)}")
return jsonify({'error': f'启动训练任务失败: {str(e)}'}), 500
```
### 3.3. 核心训练逻辑
1. **调用核心预测器**: 后台进程最终会调用 [`server/core/predictor.py`](server/core/predictor.py:1) 中的 [`PharmacyPredictor.train_model()`](server/core/predictor.py:63) 方法。
2. **数据准备**: `train_model` 方法首先根据 `training_mode` (`'product'`) 和 `product_id` 从数据源加载并聚合所有店铺关于该药品的销售数据。
3. **分发到具体训练器**: 接着,它根据 `model_type` 调用相应的训练函数。例如,如果 `model_type``transformer`,它会调用 `train_product_model_with_transformer`
*核心代码 ([`server/core/predictor.py`](server/core/predictor.py:63))*
```python
class PharmacyPredictor:
def train_model(self, product_id, model_type='transformer', ..., training_mode='product', ...):
# ...
if training_mode == 'product':
product_data = self.data[self.data['product_id'] == product_id].copy()
# ...
# 根据训练模式构建模型标识符
model_identifier = product_id
try:
if model_type == 'transformer':
model_result, metrics, actual_version = train_product_model_with_transformer(
product_id=product_id,
model_identifier=model_identifier,
product_df=product_data,
# ... 其他参数 ...
)
# ... 其他模型的elif分支 ...
return metrics
except Exception as e:
# ... 错误处理 ...
return None
```
### 3.4. 模型训练与保存
1. **具体训练器**: 以 [`server/trainers/transformer_trainer.py`](server/trainers/transformer_trainer.py:1) 为例,`train_product_model_with_transformer` 函数执行以下步骤:
* **数据预处理**: 调用 `prepare_data``prepare_sequences` 将原始销售数据转换为模型可以理解的、带有时间序列特征的监督学习格式(输入序列和目标序列)。
* **模型实例化**: 创建 `TimeSeriesTransformer` 模型实例。
* **训练循环**: 执行指定的 `epochs` 次训练,计算损失并使用优化器更新模型权重。
* **进度更新**: 在训练过程中,通过 `socketio.emit` 向前端发送 `training_progress` 事件,实时更新进度条和日志。
* **模型保存**: 训练完成后,将模型权重 (`model.state_dict()`)、完整的模型配置 (`config`) 以及数据缩放器 (`scaler_X`, `scaler_y`) 打包成一个字典checkpoint并使用 `torch.save()` 保存到 `.pth` 文件中。文件名由 `get_model_file_path` 根据 `model_identifier``model_type``version` 统一生成。
*核心代码 ([`server/trainers/transformer_trainer.py`](server/trainers/transformer_trainer.py:33))*
```python
def train_product_model_with_transformer(...):
# ... 数据准备 ...
# 定义模型配置
config = {
'input_dim': input_dim,
'output_dim': forecast_horizon,
'hidden_size': hidden_size,
# ... 所有必要的超参数 ...
'model_type': 'transformer'
}
model = TimeSeriesTransformer(...)
# ... 训练循环 ...
# 保存模型
checkpoint = {
'model_state_dict': model.state_dict(),
'config': config,
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'metrics': test_metrics
}
model_path = get_model_file_path(model_identifier, 'transformer', version)
torch.save(checkpoint, model_path)
return model, test_metrics, version
```
## 4. 详细流程:按药品预测
训练完成后,用户可以使用已保存的模型进行预测。
### 4.1. 前端交互与API请求
1. **用户操作**: 用户在 **“按药品预测”** 页面 ([`UI/src/views/prediction/ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:1)) 选择同一个药品、对应的模型和版本,然后点击 **“开始预测”**。
2. **触发函数**: 点击事件调用 [`startPrediction`](UI/src/views/prediction/ProductPredictionView.vue:202) 方法。
3. **构建Payload**: 该方法构建一个包含预测参数的 `payload`
*核心代码 ([`UI/src/views/prediction/ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:202))*
```javascript
const startPrediction = async () => {
try {
predicting.value = true
const payload = {
product_id: form.product_id,
model_type: form.model_type,
version: form.version,
future_days: form.future_days,
// training_mode is implicitly 'product' here
}
const response = await axios.post('/api/prediction', payload)
if (response.data.status === 'success') {
predictionResult.value = response.data
await nextTick()
renderChart()
}
// ... 错误处理 ...
}
// ...
}
```
4. **API请求**: 使用 `axios` 向后端 `POST /api/prediction` 发送请求。
### 4.2. 后端API接收与预测执行
1. **路由处理**: [`server/api.py`](server/api.py:1) 中的 [`@app.route('/api/prediction', methods=['POST'])`](server/api.py:1413) 捕获请求,由 [`predict()`](server/api.py:1469) 函数处理。
2. **调用核心预测器**: `predict()` 函数解析参数,然后调用 `run_prediction` 辅助函数,该函数内部再调用 [`server/core/predictor.py`](server/core/predictor.py:1) 中的 [`PharmacyPredictor.predict()`](server/core/predictor.py:295) 方法。
*核心代码 ([`server/api.py`](server/api.py:1469))*
```python
@app.route('/api/prediction', methods=['POST'])
def predict():
try:
data = request.json
# ... 解析参数 ...
training_mode = data.get('training_mode', 'product')
product_id = data.get('product_id')
# ...
# 根据模式确定模型标识符
if training_mode == 'product':
model_identifier = product_id
# ...
# 执行预测
prediction_result = run_prediction(model_type, product_id, model_id, ...)
# ... 格式化响应 ...
return jsonify(response_data)
except Exception as e:
# ... 错误处理 ...
```
3. **分发到模型加载器**: [`PharmacyPredictor.predict()`](server/core/predictor.py:295) 方法的主要作用是再次根据 `training_mode``product_id` 确定 `model_identifier`,然后将所有参数传递给 [`server/predictors/model_predictor.py`](server/predictors/model_predictor.py:1) 中的 [`load_model_and_predict()`](server/predictors/model_predictor.py:26) 函数。
*核心代码 ([`server/core/predictor.py`](server/core/predictor.py:295))*
```python
class PharmacyPredictor:
def predict(self, product_id, model_type, ..., training_mode='product', ...):
if training_mode == 'product':
model_identifier = product_id
# ...
return load_model_and_predict(
model_identifier,
model_type,
# ... 其他参数 ...
)
```
### 4.3. 模型加载与执行预测
[`load_model_and_predict()`](server/predictors/model_predictor.py:26) 是预测流程的核心,它执行以下步骤:
1. **定位模型文件**: 使用 `get_model_file_path` 根据 `product_id` (即 `model_identifier`), `model_type`, 和 `version` 找到之前保存的 `.pth` 模型文件。
2. **加载Checkpoint**: 使用 `torch.load()` 加载模型文件,得到包含 `model_state_dict`, `config`, 和 `scalers` 的字典。
3. **重建模型**: 根据加载的 `config` 中的超参数(如 `hidden_size`, `num_layers` 等),重新创建一个与训练时结构完全相同的模型实例。**这是我们之前修复的关键点,确保所有必要参数都被保存和加载。**
4. **加载权重**: 将加载的 `model_state_dict` 应用到新创建的模型实例上。
5. **准备输入数据**: 从数据源获取最新的 `sequence_length` 天的历史数据作为预测的输入。
6. **数据归一化**: 使用加载的 `scaler_X` 对输入数据进行归一化。
7. **执行预测**: 将归一化的数据输入模型 (`model(X_input)`),得到预测结果。
8. **反归一化**: 使用加载的 `scaler_y` 将模型的输出(预测值)反归一化,转换回原始的销售量尺度。
9. **构建结果**: 将预测值和对应的未来日期组合成一个DataFrame并连同历史数据一起返回。
*核心代码 ([`server/predictors/model_predictor.py`](server/predictors/model_predictor.py:26))*
```python
def load_model_and_predict(...):
# ... 找到模型文件路径 model_path ...
# 加载模型和配置
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
config = checkpoint['config']
scaler_X = checkpoint['scaler_X']
scaler_y = checkpoint['scaler_y']
# 创建模型实例 (以Transformer为例)
model = TimeSeriesTransformer(
num_features=config['input_dim'],
d_model=config['hidden_size'],
# ... 使用config中的所有参数 ...
).to(DEVICE)
# 加载模型参数
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# ... 准备输入数据 ...
# 归一化输入数据
X_scaled = scaler_X.transform(X)
X_input = torch.tensor(X_scaled.reshape(1, sequence_length, -1), ...).to(DEVICE)
# 预测
with torch.no_grad():
y_pred_scaled = model(X_input).cpu().numpy()
# 反归一化
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
# ... 构建返回结果 ...
return {
'predictions': predictions_df,
'history_data': recent_history,
# ...
}
```
### 4.4. 响应格式化与前端图表渲染
1. **API层格式化**: 在 [`server/api.py`](server/api.py:1) 的 [`predict()`](server/api.py:1469) 函数中,从 `load_model_and_predict` 返回的结果被精心格式化成前端期望的JSON结构该结构在顶层同时包含 `history_data``prediction_data` 两个数组。
2. **前端接收数据**: 前端 [`ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:1) 在 `startPrediction` 方法中接收到这个JSON响应并将其存入 `predictionResult` ref。
3. **图表渲染**: [`renderChart()`](UI/src/views/prediction/ProductPredictionView.vue:232) 方法被调用。它从 `predictionResult.value` 中提取 `history_data``prediction_data`然后使用Chart.js库将这两部分数据绘制在同一个 `<canvas>` 上,历史数据为实线,预测数据为虚线,从而形成一个连续的趋势图。
*核心代码 ([`UI/src/views/prediction/ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:232))*
```javascript
const renderChart = () => {
if (!chartCanvas.value || !predictionResult.value) return
// ...
// 后端直接提供 history_data 和 prediction_data
const historyData = predictionResult.value.history_data || []
const predictionData = predictionResult.value.prediction_data || []
const historyLabels = historyData.map(p => p.date)
const historySales = historyData.map(p => p.sales)
const predictionLabels = predictionData.map(p => p.date)
const predictionSales = predictionData.map(p => p.predicted_sales)
// ... 组合标签和数据,对齐数据点 ...
chart = new Chart(chartCanvas.value, {
type: 'line',
data: {
labels: allLabels,
datasets: [
{
label: '历史销量',
data: alignedHistorySales,
// ... 样式 ...
},
{
label: '预测销量',
data: alignedPredictionSales,
// ... 样式 ...
}
]
},
// ... Chart.js 配置 ...
})
}
```
至此,一个完整的“训练->预测->展示”的调用链路就完成了。
## 5. 模型保存规则与路径
为了确保模型的唯一性、可追溯性和可复现性,系统采用了一套严格的文件保存和命名规则。所有相关的逻辑都集中在 [`server/core/config.py`](server/core/config.py:1) 中。
### 5.1. 统一保存目录
所有训练产物包括模型权重、配置和数据缩放器Scalers都保存在项目根目录下的 `saved_models/` 文件夹中。
- **路径**: `PROJECT_ROOT/saved_models/`
- **定义**: 该路径由 [`server/core/config.py`](server/core/config.py:1) 中的 `DEFAULT_MODEL_DIR` 变量指定。
### 5.2. 文件命名规范
模型文件的命名遵循一个标准化的格式,以便在预测时能够被精确地定位和加载。该命名逻辑由 [`get_model_file_path()`](server/core/config.py:136) 函数统一管理。
**命名格式**: `{model_type}_{model_identifier}_epoch_{version}.pth`
**各部分说明**:
- `{model_type}`: 模型的算法类型。例如:`transformer`, `mlstm`, `tcn`, `kan`
- `{model_identifier}`: 模型的唯一业务标识符,它根据训练模式(`training_mode`)动态生成:
- **按药品训练 (`product`)**: 标识符就是 `product_id`
- *示例*: `transformer_17002608_epoch_best.pth`
- **按店铺训练 (`store`)**: 标识符是 `store_{store_id}`
- *示例*: `tcn_store_01010023_epoch_best.pth`
- **全局训练 (`global`)**: 标识符是固定的字符串 `'global'`
- *示例*: `mlstm_global_epoch_best.pth`
- `{version}`: 模型的版本。在训练过程中,通常会保存两个版本:
- `best`: 在验证集上表现最佳的模型。
- `{epoch_number}`: 训练完成时的最终模型,例如 `50`
前端的“版本”下拉框中显示的就是这些版本字符串。
### 5.3. Checkpoint文件内容
每个 `.pth` 文件都是一个PyTorch Checkpoint它是一个Python字典包含了重建和使用模型所需的所有信息。这是确保预测与训练环境一致的关键。
**Checkpoint结构**:
```python
checkpoint = {
# 1. 模型权重
'model_state_dict': model.state_dict(),
# 2. 完整的模型配置
'config': {
'input_dim': ...,
'hidden_size': ...,
'num_layers': ...,
'model_type': 'transformer',
# ... 其他所有重建模型所需的超参数 ...
},
# 3. 数据归一化缩放器
'scaler_X': scaler_X, # 用于输入特征
'scaler_y': scaler_y, # 用于目标值(销量)
# 4. (可选) 模型性能指标
'metrics': {'mse': 0.01, 'mae': 0.05, ...}
}
```
**核心优势**:
- **可复现性**: 通过保存完整的 `config`我们可以在预测时精确地重建出与训练时结构完全相同的模型实例避免了因模型结构不匹配导致的加载失败这是之前修复的一个核心BUG
- **数据一致性**: 保存 `scaler_X``scaler_y` 确保了在预测时使用与训练时完全相同的归一化/反归一化逻辑,保证了预测结果的正确性。

116
项目快速上手指南.md Normal file
View File

@ -0,0 +1,116 @@
# 项目快速上手指南 (面向新开发者)
欢迎加入项目本指南旨在帮助你快速理解项目的核心功能、技术架构和开发流程特别是为你一位Java背景的开发者提供清晰的切入点。
## 1. 项目是做什么的?(实现了什么功能)
这是一个基于历史销售数据的 **智能销售预测系统**
核心功能有三个全部通过Web界面操作
1. **模型训练**: 用户可以选择某个**药品**、某个**店铺**或**全局**数据然后选择一种机器学习算法如Transformer、mLSTM等进行训练最终生成一个预测模型。
2. **销售预测**: 使用已经训练好的模型,对未来的销量进行预测。
3. **结果可视化**: 将历史销量和预测销量在同一个图表中展示出来,方便用户直观地看到趋势。
简单来说,它就是一个 **"数据 -> 训练 -> 模型 -> 预测 -> 可视化"** 的完整闭环应用。
## 2. 用了什么技术?(技术栈)
你可以将这个项目的技术栈与Java世界进行类比
| 层面 | 本项目技术 | Java世界类比 | 说明 |
| :--- | :--- | :--- | :--- |
| **后端框架** | **Flask** | Spring Boot | 一个轻量级的Web框架用于提供API接口。 |
| **前端框架** | **Vue.js** | React / Angular | 用于构建用户交互界面的现代化JavaScript框架。 |
| **核心算法库** | **PyTorch** | (无直接对应) | 类似于Java的Deeplearning4j是实现深度学习算法的核心。 |
| **数据处理** | **Pandas** | (无直接对应) | Python中用于数据分析和处理的“瑞士军刀”可以看作是内存中的强大数据表格。 |
| **构建/打包** | **Vite** (前端) | Maven / Gradle | 前端项目的构建和依赖管理工具。 |
| **数据库** | **SQLite** | H2 / MySQL | 一个轻量级的本地文件数据库,用于记录预测历史等。 |
| **实时通信** | **Socket.IO** | WebSocket / STOMP | 用于后端在训练时向前端实时推送进度。 |
## 3. 系统架构是怎样的?(架构层级和设计)
本项目是经典的前后端分离架构,可以分为四个主要层次:
```
+------------------------------------------------------+
| 用户 (Browser) |
+------------------------------------------------------+
|
+------------------------------------------------------+
| 1. 前端层 (Frontend - Vue.js) |
| - Views (页面组件, e.g., ProductPredictionView.vue) |
| - API Calls (使用axios与后端通信) |
| - Charting (使用Chart.js进行图表渲染) |
+------------------------------------------------------+
| (HTTP/S, WebSocket)
+------------------------------------------------------+
| 2. 后端API层 (Backend API - Flask) |
| - api.py (类似Controller, 定义RESTful接口) |
| - 接收请求, 验证参数, 调用业务逻辑层 |
+------------------------------------------------------+
|
+------------------------------------------------------+
| 3. 业务逻辑层 (Business Logic - Python) |
| - core/predictor.py (类似Service层) |
| - 封装核心业务, 如“根据参数选择合适的训练器” |
+------------------------------------------------------+
|
+------------------------------------------------------+
| 4. 数据与模型层 (Data & Model - PyTorch/Pandas) |
| - trainers/*.py (具体的算法实现和训练逻辑) |
| - predictors/model_predictor.py (模型加载与预测逻辑) |
| - saved_models/ (存放训练好的.pth模型文件) |
| - data/ (存放原始数据.parquet文件) |
+------------------------------------------------------+
```
## 4. 关键执行流程
以最常见的“按药品预测”为例:
1. **前端**: 用户在页面上选择药品和模型点击“预测”按钮。Vue组件通过`axios`向后端发送一个POST请求到 `/api/prediction`
2. **API层**: `api.py` 接收到请求像一个Controller一样解析出药品ID、模型类型等参数。
3. **业务逻辑层**: `api.py` 调用 `core/predictor.py` 中的 `predict` 方法,将参数传递下去。这一层是业务的“调度中心”。
4. **模型层**: `core/predictor.py` 最终调用 `predictors/model_predictor.py` 中的 `load_model_and_predict` 函数。
5. **模型加载与执行**:
* 根据参数在 `saved_models/` 目录下找到对应的模型文件(例如 `transformer_17002608_epoch_best.pth`)。
* 加载文件,从中恢复出 **模型结构**、**模型权重** 和 **数据缩放器**
* 准备最新的历史数据作为输入,执行预测。
* 将预测结果返回。
6. **返回与渲染**: 结果逐层返回到`api.py`在这里被格式化为JSON然后发送给前端。前端接收到JSON后使用`Chart.js`将历史和预测数据画在图表上。
## 5. 如何添加一个新的算法?(开发者指南)
这是你最可能接触到的新功能开发。假设你要添加一个名为 `NewNet` 的新算法,你需要按以下步骤操作:
**目标**: 让 `NewNet` 出现在前端的“模型类型”下拉框中,并能成功训练和预测。
1. **创建训练器文件**:
* 在 `server/trainers/` 目录下,复制一份现有的训练器文件(例如 `tcn_trainer.py`)并重命名为 `newnet_trainer.py`
* 在 `newnet_trainer.py` 中:
* 定义你的 `NewNet` 模型类(继承自 `torch.nn.Module`)。
* 修改 `train_..._with_tcn` 函数,将其重命名为 `train_..._with_newnet`
* 在这个新函数里,确保实例化的是你的 `NewNet` 模型。
* **最关键的一步**: 在保存checkpoint时确保 `config` 字典里包含了重建 `NewNet` 所需的所有超参数(比如层数、节点数等)。
2. **注册新模型**:
* 打开 `server/core/config.py` 文件。
* 找到 `SUPPORTED_MODELS` 列表。
* 在列表中添加你的新模型标识符 `'newnet'`
3. **接入业务逻辑层 (训练)**:
* 打开 `server/core/predictor.py` 文件。
* 在 `train_model` 方法中,找到 `if/elif` 模型选择逻辑。
* 添加一个新的 `elif model_type == 'newnet':` 分支,让它调用你在第一步中创建的 `train_..._with_newnet` 函数。
4. **接入模型层 (预测)**:
* 打开 `server/predictors/model_predictor.py` 文件。
* 在 `load_model_and_predict` 函数中,找到 `if/elif` 模型实例化逻辑。
* 添加一个新的 `elif model_type == 'newnet':` 分支,确保它能根据 `config` 正确地创建 `NewNet` 模型实例。
5. **更新前端界面**:
* 打开 `UI/src/views/training/``UI/src/views/prediction/` 目录下的相关Vue文件`ProductTrainingView.vue`)。
* 找到定义模型选项的地方(通常是一个数组或对象)。
* 添加 `{ label: '新网络模型 (NewNet)', value: 'newnet' }` 这样的新选项。
完成以上步骤后,重启服务,你就可以在界面上选择并使用你的新算法了。