一、使用Swagger UI 展示药店销售预测系统API

二、完成新增模型xgboost,cnn_bilstm_attention的训练,预测
This commit is contained in:
xz2000 2025-07-23 16:55:27 +08:00
parent af3d174ac6
commit 9d7dcae1c8
3 changed files with 579 additions and 21 deletions

465
server/swagger.json Normal file
View File

@ -0,0 +1,465 @@
{
"openapi": "3.0.0",
"info": {
"title": "药店销售预测系统API",
"description": "用于药店销售预测的RESTful API",
"version": "1.0.0",
"contact": {
"name": "API开发团队",
"email": "support@example.com"
}
},
"tags": [
{
"name": "数据管理",
"description": "数据上传和查询相关接口"
},
{
"name": "模型训练",
"description": "模型训练相关接口"
},
{
"name": "模型预测",
"description": "预测销售数据相关接口"
},
{
"name": "模型管理",
"description": "模型查询、导出和删除接口"
}
],
"paths": {
"/api/products": {
"get": {
"tags": ["数据管理"],
"summary": "获取所有产品列表",
"description": "返回系统中所有产品的ID和名称",
"responses": {
"200": {
"description": "成功获取产品列表",
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"status": {"type": "string"},
"data": {
"type": "array",
"items": {
"type": "object",
"properties": {
"product_id": {"type": "string"},
"product_name": {"type": "string"}
}
}
}
}
}
}
}
},
"500": {"description": "服务器内部错误"}
}
}
},
"/api/products/{product_id}": {
"get": {
"tags": ["数据管理"],
"summary": "获取单个产品详情",
"description": "返回指定产品ID的详细信息",
"parameters": [
{
"name": "product_id",
"in": "path",
"required": true,
"schema": {"type": "string"},
"description": "产品ID例如P001"
}
],
"responses": {
"200": {"description": "成功获取产品详情"},
"404": {"description": "产品不存在"},
"500": {"description": "服务器内部错误"}
}
}
},
"/api/products/{product_id}/sales": {
"get": {
"tags": ["数据管理"],
"summary": "获取产品销售数据",
"description": "返回指定产品在特定日期范围内的销售数据",
"parameters": [
{
"name": "product_id",
"in": "path",
"required": true,
"schema": {"type": "string"},
"description": "产品ID例如P001"
},
{
"name": "start_date",
"in": "query",
"schema": {"type": "string"},
"description": "开始日期格式为YYYY-MM-DD"
},
{
"name": "end_date",
"in": "query",
"schema": {"type": "string"},
"description": "结束日期格式为YYYY-MM-DD"
}
],
"responses": {
"200": {"description": "成功获取销售数据"},
"404": {"description": "产品不存在"},
"500": {"description": "服务器内部错误"}
}
}
},
"/api/data/upload": {
"post": {
"tags": ["数据管理"],
"summary": "上传销售数据",
"description": "上传新的销售数据文件(Excel格式)",
"requestBody": {
"content": {
"multipart/form-data": {
"schema": {
"type": "object",
"properties": {
"file": {
"type": "string",
"format": "binary"
}
}
}
}
}
},
"responses": {
"200": {"description": "数据上传成功"},
"400": {"description": "请求错误"},
"500": {"description": "服务器内部错误"}
}
}
},
"/api/training": {
"get": {
"tags": ["模型训练"],
"summary": "获取所有训练任务列表",
"description": "返回所有正在进行、已完成或失败的训练任务",
"responses": {
"200": {"description": "成功获取任务列表"}
}
},
"post": {
"tags": ["模型训练"],
"summary": "启动模型训练任务",
"description": "为指定产品启动一个新的模型训练任务",
"requestBody": {
"required": true,
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"product_id": {"type": "string", "example": "P001"},
"model_type": {"type": "string", "enum": ["mlstm", "transformer", "kan", "optimized_kan", "tcn", "xgboost"]},
"store_id": {"type": "string", "example": "S001"},
"epochs": {"type": "integer", "default": 50}
},
"required": ["product_id", "model_type"]
}
}
}
},
"responses": {
"200": {"description": "训练任务已启动"},
"400": {"description": "请求错误"}
}
}
},
"/api/training/{task_id}": {
"get": {
"tags": ["模型训练"],
"summary": "查询训练任务状态",
"description": "获取特定训练任务的当前状态和详情",
"parameters": [
{
"name": "task_id",
"in": "path",
"required": true,
"schema": {"type": "string"},
"description": "训练任务ID"
}
],
"responses": {
"200": {"description": "成功获取任务状态"},
"404": {"description": "任务不存在"},
"500": {"description": "服务器内部错误"}
}
}
},
"/api/prediction": {
"post": {
"tags": ["模型预测"],
"summary": "使用模型进行预测",
"description": "使用指定模型预测未来销售数据",
"requestBody": {
"required": true,
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"product_id": {"type": "string"},
"model_type": {"type": "string", "enum": ["mlstm", "transformer", "kan", "optimized_kan", "tcn"]},
"store_id": {"type": "string"},
"version": {"type": "string"},
"future_days": {"type": "integer"},
"include_visualization": {"type": "boolean"},
"start_date": {"type": "string"}
},
"required": ["product_id", "model_type"]
}
}
}
},
"responses": {
"200": {"description": "预测成功"},
"400": {"description": "请求错误"},
"404": {"description": "产品或模型不存在"},
"500": {"description": "服务器内部错误"}
}
}
},
"/api/prediction/compare": {
"post": {
"tags": ["模型预测"],
"summary": "比较不同模型预测结果",
"description": "比较不同模型对同一产品的预测结果",
"requestBody": {
"required": true,
"content": {
"application/json": {
"schema": {
"type": "object",
"properties": {
"product_id": {"type": "string"},
"model_types": {"type": "array", "items": {"type": "string"}},
"versions": {"type": "array", "items": {"type": "string"}},
"include_visualization": {"type": "boolean"}
},
"required": ["product_id", "model_types"]
}
}
}
},
"responses": {
"200": {"description": "比较成功"},
"400": {"description": "请求错误"},
"404": {"description": "产品或模型不存在"},
"500": {"description": "服务器内部错误"}
}
}
},
"/api/prediction/history": {
"get": {
"tags": ["模型预测"],
"summary": "获取历史预测记录",
"responses": {
"200": {"description": "获取成功"}
}
}
},
"/api/prediction/history/{prediction_id}": {
"get": {
"tags": ["模型预测"],
"summary": "获取特定预测记录的详情",
"parameters": [
{
"name": "prediction_id",
"in": "path",
"required": true,
"schema": {"type": "string"}
}
],
"responses": {
"200": {"description": "获取成功"},
"404": {"description": "记录不存在"}
}
},
"delete": {
"tags": ["模型预测"],
"summary": "删除预测记录",
"parameters": [
{
"name": "prediction_id",
"in": "path",
"required": true,
"schema": {"type": "string"}
}
],
"responses": {
"200": {"description": "删除成功"},
"404": {"description": "记录不存在"}
}
}
},
"/api/models": {
"get": {
"tags": ["模型管理"],
"summary": "获取模型列表",
"parameters": [
{"name": "product_id", "in": "query", "schema": {"type": "string"}},
{"name": "model_type", "in": "query", "schema": {"type": "string"}}
],
"responses": {
"200": {"description": "获取成功"}
}
}
},
"/api/models/{model_id}": {
"get": {
"tags": ["模型管理"],
"summary": "获取模型详情",
"parameters": [
{"name": "model_id", "in": "path", "required": true, "schema": {"type": "string"}}
],
"responses": {
"200": {"description": "获取成功"},
"404": {"description": "模型不存在"}
}
},
"delete": {
"tags": ["模型管理"],
"summary": "删除模型",
"parameters": [
{"name": "model_id", "in": "path", "required": true, "schema": {"type": "string"}}
],
"responses": {
"200": {"description": "删除成功"},
"404": {"description": "模型不存在"}
}
}
},
"/api/models/{model_id}/export": {
"get": {
"tags": ["模型管理"],
"summary": "导出模型",
"parameters": [
{"name": "model_id", "in": "path", "required": true, "schema": {"type": "string"}}
],
"responses": {
"200": {"description": "模型文件下载"},
"404": {"description": "模型不存在"}
}
}
},
"/api/model_types": {
"get": {
"tags": ["模型管理"],
"summary": "获取系统支持的所有模型类型",
"responses": {
"200": {"description": "获取成功"}
}
}
},
"/api/models/{product_id}/{model_type}/versions": {
"get": {
"tags": ["模型管理"],
"summary": "获取模型版本列表",
"parameters": [
{"name": "product_id", "in": "path", "required": true, "schema": {"type": "string"}},
{"name": "model_type", "in": "path", "required": true, "schema": {"type": "string"}}
],
"responses": {
"200": {"description": "获取成功"}
}
}
},
"/api/stores": {
"get": {
"tags": ["数据管理"],
"summary": "获取所有店铺列表",
"responses": {
"200": {"description": "获取成功"}
}
},
"post": {
"tags": ["数据管理"],
"summary": "创建新店铺",
"responses": {
"200": {"description": "创建成功"}
}
}
},
"/api/stores/{store_id}": {
"get": {
"tags": ["数据管理"],
"summary": "获取单个店铺信息",
"parameters": [
{"name": "store_id", "in": "path", "required": true, "schema": {"type": "string"}}
],
"responses": {
"200": {"description": "获取成功"},
"404": {"description": "店铺不存在"}
}
},
"put": {
"tags": ["数据管理"],
"summary": "更新店铺信息",
"parameters": [
{"name": "store_id", "in": "path", "required": true, "schema": {"type": "string"}}
],
"responses": {
"200": {"description": "更新成功"},
"404": {"description": "店铺不存在"}
}
},
"delete": {
"tags": ["数据管理"],
"summary": "删除店铺",
"parameters": [
{"name": "store_id", "in": "path", "required": true, "schema": {"type": "string"}}
],
"responses": {
"200": {"description": "删除成功"},
"404": {"description": "店铺不存在"}
}
}
},
"/api/stores/{store_id}/products": {
"get": {
"tags": ["数据管理"],
"summary": "获取店铺的产品列表",
"parameters": [
{"name": "store_id", "in": "path", "required": true, "schema": {"type": "string"}}
],
"responses": {
"200": {"description": "获取成功"}
}
}
},
"/api/stores/{store_id}/statistics": {
"get": {
"tags": ["数据管理"],
"summary": "获取店铺销售统计信息",
"parameters": [
{"name": "store_id", "in": "path", "required": true, "schema": {"type": "string"}}
],
"responses": {
"200": {"description": "获取成功"}
}
}
},
"/api/sales/data": {
"get": {
"tags": ["数据管理"],
"summary": "获取销售数据列表",
"responses": {
"200": {"description": "获取成功"}
}
}
}
}
}

