ShopTRAINING/系统调用逻辑与核心代码分析.md

22 KiB
Raw Blame History

系统调用逻辑与核心代码分析

本文档旨在详细阐述本销售预测系统的端到端调用链路,从系统启动、前端交互、后端处理,到最终的模型训练、预测和图表展示。

1. 系统启动

系统由两部分组成Vue.js前端和Flask后端。

1.1. 启动后端API服务

在项目根目录下,通过以下命令启动后端服务:

python server/api.py

该命令会启动一个Flask应用监听在 http://localhost:5000并提供所有API和WebSocket服务。

1.2. 启动前端开发服务器

进入 UI 目录,执行以下命令:

cd UI
npm install
npm run dev

这将启动Vite开发服务器通常在 http://localhost:5173,并自动打开浏览器访问前端页面。

2. 核心调用链路概览

以最核心的 “按药品训练 -> 按药品预测” 流程为例,其高层调用链路如下:

训练流程: 前端UI -> POST /api/training -> api.py: start_training() -> TrainingManager -> 后台进程 -> predictor.py: train_model() -> [model]_trainer.py: train_product_model_with_*() -> 保存模型.pth

预测流程: 前端UI -> POST /api/prediction -> api.py: predict() -> predictor.py: predict() -> model_predictor.py: load_model_and_predict() -> 加载模型.pth -> 返回预测JSON -> 前端图表渲染

3. 详细流程:按药品训练

此流程的目标是为特定药品训练一个专用的预测模型。

3.1. 前端交互与API请求

  1. 用户操作: 用户在 “按药品训练” 页面 (UI/src/views/training/ProductTrainingView.vue) 选择一个药品、一个模型类型如Transformer、设置训练轮次Epochs然后点击 “启动药品训练” 按钮。

  2. 触发函数: 点击事件调用 startTraining 方法。

  3. 构建Payload: startTraining 方法构建一个包含训练参数的 payload 对象。关键字段是 training_mode: 'product',用于告知后端这是针对特定产品的训练。

    核心代码 (UI/src/views/training/ProductTrainingView.vue)

    const startTraining = async () => {
      // ... 表单验证 ...
      trainingLoading.value = true;
      try {
        const endpoint = "/api/training";
    
        const payload = {
          product_id: form.product_id,
          store_id: form.data_scope === 'global' ? null : form.store_id,
          model_type: form.model_type,
          epochs: form.epochs,
          training_mode: 'product'  // 标识这是药品训练模式
        };
    
        const response = await axios.post(endpoint, payload);
        // ... 处理响应启动WebSocket监听 ...
      } 
      // ... 错误处理 ...
    };
    
  4. API请求: 使用 axios 向后端 POST /api/training 发送请求。

3.2. 后端API接收与任务分发

  1. 路由处理: 后端 server/api.py 中的 @app.route('/api/training', methods=['POST']) 装饰器捕获该请求,并由 start_training() 函数处理。

  2. 任务提交: start_training() 函数解析请求中的JSON数据然后调用 training_manager.submit_task() 将训练任务提交到一个后台进程池中执行以避免阻塞API主线程。这使得API可以立即返回一个任务ID而训练在后台异步进行。

    核心代码 (server/api.py)

    @app.route('/api/training', methods=['POST'])
    def start_training():
        data = request.get_json()
    
        training_mode = data.get('training_mode', 'product')
        model_type = data.get('model_type')
        epochs = data.get('epochs', 50)
        product_id = data.get('product_id')
        store_id = data.get('store_id')
    
        if not model_type or (training_mode == 'product' and not product_id):
            return jsonify({'error': '缺少必要参数'}), 400
    
        try:
            # 使用训练进程管理器提交任务
            task_id = training_manager.submit_task(
                product_id=product_id or "unknown",
                model_type=model_type,
                training_mode=training_mode,
                store_id=store_id,
                epochs=epochs
            )
    
            logger.info(f"🚀 训练任务已提交到进程管理器: {task_id[:8]}")
    
            return jsonify({
                'message': '模型训练已开始(使用独立进程)',
                'task_id': task_id,
            })
    
        except Exception as e:
            logger.error(f"❌ 提交训练任务失败: {str(e)}")
            return jsonify({'error': f'启动训练任务失败: {str(e)}'}), 500
    

