""" 药店销售预测系统 - XGBoost 模型训练器 (插件式) """ import time import pandas as pd import numpy as np import xgboost as xgb from sklearn.preprocessing import MinMaxScaler from xgboost.callback import EarlyStopping # 导入核心工具 from utils.data_utils import create_dataset from utils.new_data_loader import load_new_data from analysis.metrics import evaluate_model from utils.model_manager import model_manager from models.model_registry import register_trainer from utils.visualization import plot_loss_curve # 导入绘图函数 def train_product_model_with_xgboost(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs): """ 使用 XGBoost 模型训练产品销售预测模型。 此函数签名与其他训练器保持一致,以兼容注册表调用。 """ print(f"🚀 XGBoost训练器启动: model_identifier='{model_identifier}'") # --- 1. 数据加载与筛选重构 --- if product_df is None: print("正在使用新的统一数据加载器...") full_df = load_new_data() if training_mode == 'store' and store_id: store_df = full_df[full_df['store_id'] == store_id].copy() if product_id and product_id != 'unknown' and product_id != 'all_products': product_df = store_df[store_df['product_id'] == product_id].copy() else: product_df = store_df.groupby('date').agg({ 'sales': 'sum', 'weekday': 'first', 'month': 'first', 'is_holiday': 'first', 'is_weekend': 'first', 'is_promotion': 'first', 'temperature': 'mean' }).reset_index() product_df.fillna(0, inplace=True) elif training_mode == 'global': product_df = full_df[full_df['product_id'] == product_id].copy() product_df = product_df.groupby('date').agg({ 'sales': 'sum', 'weekday': 'first', 'month': 'first', 'is_holiday': 'first', 'is_weekend': 'first', 'is_promotion': 'first', 'temperature': 'mean' }).reset_index() product_df.fillna(0, inplace=True) else: # 默认 'product' 模式 product_df = full_df[full_df['product_id'] == product_id].copy() min_required_samples = sequence_length + forecast_horizon if len(product_df) < min_required_samples: error_msg = (f"数据不足: 需要 {min_required_samples} 条, 实际 {len(product_df)} 条。") raise ValueError(error_msg) product_df = product_df.sort_values('date') product_name = f"产品 {product_id}" if 'product_name' in product_df.columns and not product_df['product_name'].empty: product_name = product_df['product_name'].iloc[0] # --- 2. 数据预处理和适配 --- features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'] X = product_df[features].values y = product_df[['sales']].values scaler_X = MinMaxScaler(feature_range=(0, 1)) scaler_y = MinMaxScaler(feature_range=(0, 1)) X_scaled = scaler_X.fit_transform(X) y_scaled = scaler_y.fit_transform(y) train_size = int(len(X_scaled) * 0.8) X_train_raw, X_test_raw = X_scaled[:train_size], X_scaled[train_size:] y_train_raw, y_test_raw = y_scaled[:train_size], y_scaled[train_size:] trainX, trainY = create_dataset(X_train_raw, y_train_raw, sequence_length, forecast_horizon) testX, testY = create_dataset(X_test_raw, y_test_raw, sequence_length, forecast_horizon) # **关键适配步骤**: XGBoost 需要二维输入 trainX = trainX.reshape(trainX.shape[0], -1) testX = testX.reshape(testX.shape[0], -1) # **关键适配**: 转换为 XGBoost 核心 DMatrix 格式,以使用稳定的 xgb.train API dtrain = xgb.DMatrix(trainX, label=trainY) dtest = xgb.DMatrix(testX, label=testY) # --- 3. 模型训练 (使用核心 xgb.train API) --- xgb_params = { 'learning_rate': kwargs.get('learning_rate', 0.08), 'subsample': kwargs.get('subsample', 0.75), 'colsample_bytree': kwargs.get('colsample_bytree', 1), 'max_depth': kwargs.get('max_depth', 7), 'gamma': kwargs.get('gamma', 0), 'objective': 'reg:squarederror', 'eval_metric': 'rmse', # eval_metric 在这里是原生支持的 'n_jobs': -1 } n_estimators = kwargs.get('n_estimators', 500) print("开始训练XGBoost模型 (使用核心xgb.train API)...") start_time = time.time() evals_result = {} model = xgb.train( params=xgb_params, dtrain=dtrain, num_boost_round=n_estimators, evals=[(dtrain, 'train'), (dtest, 'test')], early_stopping_rounds=50, # early_stopping_rounds 在这里是原生支持的 evals_result=evals_result, verbose_eval=False ) training_time = time.time() - start_time print(f"XGBoost模型训练完成,耗时: {training_time:.2f}秒") # --- 4. 模型评估与可视化 --- # 使用 model.best_iteration 获取最佳轮次的预测结果 test_pred = model.predict(dtest, iteration_range=(0, model.best_iteration)) test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, forecast_horizon)) test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, forecast_horizon)) metrics = evaluate_model(test_true_inv.flatten(), test_pred_inv.flatten()) metrics['training_time'] = training_time metrics['best_iteration'] = model.best_iteration print("\n模型评估指标:") print(f"MSE: {metrics['mse']:.4f}, RMSE: {metrics['rmse']:.4f}, MAE: {metrics['mae']:.4f}, R²: {metrics['r2']:.4f}, MAPE: {metrics['mape']:.2f}%") # 提取损失并绘制曲线 train_losses = evals_result['train']['rmse'] test_losses = evals_result['test']['rmse'] loss_curve_path = plot_loss_curve( train_losses, test_losses, product_name, 'xgboost', model_dir=model_dir ) print(f"📈 损失曲线已保存到: {loss_curve_path}") # --- 5. 模型保存 (借道 utils.model_manager) --- model_data = { 'model_state_dict': model, # 直接保存模型对象 'scaler_X': scaler_X, 'scaler_y': scaler_y, 'config': { 'sequence_length': sequence_length, 'forecast_horizon': forecast_horizon, 'model_type': 'xgboost', 'features': features, 'xgb_params': xgb_params }, 'metrics': metrics, 'loss_history': evals_result, 'loss_curve_path': loss_curve_path # 添加损失图路径 } # 保存最终版本模型 final_model_path, final_version = model_manager.save_model( model_data=model_data, product_id=product_id, model_type='xgboost', store_id=store_id, training_mode=training_mode, aggregation_method=aggregation_method, product_name=product_name ) print(f"✅ XGBoost最终模型已通过统一管理器保存,版本: {final_version}, 路径: {final_model_path}") # 保存最佳版本模型 best_model_path, best_version = model_manager.save_model( model_data=model_data, product_id=product_id, model_type='xgboost', store_id=store_id, training_mode=training_mode, aggregation_method=aggregation_method, product_name=product_name, version='best' # 明确指定版本为 'best' ) print(f"✅ XGBoost最佳模型已通过统一管理器保存,版本: {best_version}, 路径: {best_model_path}") # 返回值遵循统一格式 return model, metrics, final_version, final_model_path # --- 将此训练器注册到系统中 --- register_trainer('xgboost', train_product_model_with_xgboost)