View File

@ -6,25 +6,29 @@ CNN-BiLSTM-Attention 模型训练器
import torch
import torch.optim as optim
import numpy as np
import time
import copy
from models.model_registry import register_trainer
from utils.model_manager import model_manager
from analysis.metrics import evaluate_model
from utils.data_utils import create_dataset
from sklearn.preprocessing import MinMaxScaler
from utils.visualization import plot_loss_curve # 导入绘图函数
# 导入新创建的模型
from models.cnn_bilstm_attention import CnnBiLstmAttention
def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs):
"""
使用 CNN-BiLSTM-Attention 模型进行训练
函数签名遵循系统标准
使用 CNN-BiLSTM-Attention 模型进行训练并实现早停和最佳模型保存
"""
print(f"🚀 CNN-BiLSTM-Attention 训练器启动: model_identifier='{model_identifier}'")
start_time = time.time()
# --- 1. 数据准备 ---
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
product_name = product_df['product_name'].iloc[0] if 'product_name' in product_df.columns else model_identifier
X = product_df[features].values
y = product_df[['sales']].values
@ -42,7 +46,6 @@ def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, st
trainX, trainY = create_dataset(X_train_raw, y_train_raw, sequence_length, forecast_horizon)
testX, testY = create_dataset(X_test_raw, y_test_raw, sequence_length, forecast_horizon)
# 转换为 PyTorch Tensors
trainX = torch.from_numpy(trainX).float()
trainY = torch.from_numpy(trainY).float()
testX = torch.from_numpy(testX).float()
@ -60,22 +63,56 @@ def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, st
optimizer = optim.Adam(model.parameters(), lr=kwargs.get('learning_rate', 0.001))
criterion = torch.nn.MSELoss()
# --- 3. 训练循环 ---
print("开始训练 CNN-BiLSTM-Attention 模型...")
# --- 3. 训练循环与早停 ---
print("开始训练 CNN-BiLSTM-Attention 模型 (含早停)...")
loss_history = {'train': [], 'val': []}
best_val_loss = float('inf')
best_model_state = None
patience = kwargs.get('patience', 15)
patience_counter = 0
for epoch in range(epochs):
model.train()
optimizer.zero_grad()
outputs = model(trainX)
loss = criterion(outputs, trainY.squeeze(-1)) # 确保目标维度匹配
train_loss = criterion(outputs, trainY.squeeze(-1))
loss.backward()
train_loss.backward()
optimizer.step()
# 验证
model.eval()
with torch.no_grad():
val_outputs = model(testX)
val_loss = criterion(val_outputs, testY.squeeze(-1))
loss_history['train'].append(train_loss.item())
loss_history['val'].append(val_loss.item())
if (epoch + 1) % 10 == 0:
print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss.item():.4f}, Val Loss: {val_loss.item():.4f}')
# 早停逻辑
if val_loss.item() < best_val_loss:
best_val_loss = val_loss.item()
best_model_state = copy.deepcopy(model.state_dict())
patience_counter = 0
print(f"✨ 新的最佳模型! Epoch: {epoch+1}, Val Loss: {best_val_loss:.4f}")
else:
patience_counter += 1
if patience_counter >= patience:
print(f"🚫 早停触发! 在 epoch {epoch+1} 停止。")
break
training_time = time.time() - start_time
print(f"模型训练完成,耗时: {training_time:.2f}")
# --- 4. 使用最佳模型进行评估 ---
if best_model_state:
model.load_state_dict(best_model_state)
print("最佳模型已加载用于最终评估。")
# --- 4. 模型评估 ---
model.eval()
with torch.no_grad():
test_pred_scaled = model(testX)
@ -84,11 +121,26 @@ def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, st
test_true_unscaled = scaler_y.inverse_transform(testY.squeeze(-1).numpy())
metrics = evaluate_model(test_true_unscaled.flatten(), test_pred_unscaled.flatten())
print(f"模型评估完成: RMSE={metrics['rmse']:.4f}")
metrics['training_time'] = training_time
metrics['best_val_loss'] = best_val_loss
metrics['stopped_epoch'] = epoch + 1
print("\n最佳模型评估指标:")
print(f"MSE: {metrics['mse']:.4f}, RMSE: {metrics['rmse']:.4f}, MAE: {metrics['mae']:.4f}, R²: {metrics['r2']:.4f}, MAPE: {metrics['mape']:.2f}%")
# 绘制损失曲线
loss_curve_path = plot_loss_curve(
loss_history['train'],
loss_history['val'],
product_name,
'cnn_bilstm_attention',
model_dir=model_dir
)
print(f"📈 损失曲线已保存到: {loss_curve_path}")
# --- 5. 模型保存 ---
model_data = {
'model_state_dict': model.state_dict(),
'model_state_dict': best_model_state, # 保存最佳模型的状态
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
@ -98,9 +150,12 @@ def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, st
'sequence_length': sequence_length,
'features': features
},
'metrics': metrics
'metrics': metrics,
'loss_history': loss_history, # 保存损失历史
'loss_curve_path': loss_curve_path # 添加损失图路径
}
# 保存最终版本模型
final_model_path, final_version = model_manager.save_model(
model_data=model_data,
product_id=product_id,
@ -108,10 +163,23 @@ def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, st
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_df['product_name'].iloc[0]
product_name=product_name
)
print(f"✅ CNN-BiLSTM-Attention 最终模型已保存,版本: {final_version}")
# 保存最佳版本模型
best_model_path, best_version = model_manager.save_model(
model_data=model_data,
product_id=product_id,
model_type='cnn_bilstm_attention',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name,
version='best' # 明确指定版本为 'best'
)
print(f"✅ CNN-BiLSTM-Attention 最佳模型已保存,版本: {best_version}")
print(f"✅ CNN-BiLSTM-Attention 模型已保存,版本: {final_version}")
return model, metrics, final_version, final_model_path
# --- 关键步骤: 将训练器注册到系统中 ---