3.3. 核心训练逻辑

  1. 调用核心预测器: 后台进程最终会调用 server/core/predictor.py 中的 PharmacyPredictor.train_model() 方法。

  2. 数据准备: train_model 方法首先根据 training_mode ('product') 和 product_id 从数据源加载并聚合所有店铺关于该药品的销售数据。

  3. 分发到具体训练器: 接着,它根据 model_type 调用相应的训练函数。例如,如果 model_typetransformer,它会调用 train_product_model_with_transformer

    核心代码 (server/core/predictor.py)

    class PharmacyPredictor:
        def train_model(self, product_id, model_type='transformer', ..., training_mode='product', ...):
            # ...
            if training_mode == 'product':
                product_data = self.data[self.data['product_id'] == product_id].copy()
                # ...
    
            # 根据训练模式构建模型标识符
            model_identifier = product_id
    
            try:
                if model_type == 'transformer':
                    model_result, metrics, actual_version = train_product_model_with_transformer(
                        product_id=product_id,
                        model_identifier=model_identifier,
                        product_df=product_data,
                        # ... 其他参数 ...
                    )
                # ... 其他模型的elif分支 ...
    
                return metrics
            except Exception as e:
                # ... 错误处理 ...
                return None
    

3.4. 模型训练与保存

  1. 具体训练器: 以 server/trainers/transformer_trainer.py 为例,train_product_model_with_transformer 函数执行以下步骤:

    • 数据预处理: 调用 prepare_dataprepare_sequences 将原始销售数据转换为模型可以理解的、带有时间序列特征的监督学习格式(输入序列和目标序列)。
    • 模型实例化: 创建 TimeSeriesTransformer 模型实例。
    • 训练循环: 执行指定的 epochs 次训练,计算损失并使用优化器更新模型权重。
    • 进度更新: 在训练过程中,通过 socketio.emit 向前端发送 training_progress 事件,实时更新进度条和日志。
    • 模型保存: 训练完成后,将模型权重 (model.state_dict())、完整的模型配置 (config) 以及数据缩放器 (scaler_X, scaler_y) 打包成一个字典checkpoint并使用 torch.save() 保存到 .pth 文件中。文件名由 get_model_file_path 根据 model_identifiermodel_typeversion 统一生成。

    核心代码 (server/trainers/transformer_trainer.py)

    def train_product_model_with_transformer(...):
        # ... 数据准备 ...
    
        # 定义模型配置
        config = {
            'input_dim': input_dim,
            'output_dim': forecast_horizon,
            'hidden_size': hidden_size,
            # ... 所有必要的超参数 ...
            'model_type': 'transformer'
        }
    
        model = TimeSeriesTransformer(...)
    
        # ... 训练循环 ...
    
        # 保存模型
        checkpoint = {
            'model_state_dict': model.state_dict(),
            'config': config,
            'scaler_X': scaler_X,
            'scaler_y': scaler_y,
            'metrics': test_metrics
        }
    
        model_path = get_model_file_path(model_identifier, 'transformer', version)
        torch.save(checkpoint, model_path)
    
        return model, test_metrics, version
    

4. 详细流程:按药品预测

训练完成后,用户可以使用已保存的模型进行预测。

4.1. 前端交互与API请求

  1. 用户操作: 用户在 “按药品预测” 页面 (UI/src/views/prediction/ProductPredictionView.vue) 选择同一个药品、对应的模型和版本,然后点击 “开始预测”

  2. 触发函数: 点击事件调用 startPrediction 方法。

  3. 构建Payload: 该方法构建一个包含预测参数的 payload

    核心代码 (UI/src/views/prediction/ProductPredictionView.vue)

    const startPrediction = async () => {
      try {
        predicting.value = true
        const payload = {
          product_id: form.product_id,
          model_type: form.model_type,
          version: form.version,
          future_days: form.future_days,
          // training_mode is implicitly 'product' here
        }
        const response = await axios.post('/api/prediction', payload)
        if (response.data.status === 'success') {
          predictionResult.value = response.data
          await nextTick()
          renderChart()
        } 
        // ... 错误处理 ...
      } 
      // ...
    }
    
  4. API请求: 使用 axios 向后端 POST /api/prediction 发送请求。

4.2. 后端API接收与预测执行

  1. 路由处理: server/api.py 中的 @app.route('/api/prediction', methods=['POST']) 捕获请求,由 predict() 函数处理。

  2. 调用核心预测器: predict() 函数解析参数,然后调用 run_prediction 辅助函数,该函数内部再调用 server/core/predictor.py 中的 PharmacyPredictor.predict() 方法。

    核心代码 (server/api.py)

    @app.route('/api/prediction', methods=['POST'])
    def predict():
        try:
            data = request.json
            # ... 解析参数 ...
            training_mode = data.get('training_mode', 'product')
            product_id = data.get('product_id')
            # ...
    
            # 根据模式确定模型标识符
            if training_mode == 'product':
                model_identifier = product_id
            # ...
    
            # 执行预测
            prediction_result = run_prediction(model_type, product_id, model_id, ...)
    
            # ... 格式化响应 ...
            return jsonify(response_data)
        except Exception as e:
            # ... 错误处理 ...
    
  3. 分发到模型加载器: PharmacyPredictor.predict() 方法的主要作用是再次根据 training_modeproduct_id 确定 model_identifier,然后将所有参数传递给 server/predictors/model_predictor.py 中的 load_model_and_predict() 函数。

    核心代码 (server/core/predictor.py)

    class PharmacyPredictor:
        def predict(self, product_id, model_type, ..., training_mode='product', ...):
            if training_mode == 'product':
                model_identifier = product_id
            # ...
    
            return load_model_and_predict(
                model_identifier,
                model_type,
                # ... 其他参数 ...
            )
    

