From 9d7dcae1c87fff40a39431da977d79a48642d1d4 Mon Sep 17 00:00:00 2001 From: xz2000 Date: Wed, 23 Jul 2025 16:55:27 +0800 Subject: [PATCH] =?UTF-8?q?=E4=B8=80=E3=80=81=E4=BD=BF=E7=94=A8Swagger=20U?= =?UTF-8?q?I=20=E5=B1=95=E7=A4=BA=E8=8D=AF=E5=BA=97=E9=94=80=E5=94=AE?= =?UTF-8?q?=E9=A2=84=E6=B5=8B=E7=B3=BB=E7=BB=9FAPI=20=E4=BA=8C=E3=80=81?= =?UTF-8?q?=E5=AE=8C=E6=88=90=E6=96=B0=E5=A2=9E=E6=A8=A1=E5=9E=8Bxgboost?= =?UTF-8?q?=EF=BC=8Ccnn=5Fbilstm=5Fattention=E7=9A=84=E8=AE=AD=E7=BB=83?= =?UTF-8?q?=EF=BC=8C=E9=A2=84=E6=B5=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- server/swagger.json | 465 ++++++++++++++++++ .../trainers/cnn_bilstm_attention_trainer.py | 96 +++- server/trainers/xgboost_trainer.py | 39 +- 3 files changed, 579 insertions(+), 21 deletions(-) create mode 100644 server/swagger.json diff --git a/server/swagger.json b/server/swagger.json new file mode 100644 index 0000000..cbbe573 --- /dev/null +++ b/server/swagger.json @@ -0,0 +1,465 @@ +{ + "openapi": "3.0.0", + "info": { + "title": "药店销售预测系统API", + "description": "用于药店销售预测的RESTful API", + "version": "1.0.0", + "contact": { + "name": "API开发团队", + "email": "support@example.com" + } + }, + "tags": [ + { + "name": "数据管理", + "description": "数据上传和查询相关接口" + }, + { + "name": "模型训练", + "description": "模型训练相关接口" + }, + { + "name": "模型预测", + "description": "预测销售数据相关接口" + }, + { + "name": "模型管理", + "description": "模型查询、导出和删除接口" + } + ], + "paths": { + "/api/products": { + "get": { + "tags": ["数据管理"], + "summary": "获取所有产品列表", + "description": "返回系统中所有产品的ID和名称", + "responses": { + "200": { + "description": "成功获取产品列表", + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "status": {"type": "string"}, + "data": { + "type": "array", + "items": { + "type": "object", + "properties": { + "product_id": {"type": "string"}, + "product_name": {"type": "string"} + } + } + } + } + } + } + } + }, + "500": {"description": "服务器内部错误"} + } + } + }, + "/api/products/{product_id}": { + "get": { + "tags": ["数据管理"], + "summary": "获取单个产品详情", + "description": "返回指定产品ID的详细信息", + "parameters": [ + { + "name": "product_id", + "in": "path", + "required": true, + "schema": {"type": "string"}, + "description": "产品ID,例如P001" + } + ], + "responses": { + "200": {"description": "成功获取产品详情"}, + "404": {"description": "产品不存在"}, + "500": {"description": "服务器内部错误"} + } + } + }, + "/api/products/{product_id}/sales": { + "get": { + "tags": ["数据管理"], + "summary": "获取产品销售数据", + "description": "返回指定产品在特定日期范围内的销售数据", + "parameters": [ + { + "name": "product_id", + "in": "path", + "required": true, + "schema": {"type": "string"}, + "description": "产品ID,例如P001" + }, + { + "name": "start_date", + "in": "query", + "schema": {"type": "string"}, + "description": "开始日期,格式为YYYY-MM-DD" + }, + { + "name": "end_date", + "in": "query", + "schema": {"type": "string"}, + "description": "结束日期,格式为YYYY-MM-DD" + } + ], + "responses": { + "200": {"description": "成功获取销售数据"}, + "404": {"description": "产品不存在"}, + "500": {"description": "服务器内部错误"} + } + } + }, + "/api/data/upload": { + "post": { + "tags": ["数据管理"], + "summary": "上传销售数据", + "description": "上传新的销售数据文件(Excel格式)", + "requestBody": { + "content": { + "multipart/form-data": { + "schema": { + "type": "object", + "properties": { + "file": { + "type": "string", + "format": "binary" + } + } + } + } + } + }, + "responses": { + "200": {"description": "数据上传成功"}, + "400": {"description": "请求错误"}, + "500": {"description": "服务器内部错误"} + } + } + }, + "/api/training": { + "get": { + "tags": ["模型训练"], + "summary": "获取所有训练任务列表", + "description": "返回所有正在进行、已完成或失败的训练任务", + "responses": { + "200": {"description": "成功获取任务列表"} + } + }, + "post": { + "tags": ["模型训练"], + "summary": "启动模型训练任务", + "description": "为指定产品启动一个新的模型训练任务", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "product_id": {"type": "string", "example": "P001"}, + "model_type": {"type": "string", "enum": ["mlstm", "transformer", "kan", "optimized_kan", "tcn", "xgboost"]}, + "store_id": {"type": "string", "example": "S001"}, + "epochs": {"type": "integer", "default": 50} + }, + "required": ["product_id", "model_type"] + } + } + } + }, + "responses": { + "200": {"description": "训练任务已启动"}, + "400": {"description": "请求错误"} + } + } + }, + "/api/training/{task_id}": { + "get": { + "tags": ["模型训练"], + "summary": "查询训练任务状态", + "description": "获取特定训练任务的当前状态和详情", + "parameters": [ + { + "name": "task_id", + "in": "path", + "required": true, + "schema": {"type": "string"}, + "description": "训练任务ID" + } + ], + "responses": { + "200": {"description": "成功获取任务状态"}, + "404": {"description": "任务不存在"}, + "500": {"description": "服务器内部错误"} + } + } + }, + "/api/prediction": { + "post": { + "tags": ["模型预测"], + "summary": "使用模型进行预测", + "description": "使用指定模型预测未来销售数据", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "product_id": {"type": "string"}, + "model_type": {"type": "string", "enum": ["mlstm", "transformer", "kan", "optimized_kan", "tcn"]}, + "store_id": {"type": "string"}, + "version": {"type": "string"}, + "future_days": {"type": "integer"}, + "include_visualization": {"type": "boolean"}, + "start_date": {"type": "string"} + }, + "required": ["product_id", "model_type"] + } + } + } + }, + "responses": { + "200": {"description": "预测成功"}, + "400": {"description": "请求错误"}, + "404": {"description": "产品或模型不存在"}, + "500": {"description": "服务器内部错误"} + } + } + }, + "/api/prediction/compare": { + "post": { + "tags": ["模型预测"], + "summary": "比较不同模型预测结果", + "description": "比较不同模型对同一产品的预测结果", + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "type": "object", + "properties": { + "product_id": {"type": "string"}, + "model_types": {"type": "array", "items": {"type": "string"}}, + "versions": {"type": "array", "items": {"type": "string"}}, + "include_visualization": {"type": "boolean"} + }, + "required": ["product_id", "model_types"] + } + } + } + }, + "responses": { + "200": {"description": "比较成功"}, + "400": {"description": "请求错误"}, + "404": {"description": "产品或模型不存在"}, + "500": {"description": "服务器内部错误"} + } + } + }, + "/api/prediction/history": { + "get": { + "tags": ["模型预测"], + "summary": "获取历史预测记录", + "responses": { + "200": {"description": "获取成功"} + } + } + }, + "/api/prediction/history/{prediction_id}": { + "get": { + "tags": ["模型预测"], + "summary": "获取特定预测记录的详情", + "parameters": [ + { + "name": "prediction_id", + "in": "path", + "required": true, + "schema": {"type": "string"} + } + ], + "responses": { + "200": {"description": "获取成功"}, + "404": {"description": "记录不存在"} + } + }, + "delete": { + "tags": ["模型预测"], + "summary": "删除预测记录", + "parameters": [ + { + "name": "prediction_id", + "in": "path", + "required": true, + "schema": {"type": "string"} + } + ], + "responses": { + "200": {"description": "删除成功"}, + "404": {"description": "记录不存在"} + } + } + }, + "/api/models": { + "get": { + "tags": ["模型管理"], + "summary": "获取模型列表", + "parameters": [ + {"name": "product_id", "in": "query", "schema": {"type": "string"}}, + {"name": "model_type", "in": "query", "schema": {"type": "string"}} + ], + "responses": { + "200": {"description": "获取成功"} + } + } + }, + "/api/models/{model_id}": { + "get": { + "tags": ["模型管理"], + "summary": "获取模型详情", + "parameters": [ + {"name": "model_id", "in": "path", "required": true, "schema": {"type": "string"}} + ], + "responses": { + "200": {"description": "获取成功"}, + "404": {"description": "模型不存在"} + } + }, + "delete": { + "tags": ["模型管理"], + "summary": "删除模型", + "parameters": [ + {"name": "model_id", "in": "path", "required": true, "schema": {"type": "string"}} + ], + "responses": { + "200": {"description": "删除成功"}, + "404": {"description": "模型不存在"} + } + } + }, + "/api/models/{model_id}/export": { + "get": { + "tags": ["模型管理"], + "summary": "导出模型", + "parameters": [ + {"name": "model_id", "in": "path", "required": true, "schema": {"type": "string"}} + ], + "responses": { + "200": {"description": "模型文件下载"}, + "404": {"description": "模型不存在"} + } + } + }, + "/api/model_types": { + "get": { + "tags": ["模型管理"], + "summary": "获取系统支持的所有模型类型", + "responses": { + "200": {"description": "获取成功"} + } + } + }, + "/api/models/{product_id}/{model_type}/versions": { + "get": { + "tags": ["模型管理"], + "summary": "获取模型版本列表", + "parameters": [ + {"name": "product_id", "in": "path", "required": true, "schema": {"type": "string"}}, + {"name": "model_type", "in": "path", "required": true, "schema": {"type": "string"}} + ], + "responses": { + "200": {"description": "获取成功"} + } + } + }, + "/api/stores": { + "get": { + "tags": ["数据管理"], + "summary": "获取所有店铺列表", + "responses": { + "200": {"description": "获取成功"} + } + }, + "post": { + "tags": ["数据管理"], + "summary": "创建新店铺", + "responses": { + "200": {"description": "创建成功"} + } + } + }, + "/api/stores/{store_id}": { + "get": { + "tags": ["数据管理"], + "summary": "获取单个店铺信息", + "parameters": [ + {"name": "store_id", "in": "path", "required": true, "schema": {"type": "string"}} + ], + "responses": { + "200": {"description": "获取成功"}, + "404": {"description": "店铺不存在"} + } + }, + "put": { + "tags": ["数据管理"], + "summary": "更新店铺信息", + "parameters": [ + {"name": "store_id", "in": "path", "required": true, "schema": {"type": "string"}} + ], + "responses": { + "200": {"description": "更新成功"}, + "404": {"description": "店铺不存在"} + } + }, + "delete": { + "tags": ["数据管理"], + "summary": "删除店铺", + "parameters": [ + {"name": "store_id", "in": "path", "required": true, "schema": {"type": "string"}} + ], + "responses": { + "200": {"description": "删除成功"}, + "404": {"description": "店铺不存在"} + } + } + }, + "/api/stores/{store_id}/products": { + "get": { + "tags": ["数据管理"], + "summary": "获取店铺的产品列表", + "parameters": [ + {"name": "store_id", "in": "path", "required": true, "schema": {"type": "string"}} + ], + "responses": { + "200": {"description": "获取成功"} + } + } + }, + "/api/stores/{store_id}/statistics": { + "get": { + "tags": ["数据管理"], + "summary": "获取店铺销售统计信息", + "parameters": [ + {"name": "store_id", "in": "path", "required": true, "schema": {"type": "string"}} + ], + "responses": { + "200": {"description": "获取成功"} + } + } + }, + "/api/sales/data": { + "get": { + "tags": ["数据管理"], + "summary": "获取销售数据列表", + "responses": { + "200": {"description": "获取成功"} + } + } + } + } +} \ No newline at end of file diff --git a/server/trainers/cnn_bilstm_attention_trainer.py b/server/trainers/cnn_bilstm_attention_trainer.py index 35a9149..c8d4147 100644 --- a/server/trainers/cnn_bilstm_attention_trainer.py +++ b/server/trainers/cnn_bilstm_attention_trainer.py @@ -6,25 +6,29 @@ CNN-BiLSTM-Attention 模型训练器 import torch import torch.optim as optim import numpy as np +import time +import copy from models.model_registry import register_trainer from utils.model_manager import model_manager from analysis.metrics import evaluate_model from utils.data_utils import create_dataset from sklearn.preprocessing import MinMaxScaler +from utils.visualization import plot_loss_curve # 导入绘图函数 # 导入新创建的模型 from models.cnn_bilstm_attention import CnnBiLstmAttention def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs): """ - 使用 CNN-BiLSTM-Attention 模型进行训练。 - 函数签名遵循系统标准。 + 使用 CNN-BiLSTM-Attention 模型进行训练,并实现早停和最佳模型保存。 """ print(f"🚀 CNN-BiLSTM-Attention 训练器启动: model_identifier='{model_identifier}'") + start_time = time.time() # --- 1. 数据准备 --- features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'] + product_name = product_df['product_name'].iloc[0] if 'product_name' in product_df.columns else model_identifier X = product_df[features].values y = product_df[['sales']].values @@ -42,7 +46,6 @@ def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, st trainX, trainY = create_dataset(X_train_raw, y_train_raw, sequence_length, forecast_horizon) testX, testY = create_dataset(X_test_raw, y_test_raw, sequence_length, forecast_horizon) - # 转换为 PyTorch Tensors trainX = torch.from_numpy(trainX).float() trainY = torch.from_numpy(trainY).float() testX = torch.from_numpy(testX).float() @@ -60,22 +63,56 @@ def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, st optimizer = optim.Adam(model.parameters(), lr=kwargs.get('learning_rate', 0.001)) criterion = torch.nn.MSELoss() - # --- 3. 训练循环 --- - print("开始训练 CNN-BiLSTM-Attention 模型...") + # --- 3. 训练循环与早停 --- + print("开始训练 CNN-BiLSTM-Attention 模型 (含早停)...") + loss_history = {'train': [], 'val': []} + best_val_loss = float('inf') + best_model_state = None + patience = kwargs.get('patience', 15) + patience_counter = 0 + for epoch in range(epochs): model.train() optimizer.zero_grad() outputs = model(trainX) - loss = criterion(outputs, trainY.squeeze(-1)) # 确保目标维度匹配 + train_loss = criterion(outputs, trainY.squeeze(-1)) - loss.backward() + train_loss.backward() optimizer.step() + # 验证 + model.eval() + with torch.no_grad(): + val_outputs = model(testX) + val_loss = criterion(val_outputs, testY.squeeze(-1)) + + loss_history['train'].append(train_loss.item()) + loss_history['val'].append(val_loss.item()) + if (epoch + 1) % 10 == 0: - print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}') + print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss.item():.4f}, Val Loss: {val_loss.item():.4f}') + + # 早停逻辑 + if val_loss.item() < best_val_loss: + best_val_loss = val_loss.item() + best_model_state = copy.deepcopy(model.state_dict()) + patience_counter = 0 + print(f"✨ 新的最佳模型! Epoch: {epoch+1}, Val Loss: {best_val_loss:.4f}") + else: + patience_counter += 1 + if patience_counter >= patience: + print(f"🚫 早停触发! 在 epoch {epoch+1} 停止。") + break + + training_time = time.time() - start_time + print(f"模型训练完成,耗时: {training_time:.2f}秒") + + # --- 4. 使用最佳模型进行评估 --- + if best_model_state: + model.load_state_dict(best_model_state) + print("最佳模型已加载用于最终评估。") - # --- 4. 模型评估 --- model.eval() with torch.no_grad(): test_pred_scaled = model(testX) @@ -84,11 +121,26 @@ def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, st test_true_unscaled = scaler_y.inverse_transform(testY.squeeze(-1).numpy()) metrics = evaluate_model(test_true_unscaled.flatten(), test_pred_unscaled.flatten()) - print(f"模型评估完成: RMSE={metrics['rmse']:.4f}") + metrics['training_time'] = training_time + metrics['best_val_loss'] = best_val_loss + metrics['stopped_epoch'] = epoch + 1 + + print("\n最佳模型评估指标:") + print(f"MSE: {metrics['mse']:.4f}, RMSE: {metrics['rmse']:.4f}, MAE: {metrics['mae']:.4f}, R²: {metrics['r2']:.4f}, MAPE: {metrics['mape']:.2f}%") + + # 绘制损失曲线 + loss_curve_path = plot_loss_curve( + loss_history['train'], + loss_history['val'], + product_name, + 'cnn_bilstm_attention', + model_dir=model_dir + ) + print(f"📈 损失曲线已保存到: {loss_curve_path}") # --- 5. 模型保存 --- model_data = { - 'model_state_dict': model.state_dict(), + 'model_state_dict': best_model_state, # 保存最佳模型的状态 'scaler_X': scaler_X, 'scaler_y': scaler_y, 'config': { @@ -98,9 +150,12 @@ def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, st 'sequence_length': sequence_length, 'features': features }, - 'metrics': metrics + 'metrics': metrics, + 'loss_history': loss_history, # 保存损失历史 + 'loss_curve_path': loss_curve_path # 添加损失图路径 } + # 保存最终版本模型 final_model_path, final_version = model_manager.save_model( model_data=model_data, product_id=product_id, @@ -108,10 +163,23 @@ def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, st store_id=store_id, training_mode=training_mode, aggregation_method=aggregation_method, - product_name=product_df['product_name'].iloc[0] + product_name=product_name ) + print(f"✅ CNN-BiLSTM-Attention 最终模型已保存,版本: {final_version}") + + # 保存最佳版本模型 + best_model_path, best_version = model_manager.save_model( + model_data=model_data, + product_id=product_id, + model_type='cnn_bilstm_attention', + store_id=store_id, + training_mode=training_mode, + aggregation_method=aggregation_method, + product_name=product_name, + version='best' # 明确指定版本为 'best' + ) + print(f"✅ CNN-BiLSTM-Attention 最佳模型已保存,版本: {best_version}") - print(f"✅ CNN-BiLSTM-Attention 模型已保存,版本: {final_version}") return model, metrics, final_version, final_model_path # --- 关键步骤: 将训练器注册到系统中 --- diff --git a/server/trainers/xgboost_trainer.py b/server/trainers/xgboost_trainer.py index 9dff330..425c3bf 100644 --- a/server/trainers/xgboost_trainer.py +++ b/server/trainers/xgboost_trainer.py @@ -14,6 +14,7 @@ from utils.data_utils import create_dataset from analysis.metrics import evaluate_model from utils.model_manager import model_manager from models.model_registry import register_trainer +from utils.visualization import plot_loss_curve # 导入绘图函数 def train_product_model_with_xgboost(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs): """ @@ -91,7 +92,7 @@ def train_product_model_with_xgboost(product_id, model_identifier, product_df, s training_time = time.time() - start_time print(f"XGBoost模型训练完成,耗时: {training_time:.2f}秒") - # --- 4. 模型评估 --- + # --- 4. 模型评估与可视化 --- # 使用 model.best_iteration 获取最佳轮次的预测结果 test_pred = model.predict(dtest, iteration_range=(0, model.best_iteration)) @@ -100,13 +101,24 @@ def train_product_model_with_xgboost(product_id, model_identifier, product_df, s metrics = evaluate_model(test_true_inv.flatten(), test_pred_inv.flatten()) metrics['training_time'] = training_time + metrics['best_iteration'] = model.best_iteration print("\n模型评估指标:") print(f"MSE: {metrics['mse']:.4f}, RMSE: {metrics['rmse']:.4f}, MAE: {metrics['mae']:.4f}, R²: {metrics['r2']:.4f}, MAPE: {metrics['mape']:.2f}%") + # 提取损失并绘制曲线 + train_losses = evals_result['train']['rmse'] + test_losses = evals_result['test']['rmse'] + loss_curve_path = plot_loss_curve( + train_losses, + test_losses, + product_name, + 'xgboost', + model_dir=model_dir + ) + print(f"📈 损失曲线已保存到: {loss_curve_path}") + # --- 5. 模型保存 (借道 utils.model_manager) --- - # **关键适配点**: 我们将完整的XGBoost模型对象存入字典 - # torch.save 可以序列化多种Python对象,包括sklearn模型 model_data = { 'model_state_dict': model, # 直接保存模型对象 'scaler_X': scaler_X, @@ -119,10 +131,11 @@ def train_product_model_with_xgboost(product_id, model_identifier, product_df, s 'xgb_params': xgb_params }, 'metrics': metrics, - 'loss_history': evals_result + 'loss_history': evals_result, + 'loss_curve_path': loss_curve_path # 添加损失图路径 } - # 调用全局管理器进行保存,复用其命名和版本逻辑 + # 保存最终版本模型 final_model_path, final_version = model_manager.save_model( model_data=model_data, product_id=product_id, @@ -132,8 +145,20 @@ def train_product_model_with_xgboost(product_id, model_identifier, product_df, s aggregation_method=aggregation_method, product_name=product_name ) - - print(f"XGBoost模型已通过统一管理器保存,版本: {final_version}, 路径: {final_model_path}") + print(f"✅ XGBoost最终模型已通过统一管理器保存,版本: {final_version}, 路径: {final_model_path}") + + # 保存最佳版本模型 + best_model_path, best_version = model_manager.save_model( + model_data=model_data, + product_id=product_id, + model_type='xgboost', + store_id=store_id, + training_mode=training_mode, + aggregation_method=aggregation_method, + product_name=product_name, + version='best' # 明确指定版本为 'best' + ) + print(f"✅ XGBoost最佳模型已通过统一管理器保存,版本: {best_version}, 路径: {best_model_path}") # 返回值遵循统一格式 return model, metrics, final_version, final_model_path