View File

@ -14,6 +14,7 @@ from utils.data_utils import create_dataset
from analysis.metrics import evaluate_model
from utils.model_manager import model_manager
from models.model_registry import register_trainer
from utils.visualization import plot_loss_curve # 导入绘图函数
def train_product_model_with_xgboost(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs):
"""
@ -91,7 +92,7 @@ def train_product_model_with_xgboost(product_id, model_identifier, product_df, s
training_time = time.time() - start_time
print(f"XGBoost模型训练完成耗时: {training_time:.2f}")
# --- 4. 模型评估 ---
# --- 4. 模型评估与可视化 ---
# 使用 model.best_iteration 获取最佳轮次的预测结果
test_pred = model.predict(dtest, iteration_range=(0, model.best_iteration))
@ -100,13 +101,24 @@ def train_product_model_with_xgboost(product_id, model_identifier, product_df, s
metrics = evaluate_model(test_true_inv.flatten(), test_pred_inv.flatten())
metrics['training_time'] = training_time
metrics['best_iteration'] = model.best_iteration
print("\n模型评估指标:")
print(f"MSE: {metrics['mse']:.4f}, RMSE: {metrics['rmse']:.4f}, MAE: {metrics['mae']:.4f}, R²: {metrics['r2']:.4f}, MAPE: {metrics['mape']:.2f}%")
# 提取损失并绘制曲线
train_losses = evals_result['train']['rmse']
test_losses = evals_result['test']['rmse']
loss_curve_path = plot_loss_curve(
train_losses,
test_losses,
product_name,
'xgboost',
model_dir=model_dir
)
print(f"📈 损失曲线已保存到: {loss_curve_path}")
# --- 5. 模型保存 (借道 utils.model_manager) ---
# **关键适配点**: 我们将完整的XGBoost模型对象存入字典
# torch.save 可以序列化多种Python对象包括sklearn模型
model_data = {
'model_state_dict': model, # 直接保存模型对象
'scaler_X': scaler_X,
@ -119,10 +131,11 @@ def train_product_model_with_xgboost(product_id, model_identifier, product_df, s
'xgb_params': xgb_params
},
'metrics': metrics,
'loss_history': evals_result
'loss_history': evals_result,
'loss_curve_path': loss_curve_path # 添加损失图路径
}
# 调用全局管理器进行保存,复用其命名和版本逻辑
# 保存最终版本模型
final_model_path, final_version = model_manager.save_model(
model_data=model_data,
product_id=product_id,
@ -132,8 +145,20 @@ def train_product_model_with_xgboost(product_id, model_identifier, product_df, s
aggregation_method=aggregation_method,
product_name=product_name
)
print(f"XGBoost模型已通过统一管理器保存版本: {final_version}, 路径: {final_model_path}")
print(f"✅ XGBoost最终模型已通过统一管理器保存版本: {final_version}, 路径: {final_model_path}")
# 保存最佳版本模型
best_model_path, best_version = model_manager.save_model(
model_data=model_data,
product_id=product_id,
model_type='xgboost',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name,
version='best' # 明确指定版本为 'best'
)
print(f"✅ XGBoost最佳模型已通过统一管理器保存版本: {best_version}, 路径: {best_model_path}")
# 返回值遵循统一格式
return model, metrics, final_version, final_model_path