From e437658b9dd73a09fdc0651e2ffd49c1edf2345a Mon Sep 17 00:00:00 2001 From: LYFxiaoan Date: Thu, 17 Jul 2025 15:52:04 +0800 Subject: [PATCH] =?UTF-8?q?=E7=B3=BB=E7=BB=9F=E5=BC=80=E5=8F=91=E8=AE=BE?= =?UTF-8?q?=E8=AE=A1=E6=8C=87=E5=8D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- lyf开发日志记录文档.md | 2 +- prediction_history.db | Bin 303104 -> 307200 bytes 系统调用逻辑与核心代码分析.md | 464 ++++++++++++++++++++++++++++++++++ 项目快速上手指南.md | 116 +++++++++ 4 files changed, 581 insertions(+), 1 deletion(-) create mode 100644 系统调用逻辑与核心代码分析.md create mode 100644 项目快速上手指南.md diff --git a/lyf开发日志记录文档.md b/lyf开发日志记录文档.md index 8aae27f..a75b2f1 100644 --- a/lyf开发日志记录文档.md +++ b/lyf开发日志记录文档.md @@ -165,7 +165,7 @@ ### 16:16 - 项目状态更新 - **状态**: **所有已知问题已修复**。 -- **确认**: 用户已确认“现在药品和店铺预测流程通了”。 +- **确认**: 用户已确认“现在药品和店铺预测流程通了。 - **后续**: 将本次修复过程归档至本文档。 diff --git a/prediction_history.db b/prediction_history.db index 329cd45e1b37fc423b153a5af8241cb52501dd7e..16cbe924d9fa39ac21b463885e535af512a709a3 100644 GIT binary patch delta 657 zcmZoTAk^?cXo56rJOcxR_e2GI#`uj1OZYjg`0fD(7`!L*2oy2#-PixOlARW($ORb+JROcfLqZImX) z{^@!dEE#;3RtCmaMrL}(2F4afM$;8@S+q5bEsZSoOf3x!%uGy7jSY=Vbrh6}6LV6F zm8=wy<)$C3W|5k%AH^b}4iz>qHa0TAt^+JJ{bL@B 按药品预测”** 流程为例,其高层调用链路如下: + +**训练流程:** +`前端UI` -> `POST /api/training` -> `api.py: start_training()` -> `TrainingManager` -> `后台进程` -> `predictor.py: train_model()` -> `[model]_trainer.py: train_product_model_with_*()` -> `保存模型.pth` + +**预测流程:** +`前端UI` -> `POST /api/prediction` -> `api.py: predict()` -> `predictor.py: predict()` -> `model_predictor.py: load_model_and_predict()` -> `加载模型.pth` -> `返回预测JSON` -> `前端图表渲染` + +## 3. 详细流程:按药品训练 + +此流程的目标是为特定药品训练一个专用的预测模型。 + +### 3.1. 前端交互与API请求 + +1. **用户操作**: 用户在 **“按药品训练”** 页面 ([`UI/src/views/training/ProductTrainingView.vue`](UI/src/views/training/ProductTrainingView.vue:1)) 选择一个药品、一个模型类型(如Transformer)、设置训练轮次(Epochs),然后点击 **“启动药品训练”** 按钮。 + +2. **触发函数**: 点击事件调用 [`startTraining`](UI/src/views/training/ProductTrainingView.vue:521) 方法。 + +3. **构建Payload**: `startTraining` 方法构建一个包含训练参数的 `payload` 对象。关键字段是 `training_mode: 'product'`,用于告知后端这是针对特定产品的训练。 + + *核心代码 ([`UI/src/views/training/ProductTrainingView.vue`](UI/src/views/training/ProductTrainingView.vue:521))* + ```javascript + const startTraining = async () => { + // ... 表单验证 ... + trainingLoading.value = true; + try { + const endpoint = "/api/training"; + + const payload = { + product_id: form.product_id, + store_id: form.data_scope === 'global' ? null : form.store_id, + model_type: form.model_type, + epochs: form.epochs, + training_mode: 'product' // 标识这是药品训练模式 + }; + + const response = await axios.post(endpoint, payload); + // ... 处理响应,启动WebSocket监听 ... + } + // ... 错误处理 ... + }; + ``` + +4. **API请求**: 使用 `axios` 向后端 `POST /api/training` 发送请求。 + +### 3.2. 后端API接收与任务分发 + +1. **路由处理**: 后端 [`server/api.py`](server/api.py:1) 中的 [`@app.route('/api/training', methods=['POST'])`](server/api.py:933) 装饰器捕获该请求,并由 [`start_training()`](server/api.py:971) 函数处理。 + +2. **任务提交**: `start_training()` 函数解析请求中的JSON数据,然后调用 `training_manager.submit_task()` 将训练任务提交到一个后台进程池中执行,以避免阻塞API主线程。这使得API可以立即返回一个任务ID,而训练在后台异步进行。 + + *核心代码 ([`server/api.py`](server/api.py:971))* + ```python + @app.route('/api/training', methods=['POST']) + def start_training(): + data = request.get_json() + + training_mode = data.get('training_mode', 'product') + model_type = data.get('model_type') + epochs = data.get('epochs', 50) + product_id = data.get('product_id') + store_id = data.get('store_id') + + if not model_type or (training_mode == 'product' and not product_id): + return jsonify({'error': '缺少必要参数'}), 400 + + try: + # 使用训练进程管理器提交任务 + task_id = training_manager.submit_task( + product_id=product_id or "unknown", + model_type=model_type, + training_mode=training_mode, + store_id=store_id, + epochs=epochs + ) + + logger.info(f"🚀 训练任务已提交到进程管理器: {task_id[:8]}") + + return jsonify({ + 'message': '模型训练已开始(使用独立进程)', + 'task_id': task_id, + }) + + except Exception as e: + logger.error(f"❌ 提交训练任务失败: {str(e)}") + return jsonify({'error': f'启动训练任务失败: {str(e)}'}), 500 + ``` + +### 3.3. 核心训练逻辑 + +1. **调用核心预测器**: 后台进程最终会调用 [`server/core/predictor.py`](server/core/predictor.py:1) 中的 [`PharmacyPredictor.train_model()`](server/core/predictor.py:63) 方法。 + +2. **数据准备**: `train_model` 方法首先根据 `training_mode` (`'product'`) 和 `product_id` 从数据源加载并聚合所有店铺关于该药品的销售数据。 + +3. **分发到具体训练器**: 接着,它根据 `model_type` 调用相应的训练函数。例如,如果 `model_type` 是 `transformer`,它会调用 `train_product_model_with_transformer`。 + + *核心代码 ([`server/core/predictor.py`](server/core/predictor.py:63))* + ```python + class PharmacyPredictor: + def train_model(self, product_id, model_type='transformer', ..., training_mode='product', ...): + # ... + if training_mode == 'product': + product_data = self.data[self.data['product_id'] == product_id].copy() + # ... + + # 根据训练模式构建模型标识符 + model_identifier = product_id + + try: + if model_type == 'transformer': + model_result, metrics, actual_version = train_product_model_with_transformer( + product_id=product_id, + model_identifier=model_identifier, + product_df=product_data, + # ... 其他参数 ... + ) + # ... 其他模型的elif分支 ... + + return metrics + except Exception as e: + # ... 错误处理 ... + return None + ``` + +### 3.4. 模型训练与保存 + +1. **具体训练器**: 以 [`server/trainers/transformer_trainer.py`](server/trainers/transformer_trainer.py:1) 为例,`train_product_model_with_transformer` 函数执行以下步骤: + * **数据预处理**: 调用 `prepare_data` 和 `prepare_sequences` 将原始销售数据转换为模型可以理解的、带有时间序列特征的监督学习格式(输入序列和目标序列)。 + * **模型实例化**: 创建 `TimeSeriesTransformer` 模型实例。 + * **训练循环**: 执行指定的 `epochs` 次训练,计算损失并使用优化器更新模型权重。 + * **进度更新**: 在训练过程中,通过 `socketio.emit` 向前端发送 `training_progress` 事件,实时更新进度条和日志。 + * **模型保存**: 训练完成后,将模型权重 (`model.state_dict()`)、完整的模型配置 (`config`) 以及数据缩放器 (`scaler_X`, `scaler_y`) 打包成一个字典(checkpoint),并使用 `torch.save()` 保存到 `.pth` 文件中。文件名由 `get_model_file_path` 根据 `model_identifier`、`model_type` 和 `version` 统一生成。 + + *核心代码 ([`server/trainers/transformer_trainer.py`](server/trainers/transformer_trainer.py:33))* + ```python + def train_product_model_with_transformer(...): + # ... 数据准备 ... + + # 定义模型配置 + config = { + 'input_dim': input_dim, + 'output_dim': forecast_horizon, + 'hidden_size': hidden_size, + # ... 所有必要的超参数 ... + 'model_type': 'transformer' + } + + model = TimeSeriesTransformer(...) + + # ... 训练循环 ... + + # 保存模型 + checkpoint = { + 'model_state_dict': model.state_dict(), + 'config': config, + 'scaler_X': scaler_X, + 'scaler_y': scaler_y, + 'metrics': test_metrics + } + + model_path = get_model_file_path(model_identifier, 'transformer', version) + torch.save(checkpoint, model_path) + + return model, test_metrics, version + ``` + +## 4. 详细流程:按药品预测 + +训练完成后,用户可以使用已保存的模型进行预测。 + +### 4.1. 前端交互与API请求 + +1. **用户操作**: 用户在 **“按药品预测”** 页面 ([`UI/src/views/prediction/ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:1)) 选择同一个药品、对应的模型和版本,然后点击 **“开始预测”**。 + +2. **触发函数**: 点击事件调用 [`startPrediction`](UI/src/views/prediction/ProductPredictionView.vue:202) 方法。 + +3. **构建Payload**: 该方法构建一个包含预测参数的 `payload`。 + + *核心代码 ([`UI/src/views/prediction/ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:202))* + ```javascript + const startPrediction = async () => { + try { + predicting.value = true + const payload = { + product_id: form.product_id, + model_type: form.model_type, + version: form.version, + future_days: form.future_days, + // training_mode is implicitly 'product' here + } + const response = await axios.post('/api/prediction', payload) + if (response.data.status === 'success') { + predictionResult.value = response.data + await nextTick() + renderChart() + } + // ... 错误处理 ... + } + // ... + } + ``` + +4. **API请求**: 使用 `axios` 向后端 `POST /api/prediction` 发送请求。 + +### 4.2. 后端API接收与预测执行 + +1. **路由处理**: [`server/api.py`](server/api.py:1) 中的 [`@app.route('/api/prediction', methods=['POST'])`](server/api.py:1413) 捕获请求,由 [`predict()`](server/api.py:1469) 函数处理。 + +2. **调用核心预测器**: `predict()` 函数解析参数,然后调用 `run_prediction` 辅助函数,该函数内部再调用 [`server/core/predictor.py`](server/core/predictor.py:1) 中的 [`PharmacyPredictor.predict()`](server/core/predictor.py:295) 方法。 + + *核心代码 ([`server/api.py`](server/api.py:1469))* + ```python + @app.route('/api/prediction', methods=['POST']) + def predict(): + try: + data = request.json + # ... 解析参数 ... + training_mode = data.get('training_mode', 'product') + product_id = data.get('product_id') + # ... + + # 根据模式确定模型标识符 + if training_mode == 'product': + model_identifier = product_id + # ... + + # 执行预测 + prediction_result = run_prediction(model_type, product_id, model_id, ...) + + # ... 格式化响应 ... + return jsonify(response_data) + except Exception as e: + # ... 错误处理 ... + ``` + +3. **分发到模型加载器**: [`PharmacyPredictor.predict()`](server/core/predictor.py:295) 方法的主要作用是再次根据 `training_mode` 和 `product_id` 确定 `model_identifier`,然后将所有参数传递给 [`server/predictors/model_predictor.py`](server/predictors/model_predictor.py:1) 中的 [`load_model_and_predict()`](server/predictors/model_predictor.py:26) 函数。 + + *核心代码 ([`server/core/predictor.py`](server/core/predictor.py:295))* + ```python + class PharmacyPredictor: + def predict(self, product_id, model_type, ..., training_mode='product', ...): + if training_mode == 'product': + model_identifier = product_id + # ... + + return load_model_and_predict( + model_identifier, + model_type, + # ... 其他参数 ... + ) + ``` + +### 4.3. 模型加载与执行预测 + +[`load_model_and_predict()`](server/predictors/model_predictor.py:26) 是预测流程的核心,它执行以下步骤: + +1. **定位模型文件**: 使用 `get_model_file_path` 根据 `product_id` (即 `model_identifier`), `model_type`, 和 `version` 找到之前保存的 `.pth` 模型文件。 + +2. **加载Checkpoint**: 使用 `torch.load()` 加载模型文件,得到包含 `model_state_dict`, `config`, 和 `scalers` 的字典。 + +3. **重建模型**: 根据加载的 `config` 中的超参数(如 `hidden_size`, `num_layers` 等),重新创建一个与训练时结构完全相同的模型实例。**这是我们之前修复的关键点,确保所有必要参数都被保存和加载。** + +4. **加载权重**: 将加载的 `model_state_dict` 应用到新创建的模型实例上。 + +5. **准备输入数据**: 从数据源获取最新的 `sequence_length` 天的历史数据作为预测的输入。 + +6. **数据归一化**: 使用加载的 `scaler_X` 对输入数据进行归一化。 + +7. **执行预测**: 将归一化的数据输入模型 (`model(X_input)`),得到预测结果。 + +8. **反归一化**: 使用加载的 `scaler_y` 将模型的输出(预测值)反归一化,转换回原始的销售量尺度。 + +9. **构建结果**: 将预测值和对应的未来日期组合成一个DataFrame,并连同历史数据一起返回。 + + *核心代码 ([`server/predictors/model_predictor.py`](server/predictors/model_predictor.py:26))* + ```python + def load_model_and_predict(...): + # ... 找到模型文件路径 model_path ... + + # 加载模型和配置 + checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False) + config = checkpoint['config'] + scaler_X = checkpoint['scaler_X'] + scaler_y = checkpoint['scaler_y'] + + # 创建模型实例 (以Transformer为例) + model = TimeSeriesTransformer( + num_features=config['input_dim'], + d_model=config['hidden_size'], + # ... 使用config中的所有参数 ... + ).to(DEVICE) + + # 加载模型参数 + model.load_state_dict(checkpoint['model_state_dict']) + model.eval() + + # ... 准备输入数据 ... + + # 归一化输入数据 + X_scaled = scaler_X.transform(X) + X_input = torch.tensor(X_scaled.reshape(1, sequence_length, -1), ...).to(DEVICE) + + # 预测 + with torch.no_grad(): + y_pred_scaled = model(X_input).cpu().numpy() + + # 反归一化 + y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten() + + # ... 构建返回结果 ... + return { + 'predictions': predictions_df, + 'history_data': recent_history, + # ... + } + ``` + +### 4.4. 响应格式化与前端图表渲染 + +1. **API层格式化**: 在 [`server/api.py`](server/api.py:1) 的 [`predict()`](server/api.py:1469) 函数中,从 `load_model_and_predict` 返回的结果被精心格式化成前端期望的JSON结构,该结构在顶层同时包含 `history_data` 和 `prediction_data` 两个数组。 + +2. **前端接收数据**: 前端 [`ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:1) 在 `startPrediction` 方法中接收到这个JSON响应,并将其存入 `predictionResult` ref。 + +3. **图表渲染**: [`renderChart()`](UI/src/views/prediction/ProductPredictionView.vue:232) 方法被调用。它从 `predictionResult.value` 中提取 `history_data` 和 `prediction_data`,然后使用Chart.js库将这两部分数据绘制在同一个 `` 上,历史数据为实线,预测数据为虚线,从而形成一个连续的趋势图。 + + *核心代码 ([`UI/src/views/prediction/ProductPredictionView.vue`](UI/src/views/prediction/ProductPredictionView.vue:232))* + ```javascript + const renderChart = () => { + if (!chartCanvas.value || !predictionResult.value) return + // ... + + // 后端直接提供 history_data 和 prediction_data + const historyData = predictionResult.value.history_data || [] + const predictionData = predictionResult.value.prediction_data || [] + + const historyLabels = historyData.map(p => p.date) + const historySales = historyData.map(p => p.sales) + + const predictionLabels = predictionData.map(p => p.date) + const predictionSales = predictionData.map(p => p.predicted_sales) + + // ... 组合标签和数据,对齐数据点 ... + + chart = new Chart(chartCanvas.value, { + type: 'line', + data: { + labels: allLabels, + datasets: [ + { + label: '历史销量', + data: alignedHistorySales, + // ... 样式 ... + }, + { + label: '预测销量', + data: alignedPredictionSales, + // ... 样式 ... + } + ] + }, + // ... Chart.js 配置 ... + }) + } + ``` + +至此,一个完整的“训练->预测->展示”的调用链路就完成了。 + +## 5. 模型保存规则与路径 + +为了确保模型的唯一性、可追溯性和可复现性,系统采用了一套严格的文件保存和命名规则。所有相关的逻辑都集中在 [`server/core/config.py`](server/core/config.py:1) 中。 + +### 5.1. 统一保存目录 + +所有训练产物,包括模型权重、配置和数据缩放器(Scalers),都保存在项目根目录下的 `saved_models/` 文件夹中。 + +- **路径**: `PROJECT_ROOT/saved_models/` +- **定义**: 该路径由 [`server/core/config.py`](server/core/config.py:1) 中的 `DEFAULT_MODEL_DIR` 变量指定。 + +### 5.2. 文件命名规范 + +模型文件的命名遵循一个标准化的格式,以便在预测时能够被精确地定位和加载。该命名逻辑由 [`get_model_file_path()`](server/core/config.py:136) 函数统一管理。 + +**命名格式**: `{model_type}_{model_identifier}_epoch_{version}.pth` + +**各部分说明**: + +- `{model_type}`: 模型的算法类型。例如:`transformer`, `mlstm`, `tcn`, `kan`。 +- `{model_identifier}`: 模型的唯一业务标识符,它根据训练模式(`training_mode`)动态生成: + - **按药品训练 (`product`)**: 标识符就是 `product_id`。 + - *示例*: `transformer_17002608_epoch_best.pth` + - **按店铺训练 (`store`)**: 标识符是 `store_{store_id}`。 + - *示例*: `tcn_store_01010023_epoch_best.pth` + - **全局训练 (`global`)**: 标识符是固定的字符串 `'global'`。 + - *示例*: `mlstm_global_epoch_best.pth` +- `{version}`: 模型的版本。在训练过程中,通常会保存两个版本: + - `best`: 在验证集上表现最佳的模型。 + - `{epoch_number}`: 训练完成时的最终模型,例如 `50`。 + 前端的“版本”下拉框中显示的就是这些版本字符串。 + +### 5.3. Checkpoint文件内容 + +每个 `.pth` 文件都是一个PyTorch Checkpoint,它是一个Python字典,包含了重建和使用模型所需的所有信息。这是确保预测与训练环境一致的关键。 + +**Checkpoint结构**: + +```python +checkpoint = { + # 1. 模型权重 + 'model_state_dict': model.state_dict(), + + # 2. 完整的模型配置 + 'config': { + 'input_dim': ..., + 'hidden_size': ..., + 'num_layers': ..., + 'model_type': 'transformer', + # ... 其他所有重建模型所需的超参数 ... + }, + + # 3. 数据归一化缩放器 + 'scaler_X': scaler_X, # 用于输入特征 + 'scaler_y': scaler_y, # 用于目标值(销量) + + # 4. (可选) 模型性能指标 + 'metrics': {'mse': 0.01, 'mae': 0.05, ...} +} +``` + +**核心优势**: + +- **可复现性**: 通过保存完整的 `config`,我们可以在预测时精确地重建出与训练时结构完全相同的模型实例,避免了因模型结构不匹配导致的加载失败(这是之前修复的一个核心BUG)。 +- **数据一致性**: 保存 `scaler_X` 和 `scaler_y` 确保了在预测时使用与训练时完全相同的归一化/反归一化逻辑,保证了预测结果的正确性。 \ No newline at end of file diff --git a/项目快速上手指南.md b/项目快速上手指南.md new file mode 100644 index 0000000..1912719 --- /dev/null +++ b/项目快速上手指南.md @@ -0,0 +1,116 @@ +# 项目快速上手指南 (面向新开发者) + +欢迎加入项目!本指南旨在帮助你快速理解项目的核心功能、技术架构和开发流程,特别是为你(一位Java背景的开发者)提供清晰的切入点。 + +## 1. 项目是做什么的?(实现了什么功能) + +这是一个基于历史销售数据的 **智能销售预测系统**。 + +核心功能有三个,全部通过Web界面操作: +1. **模型训练**: 用户可以选择某个**药品**、某个**店铺**或**全局**数据,然后选择一种机器学习算法(如Transformer、mLSTM等)进行训练,最终生成一个预测模型。 +2. **销售预测**: 使用已经训练好的模型,对未来的销量进行预测。 +3. **结果可视化**: 将历史销量和预测销量在同一个图表中展示出来,方便用户直观地看到趋势。 + +简单来说,它就是一个 **"数据 -> 训练 -> 模型 -> 预测 -> 可视化"** 的完整闭环应用。 + +## 2. 用了什么技术?(技术栈) + +你可以将这个项目的技术栈与Java世界进行类比: + +| 层面 | 本项目技术 | Java世界类比 | 说明 | +| :--- | :--- | :--- | :--- | +| **后端框架** | **Flask** | Spring Boot | 一个轻量级的Web框架,用于提供API接口。 | +| **前端框架** | **Vue.js** | React / Angular | 用于构建用户交互界面的现代化JavaScript框架。 | +| **核心算法库** | **PyTorch** | (无直接对应) | 类似于Java的Deeplearning4j,是实现深度学习算法的核心。 | +| **数据处理** | **Pandas** | (无直接对应) | Python中用于数据分析和处理的“瑞士军刀”,可以看作是内存中的强大数据表格。 | +| **构建/打包** | **Vite** (前端) | Maven / Gradle | 前端项目的构建和依赖管理工具。 | +| **数据库** | **SQLite** | H2 / MySQL | 一个轻量级的本地文件数据库,用于记录预测历史等。 | +| **实时通信** | **Socket.IO** | WebSocket / STOMP | 用于后端在训练时向前端实时推送进度。 | + +## 3. 系统架构是怎样的?(架构层级和设计) + +本项目是经典的前后端分离架构,可以分为四个主要层次: + +``` ++------------------------------------------------------+ +| 用户 (Browser) | ++------------------------------------------------------+ + | ++------------------------------------------------------+ +| 1. 前端层 (Frontend - Vue.js) | +| - Views (页面组件, e.g., ProductPredictionView.vue) | +| - API Calls (使用axios与后端通信) | +| - Charting (使用Chart.js进行图表渲染) | ++------------------------------------------------------+ + | (HTTP/S, WebSocket) ++------------------------------------------------------+ +| 2. 后端API层 (Backend API - Flask) | +| - api.py (类似Controller, 定义RESTful接口) | +| - 接收请求, 验证参数, 调用业务逻辑层 | ++------------------------------------------------------+ + | ++------------------------------------------------------+ +| 3. 业务逻辑层 (Business Logic - Python) | +| - core/predictor.py (类似Service层) | +| - 封装核心业务, 如“根据参数选择合适的训练器” | ++------------------------------------------------------+ + | ++------------------------------------------------------+ +| 4. 数据与模型层 (Data & Model - PyTorch/Pandas) | +| - trainers/*.py (具体的算法实现和训练逻辑) | +| - predictors/model_predictor.py (模型加载与预测逻辑) | +| - saved_models/ (存放训练好的.pth模型文件) | +| - data/ (存放原始数据.parquet文件) | ++------------------------------------------------------+ +``` + +## 4. 关键执行流程 + +以最常见的“按药品预测”为例: + +1. **前端**: 用户在页面上选择药品和模型,点击“预测”按钮。Vue组件通过`axios`向后端发送一个POST请求到 `/api/prediction`。 +2. **API层**: `api.py` 接收到请求,像一个Controller一样,解析出药品ID、模型类型等参数。 +3. **业务逻辑层**: `api.py` 调用 `core/predictor.py` 中的 `predict` 方法,将参数传递下去。这一层是业务的“调度中心”。 +4. **模型层**: `core/predictor.py` 最终调用 `predictors/model_predictor.py` 中的 `load_model_and_predict` 函数。 +5. **模型加载与执行**: + * 根据参数在 `saved_models/` 目录下找到对应的模型文件(例如 `transformer_17002608_epoch_best.pth`)。 + * 加载文件,从中恢复出 **模型结构**、**模型权重** 和 **数据缩放器**。 + * 准备最新的历史数据作为输入,执行预测。 + * 将预测结果返回。 +6. **返回与渲染**: 结果逐层返回到`api.py`,在这里被格式化为JSON,然后发送给前端。前端接收到JSON后,使用`Chart.js`将历史和预测数据画在图表上。 + +## 5. 如何添加一个新的算法?(开发者指南) + +这是你最可能接触到的新功能开发。假设你要添加一个名为 `NewNet` 的新算法,你需要按以下步骤操作: + +**目标**: 让 `NewNet` 出现在前端的“模型类型”下拉框中,并能成功训练和预测。 + +1. **创建训练器文件**: + * 在 `server/trainers/` 目录下,复制一份现有的训练器文件(例如 `tcn_trainer.py`)并重命名为 `newnet_trainer.py`。 + * 在 `newnet_trainer.py` 中: + * 定义你的 `NewNet` 模型类(继承自 `torch.nn.Module`)。 + * 修改 `train_..._with_tcn` 函数,将其重命名为 `train_..._with_newnet`。 + * 在这个新函数里,确保实例化的是你的 `NewNet` 模型。 + * **最关键的一步**: 在保存checkpoint时,确保 `config` 字典里包含了重建 `NewNet` 所需的所有超参数(比如层数、节点数等)。 + +2. **注册新模型**: + * 打开 `server/core/config.py` 文件。 + * 找到 `SUPPORTED_MODELS` 列表。 + * 在列表中添加你的新模型标识符 `'newnet'`。 + +3. **接入业务逻辑层 (训练)**: + * 打开 `server/core/predictor.py` 文件。 + * 在 `train_model` 方法中,找到 `if/elif` 模型选择逻辑。 + * 添加一个新的 `elif model_type == 'newnet':` 分支,让它调用你在第一步中创建的 `train_..._with_newnet` 函数。 + +4. **接入模型层 (预测)**: + * 打开 `server/predictors/model_predictor.py` 文件。 + * 在 `load_model_and_predict` 函数中,找到 `if/elif` 模型实例化逻辑。 + * 添加一个新的 `elif model_type == 'newnet':` 分支,确保它能根据 `config` 正确地创建 `NewNet` 模型实例。 + +5. **更新前端界面**: + * 打开 `UI/src/views/training/` 和 `UI/src/views/prediction/` 目录下的相关Vue文件(如 `ProductTrainingView.vue`)。 + * 找到定义模型选项的地方(通常是一个数组或对象)。 + * 添加 `{ label: '新网络模型 (NewNet)', value: 'newnet' }` 这样的新选项。 + +完成以上步骤后,重启服务,你就可以在界面上选择并使用你的新算法了。 \ No newline at end of file