ShopTRAINING/server/trainers/xgboost_trainer.py

193 lines
7.8 KiB
Python
Raw Normal View History

2025-07-22 15:40:37 +08:00
"""
药店销售预测系统 - XGBoost 模型训练器 (插件式)
"""
import time
import pandas as pd
import numpy as np
import xgboost as xgb
from sklearn.preprocessing import MinMaxScaler
from xgboost.callback import EarlyStopping
# 导入核心工具
from utils.data_utils import create_dataset
2025-07-26 14:41:41 +08:00
from utils.new_data_loader import load_new_data
2025-07-22 15:40:37 +08:00
from analysis.metrics import evaluate_model
from utils.model_manager import model_manager
from models.model_registry import register_trainer
from utils.visualization import plot_loss_curve # 导入绘图函数
2025-07-22 15:40:37 +08:00
def train_product_model_with_xgboost(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs):
"""
使用 XGBoost 模型训练产品销售预测模型
此函数签名与其他训练器保持一致以兼容注册表调用
"""
print(f"🚀 XGBoost训练器启动: model_identifier='{model_identifier}'")
2025-07-26 14:41:41 +08:00
# --- 1. 数据加载与筛选重构 ---
if product_df is None:
print("正在使用新的统一数据加载器...")
full_df = load_new_data()
if training_mode == 'store' and store_id:
store_df = full_df[full_df['store_id'] == store_id].copy()
if product_id and product_id != 'unknown' and product_id != 'all_products':
product_df = store_df[store_df['product_id'] == product_id].copy()
else:
product_df = store_df.groupby('date').agg({
'sales': 'sum', 'weekday': 'first', 'month': 'first',
'is_holiday': 'first', 'is_weekend': 'first',
'is_promotion': 'first', 'temperature': 'mean'
}).reset_index()
product_df.fillna(0, inplace=True)
elif training_mode == 'global':
product_df = full_df[full_df['product_id'] == product_id].copy()
product_df = product_df.groupby('date').agg({
'sales': 'sum', 'weekday': 'first', 'month': 'first',
'is_holiday': 'first', 'is_weekend': 'first',
'is_promotion': 'first', 'temperature': 'mean'
}).reset_index()
product_df.fillna(0, inplace=True)
else: # 默认 'product' 模式
product_df = full_df[full_df['product_id'] == product_id].copy()
2025-07-22 15:40:37 +08:00
min_required_samples = sequence_length + forecast_horizon
if len(product_df) < min_required_samples:
error_msg = (f"数据不足: 需要 {min_required_samples} 条, 实际 {len(product_df)} 条。")
raise ValueError(error_msg)
product_df = product_df.sort_values('date')
2025-07-26 14:41:41 +08:00
product_name = f"产品 {product_id}"
if 'product_name' in product_df.columns and not product_df['product_name'].empty:
product_name = product_df['product_name'].iloc[0]
2025-07-22 15:40:37 +08:00
# --- 2. 数据预处理和适配 ---
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
X = product_df[features].values
y = product_df[['sales']].values
scaler_X = MinMaxScaler(feature_range=(0, 1))
scaler_y = MinMaxScaler(feature_range=(0, 1))
X_scaled = scaler_X.fit_transform(X)
y_scaled = scaler_y.fit_transform(y)
train_size = int(len(X_scaled) * 0.8)
X_train_raw, X_test_raw = X_scaled[:train_size], X_scaled[train_size:]
y_train_raw, y_test_raw = y_scaled[:train_size], y_scaled[train_size:]
trainX, trainY = create_dataset(X_train_raw, y_train_raw, sequence_length, forecast_horizon)
testX, testY = create_dataset(X_test_raw, y_test_raw, sequence_length, forecast_horizon)
# **关键适配步骤**: XGBoost 需要二维输入
trainX = trainX.reshape(trainX.shape[0], -1)
testX = testX.reshape(testX.shape[0], -1)
# **关键适配**: 转换为 XGBoost 核心 DMatrix 格式,以使用稳定的 xgb.train API
dtrain = xgb.DMatrix(trainX, label=trainY)
dtest = xgb.DMatrix(testX, label=testY)
# --- 3. 模型训练 (使用核心 xgb.train API) ---
xgb_params = {
'learning_rate': kwargs.get('learning_rate', 0.08),
'subsample': kwargs.get('subsample', 0.75),
'colsample_bytree': kwargs.get('colsample_bytree', 1),
'max_depth': kwargs.get('max_depth', 7),
'gamma': kwargs.get('gamma', 0),
'objective': 'reg:squarederror',
'eval_metric': 'rmse', # eval_metric 在这里是原生支持的
'n_jobs': -1
}
n_estimators = kwargs.get('n_estimators', 500)
print("开始训练XGBoost模型 (使用核心xgb.train API)...")
start_time = time.time()
evals_result = {}
model = xgb.train(
params=xgb_params,
dtrain=dtrain,
num_boost_round=n_estimators,
evals=[(dtrain, 'train'), (dtest, 'test')],
early_stopping_rounds=50, # early_stopping_rounds 在这里是原生支持的
evals_result=evals_result,
verbose_eval=False
)
training_time = time.time() - start_time
print(f"XGBoost模型训练完成耗时: {training_time:.2f}")
# --- 4. 模型评估与可视化 ---
2025-07-22 15:40:37 +08:00
# 使用 model.best_iteration 获取最佳轮次的预测结果
test_pred = model.predict(dtest, iteration_range=(0, model.best_iteration))
test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, forecast_horizon))
test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, forecast_horizon))
metrics = evaluate_model(test_true_inv.flatten(), test_pred_inv.flatten())
metrics['training_time'] = training_time
metrics['best_iteration'] = model.best_iteration
2025-07-22 15:40:37 +08:00
print("\n模型评估指标:")
print(f"MSE: {metrics['mse']:.4f}, RMSE: {metrics['rmse']:.4f}, MAE: {metrics['mae']:.4f}, R²: {metrics['r2']:.4f}, MAPE: {metrics['mape']:.2f}%")
# 提取损失并绘制曲线
train_losses = evals_result['train']['rmse']
test_losses = evals_result['test']['rmse']
loss_curve_path = plot_loss_curve(
train_losses,
test_losses,
product_name,
'xgboost',
model_dir=model_dir
)
print(f"📈 损失曲线已保存到: {loss_curve_path}")
2025-07-22 15:40:37 +08:00
# --- 5. 模型保存 (借道 utils.model_manager) ---
model_data = {
'model_state_dict': model, # 直接保存模型对象
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
'sequence_length': sequence_length,
'forecast_horizon': forecast_horizon,
'model_type': 'xgboost',
'features': features,
'xgb_params': xgb_params
},
'metrics': metrics,
'loss_history': evals_result,
'loss_curve_path': loss_curve_path # 添加损失图路径
2025-07-22 15:40:37 +08:00
}
# 保存最终版本模型
2025-07-22 15:40:37 +08:00
final_model_path, final_version = model_manager.save_model(
model_data=model_data,
product_id=product_id,
model_type='xgboost',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name
)
print(f"✅ XGBoost最终模型已通过统一管理器保存版本: {final_version}, 路径: {final_model_path}")
# 保存最佳版本模型
best_model_path, best_version = model_manager.save_model(
model_data=model_data,
product_id=product_id,
model_type='xgboost',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name,
version='best' # 明确指定版本为 'best'
)
print(f"✅ XGBoost最佳模型已通过统一管理器保存版本: {best_version}, 路径: {best_model_path}")
2025-07-22 15:40:37 +08:00
# 返回值遵循统一格式
return model, metrics, final_version, final_model_path
# --- 将此训练器注册到系统中 ---
register_trainer('xgboost', train_product_model_with_xgboost)