ShopTRAINING/server/trainers/xgboost_trainer.py
xz2000 9d7dcae1c8 一、使用Swagger UI 展示药店销售预测系统API
二、完成新增模型xgboost,cnn_bilstm_attention的训练,预测
2025-07-23 16:58:20 +08:00

167 lines
6.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
药店销售预测系统 - XGBoost 模型训练器 (插件式)
"""
import time
import pandas as pd
import numpy as np
import xgboost as xgb
from sklearn.preprocessing import MinMaxScaler
from xgboost.callback import EarlyStopping
# 导入核心工具
from utils.data_utils import create_dataset
from analysis.metrics import evaluate_model
from utils.model_manager import model_manager
from models.model_registry import register_trainer
from utils.visualization import plot_loss_curve # 导入绘图函数
def train_product_model_with_xgboost(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs):
"""
使用 XGBoost 模型训练产品销售预测模型。
此函数签名与其他训练器保持一致,以兼容注册表调用。
"""
print(f"🚀 XGBoost训练器启动: model_identifier='{model_identifier}'")
# --- 1. 数据准备和验证 ---
if product_df.empty:
raise ValueError(f"产品 {product_id} 没有可用的销售数据")
min_required_samples = sequence_length + forecast_horizon
if len(product_df) < min_required_samples:
error_msg = (f"数据不足: 需要 {min_required_samples} 条, 实际 {len(product_df)} 条。")
raise ValueError(error_msg)
product_df = product_df.sort_values('date')
product_name = product_df['product_name'].iloc[0] if 'product_name' in product_df.columns else model_identifier
# --- 2. 数据预处理和适配 ---
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
X = product_df[features].values
y = product_df[['sales']].values
scaler_X = MinMaxScaler(feature_range=(0, 1))
scaler_y = MinMaxScaler(feature_range=(0, 1))
X_scaled = scaler_X.fit_transform(X)
y_scaled = scaler_y.fit_transform(y)
train_size = int(len(X_scaled) * 0.8)
X_train_raw, X_test_raw = X_scaled[:train_size], X_scaled[train_size:]
y_train_raw, y_test_raw = y_scaled[:train_size], y_scaled[train_size:]
trainX, trainY = create_dataset(X_train_raw, y_train_raw, sequence_length, forecast_horizon)
testX, testY = create_dataset(X_test_raw, y_test_raw, sequence_length, forecast_horizon)
# **关键适配步骤**: XGBoost 需要二维输入
trainX = trainX.reshape(trainX.shape[0], -1)
testX = testX.reshape(testX.shape[0], -1)
# **关键适配**: 转换为 XGBoost 核心 DMatrix 格式,以使用稳定的 xgb.train API
dtrain = xgb.DMatrix(trainX, label=trainY)
dtest = xgb.DMatrix(testX, label=testY)
# --- 3. 模型训练 (使用核心 xgb.train API) ---
xgb_params = {
'learning_rate': kwargs.get('learning_rate', 0.08),
'subsample': kwargs.get('subsample', 0.75),
'colsample_bytree': kwargs.get('colsample_bytree', 1),
'max_depth': kwargs.get('max_depth', 7),
'gamma': kwargs.get('gamma', 0),
'objective': 'reg:squarederror',
'eval_metric': 'rmse', # eval_metric 在这里是原生支持的
'n_jobs': -1
}
n_estimators = kwargs.get('n_estimators', 500)
print("开始训练XGBoost模型 (使用核心xgb.train API)...")
start_time = time.time()
evals_result = {}
model = xgb.train(
params=xgb_params,
dtrain=dtrain,
num_boost_round=n_estimators,
evals=[(dtrain, 'train'), (dtest, 'test')],
early_stopping_rounds=50, # early_stopping_rounds 在这里是原生支持的
evals_result=evals_result,
verbose_eval=False
)
training_time = time.time() - start_time
print(f"XGBoost模型训练完成耗时: {training_time:.2f}")
# --- 4. 模型评估与可视化 ---
# 使用 model.best_iteration 获取最佳轮次的预测结果
test_pred = model.predict(dtest, iteration_range=(0, model.best_iteration))
test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, forecast_horizon))
test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, forecast_horizon))
metrics = evaluate_model(test_true_inv.flatten(), test_pred_inv.flatten())
metrics['training_time'] = training_time
metrics['best_iteration'] = model.best_iteration
print("\n模型评估指标:")
print(f"MSE: {metrics['mse']:.4f}, RMSE: {metrics['rmse']:.4f}, MAE: {metrics['mae']:.4f}, R²: {metrics['r2']:.4f}, MAPE: {metrics['mape']:.2f}%")
# 提取损失并绘制曲线
train_losses = evals_result['train']['rmse']
test_losses = evals_result['test']['rmse']
loss_curve_path = plot_loss_curve(
train_losses,
test_losses,
product_name,
'xgboost',
model_dir=model_dir
)
print(f"📈 损失曲线已保存到: {loss_curve_path}")
# --- 5. 模型保存 (借道 utils.model_manager) ---
model_data = {
'model_state_dict': model, # 直接保存模型对象
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
'sequence_length': sequence_length,
'forecast_horizon': forecast_horizon,
'model_type': 'xgboost',
'features': features,
'xgb_params': xgb_params
},
'metrics': metrics,
'loss_history': evals_result,
'loss_curve_path': loss_curve_path # 添加损失图路径
}
# 保存最终版本模型
final_model_path, final_version = model_manager.save_model(
model_data=model_data,
product_id=product_id,
model_type='xgboost',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name
)
print(f"✅ XGBoost最终模型已通过统一管理器保存版本: {final_version}, 路径: {final_model_path}")
# 保存最佳版本模型
best_model_path, best_version = model_manager.save_model(
model_data=model_data,
product_id=product_id,
model_type='xgboost',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name,
version='best' # 明确指定版本为 'best'
)
print(f"✅ XGBoost最佳模型已通过统一管理器保存版本: {best_version}, 路径: {best_model_path}")
# 返回值遵循统一格式
return model, metrics, final_version, final_model_path
# --- 将此训练器注册到系统中 ---
register_trainer('xgboost', train_product_model_with_xgboost)