4.3. 模型加载与执行预测

load_model_and_predict() 是预测流程的核心,它执行以下步骤:

  1. 定位模型文件: 使用 get_model_file_path 根据 product_id (即 model_identifier), model_type, 和 version 找到之前保存的 .pth 模型文件。

  2. 加载Checkpoint: 使用 torch.load() 加载模型文件,得到包含 model_state_dict, config, 和 scalers 的字典。

  3. 重建模型: 根据加载的 config 中的超参数(如 hidden_size, num_layers 等),重新创建一个与训练时结构完全相同的模型实例。这是我们之前修复的关键点,确保所有必要参数都被保存和加载。

  4. 加载权重: 将加载的 model_state_dict 应用到新创建的模型实例上。

  5. 准备输入数据: 从数据源获取最新的 sequence_length 天的历史数据作为预测的输入。

  6. 数据归一化: 使用加载的 scaler_X 对输入数据进行归一化。

  7. 执行预测: 将归一化的数据输入模型 (model(X_input)),得到预测结果。

  8. 反归一化: 使用加载的 scaler_y 将模型的输出(预测值)反归一化,转换回原始的销售量尺度。

  9. 构建结果: 将预测值和对应的未来日期组合成一个DataFrame并连同历史数据一起返回。

    核心代码 (server/predictors/model_predictor.py)

    def load_model_and_predict(...):
        # ... 找到模型文件路径 model_path ...
    
        # 加载模型和配置
        checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
        config = checkpoint['config']
        scaler_X = checkpoint['scaler_X']
        scaler_y = checkpoint['scaler_y']
    
        # 创建模型实例 (以Transformer为例)
        model = TimeSeriesTransformer(
            num_features=config['input_dim'],
            d_model=config['hidden_size'],
            # ... 使用config中的所有参数 ...
        ).to(DEVICE)
    
        # 加载模型参数
        model.load_state_dict(checkpoint['model_state_dict'])
        model.eval()
    
        # ... 准备输入数据 ...
    
        # 归一化输入数据
        X_scaled = scaler_X.transform(X)
        X_input = torch.tensor(X_scaled.reshape(1, sequence_length, -1), ...).to(DEVICE)
    
        # 预测
        with torch.no_grad():
            y_pred_scaled = model(X_input).cpu().numpy()
    
        # 反归一化
        y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
    
        # ... 构建返回结果 ...
        return {
            'predictions': predictions_df,
            'history_data': recent_history,
            # ...
        }
    

4.4. 响应格式化与前端图表渲染

  1. API层格式化: 在 server/api.pypredict() 函数中,从 load_model_and_predict 返回的结果被精心格式化成前端期望的JSON结构该结构在顶层同时包含 history_dataprediction_data 两个数组。

  2. 前端接收数据: 前端 ProductPredictionView.vuestartPrediction 方法中接收到这个JSON响应并将其存入 predictionResult ref。

  3. 图表渲染: renderChart() 方法被调用。它从 predictionResult.value 中提取 history_dataprediction_data然后使用Chart.js库将这两部分数据绘制在同一个 <canvas> 上,历史数据为实线,预测数据为虚线,从而形成一个连续的趋势图。

    核心代码 (UI/src/views/prediction/ProductPredictionView.vue)

    const renderChart = () => {
      if (!chartCanvas.value || !predictionResult.value) return
      // ...
    
      // 后端直接提供 history_data 和 prediction_data
      const historyData = predictionResult.value.history_data || []
      const predictionData = predictionResult.value.prediction_data || []
    
      const historyLabels = historyData.map(p => p.date)
      const historySales = historyData.map(p => p.sales)
    
      const predictionLabels = predictionData.map(p => p.date)
      const predictionSales = predictionData.map(p => p.predicted_sales)
    
      // ... 组合标签和数据,对齐数据点 ...
    
      chart = new Chart(chartCanvas.value, {
        type: 'line',
        data: {
          labels: allLabels,
          datasets: [
            {
              label: '历史销量',
              data: alignedHistorySales,
              // ... 样式 ...
            },
            {
              label: '预测销量',
              data: alignedPredictionSales,
              // ... 样式 ...
            }
          ]
        },
        // ... Chart.js 配置 ...
      })
    }
    

至此,一个完整的“训练->预测->展示”的调用链路就完成了。

5. 模型保存与版本管理核心逻辑 (重构后)

为了根治版本混乱和模型加载失败的问题,系统进行了一项重要的重构。现在,所有与模型保存、命名和版本管理相关的逻辑都已统一集中server/utils/model_manager.pyModelManager 类中。

5.1. 统一管理者:ModelManager

  • 单一职责: ModelManager 是系统中唯一负责处理模型文件IO的组件。所有训练器 (trainer) 在需要保存模型时,都必须通过它来进行。
  • 核心功能:
    1. 自动版本控制: 自动生成和递增符合规范的版本号。
    2. 统一命名: 根据模型的元数据算法类型、训练模式、ID等生成标准化的文件名。
    3. 安全保存: 将模型数据和元数据一起打包保存到 .pth 文件中。
    4. 可靠检索: 提供统一的接口来列出和查找模型。

5.2. 统一版本规范

所有模型版本现在都遵循一个严格的、可预测的格式:

  • 数字版本: v{数字},例如 v1, v2, v3...
    • 生成: 当一次训练正常完成时,ModelManager 会自动计算出当前模型的下一个可用版本号(例如,如果已存在 v1v2,则新版本为 v3),并以此命名最终的模型文件。
    • 用途: 代表一次完整的、稳定的训练产出。
  • 特殊版本: best
    • 生成: 在训练过程中,如果某个 epoch 产生的模型在验证集上的性能超过了之前所有 epoch,训练器会调用 ModelManager 将这个模型保存为 best 版本,覆盖掉旧的 best 模型。
    • 用途: 始终指向该模型迄今为止性能最佳的一个版本,便于快速进行高质量的预测。

5.3. 统一命名约定 (v2版)

随着系统增加了“按店铺”和“全局”训练模式,ModelManagergenerate_model_filename 方法也已升级,以支持更丰富的、无歧义的命名格式:

  • 药品模型: {model_type}_product_{product_id}_{version}.pth
    • 示例: transformer_product_17002608_best.pth
  • 店铺模型: {model_type}_store_{store_id}_{version}.pth
    • 示例: mlstm_store_01010023_v2.pth
  • 全局模型: {model_type}_global_{aggregation_method}_{version}.pth
    • 示例: tcn_global_sum_v1.pth

这个新的命名系统确保了不同训练模式产出的模型可以清晰地被识别和管理。

5.4. Checkpoint文件内容 (结构不变)

每个 .pth 文件依然是一个包含模型权重、完整配置和数据缩放器的PyTorch Checkpoint。重构加强了所有训练器都必须将完整的配置信息存入 config 字典这一规则,确保了模型的完全可复现性。

5.5. 核心优势 (重构后)

  • 逻辑集中: 所有版本管理的复杂性都被封装在 ModelManager 内部,训练器只需调用 save_model 即可,无需关心版本号如何生成。
  • 数据一致性: 由于版本的生成、保存和检索都由同一个组件以同一种逻辑处理,从根本上杜绝了因命名或版本格式不匹配导致“模型未找到”的问题。
  • 易于维护: 未来如果需要修改版本策略或命名规则,只需修改 ModelManager 一个文件即可,无需改动所有训练器。

6. 核心流程的演进:支持店铺与全局模式

在最初的“按药品”流程基础上系统已重构以支持“按店铺”和“全局”的完整AI闭环。这引入了一些关键的逻辑变化

6.1. 训练流程的变化

  • 统一入口: 所有训练请求(药品、店铺、全局)都通过 POST /api/training 接口,由 training_mode 参数区分。
  • 数据聚合: 在 predictor.pytrain_model 方法中,会根据 training_mode 调用 aggregate_multi_store_data 函数,为店铺或全局模式准备正确的聚合时间序列数据。
  • 模型标识符: train_model 方法现在会生成一个唯一的 model_identifier(例如 product_17002608, store_01010023, global_sum),并将其传递给所有下游训练器。这是确保模型被正确命名的关键。

6.2. 预测流程的重大修复

预测流程经过了重大修复,以解决之前因逻辑不统一导致的 404 错误。

  • 废弃旧函数: core/config.py 中的 get_model_file_pathget_model_versions 等旧的、有缺陷的辅助函数已被完全废弃
  • 统一查找逻辑: 现在,api.pypredict 函数必须使用 model_manager.list_models() 方法来查找模型。
  • 可靠的路径传递: predict 函数找到正确的模型文件路径后,会将其作为一个参数,一路传递给 run_prediction 和最终的 load_model_and_predict 函数。
  • 根除缺陷: load_model_and_predict 函数内部所有手动的、过时的文件查找逻辑已被完全移除。它现在只负责接收一个明确的路径并加载模型。

这个修复确保了整个预测链路都依赖于 ModelManager 这一个“单一事实来源”,从根本上解决了因路径不匹配导致的预测失败问题。