167 lines
6.4 KiB
Python
167 lines
6.4 KiB
Python
"""
|
||
药店销售预测系统 - XGBoost 模型训练器 (插件式)
|
||
"""
|
||
|
||
import time
|
||
import pandas as pd
|
||
import numpy as np
|
||
import xgboost as xgb
|
||
from sklearn.preprocessing import MinMaxScaler
|
||
from xgboost.callback import EarlyStopping
|
||
|
||
# 导入核心工具
|
||
from utils.data_utils import create_dataset
|
||
from analysis.metrics import evaluate_model
|
||
from utils.model_manager import model_manager
|
||
from models.model_registry import register_trainer
|
||
from utils.visualization import plot_loss_curve # 导入绘图函数
|
||
|
||
def train_product_model_with_xgboost(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs):
|
||
"""
|
||
使用 XGBoost 模型训练产品销售预测模型。
|
||
此函数签名与其他训练器保持一致,以兼容注册表调用。
|
||
"""
|
||
print(f"🚀 XGBoost训练器启动: model_identifier='{model_identifier}'")
|
||
|
||
# --- 1. 数据准备和验证 ---
|
||
if product_df.empty:
|
||
raise ValueError(f"产品 {product_id} 没有可用的销售数据")
|
||
|
||
min_required_samples = sequence_length + forecast_horizon
|
||
if len(product_df) < min_required_samples:
|
||
error_msg = (f"数据不足: 需要 {min_required_samples} 条, 实际 {len(product_df)} 条。")
|
||
raise ValueError(error_msg)
|
||
|
||
product_df = product_df.sort_values('date')
|
||
product_name = product_df['product_name'].iloc[0] if 'product_name' in product_df.columns else model_identifier
|
||
|
||
# --- 2. 数据预处理和适配 ---
|
||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||
|
||
X = product_df[features].values
|
||
y = product_df[['sales']].values
|
||
|
||
scaler_X = MinMaxScaler(feature_range=(0, 1))
|
||
scaler_y = MinMaxScaler(feature_range=(0, 1))
|
||
|
||
X_scaled = scaler_X.fit_transform(X)
|
||
y_scaled = scaler_y.fit_transform(y)
|
||
|
||
train_size = int(len(X_scaled) * 0.8)
|
||
X_train_raw, X_test_raw = X_scaled[:train_size], X_scaled[train_size:]
|
||
y_train_raw, y_test_raw = y_scaled[:train_size], y_scaled[train_size:]
|
||
|
||
trainX, trainY = create_dataset(X_train_raw, y_train_raw, sequence_length, forecast_horizon)
|
||
testX, testY = create_dataset(X_test_raw, y_test_raw, sequence_length, forecast_horizon)
|
||
|
||
# **关键适配步骤**: XGBoost 需要二维输入
|
||
trainX = trainX.reshape(trainX.shape[0], -1)
|
||
testX = testX.reshape(testX.shape[0], -1)
|
||
|
||
# **关键适配**: 转换为 XGBoost 核心 DMatrix 格式,以使用稳定的 xgb.train API
|
||
dtrain = xgb.DMatrix(trainX, label=trainY)
|
||
dtest = xgb.DMatrix(testX, label=testY)
|
||
|
||
# --- 3. 模型训练 (使用核心 xgb.train API) ---
|
||
xgb_params = {
|
||
'learning_rate': kwargs.get('learning_rate', 0.08),
|
||
'subsample': kwargs.get('subsample', 0.75),
|
||
'colsample_bytree': kwargs.get('colsample_bytree', 1),
|
||
'max_depth': kwargs.get('max_depth', 7),
|
||
'gamma': kwargs.get('gamma', 0),
|
||
'objective': 'reg:squarederror',
|
||
'eval_metric': 'rmse', # eval_metric 在这里是原生支持的
|
||
'n_jobs': -1
|
||
}
|
||
n_estimators = kwargs.get('n_estimators', 500)
|
||
|
||
print("开始训练XGBoost模型 (使用核心xgb.train API)...")
|
||
start_time = time.time()
|
||
|
||
evals_result = {}
|
||
model = xgb.train(
|
||
params=xgb_params,
|
||
dtrain=dtrain,
|
||
num_boost_round=n_estimators,
|
||
evals=[(dtrain, 'train'), (dtest, 'test')],
|
||
early_stopping_rounds=50, # early_stopping_rounds 在这里是原生支持的
|
||
evals_result=evals_result,
|
||
verbose_eval=False
|
||
)
|
||
|
||
training_time = time.time() - start_time
|
||
print(f"XGBoost模型训练完成,耗时: {training_time:.2f}秒")
|
||
|
||
# --- 4. 模型评估与可视化 ---
|
||
# 使用 model.best_iteration 获取最佳轮次的预测结果
|
||
test_pred = model.predict(dtest, iteration_range=(0, model.best_iteration))
|
||
|
||
test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, forecast_horizon))
|
||
test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, forecast_horizon))
|
||
|
||
metrics = evaluate_model(test_true_inv.flatten(), test_pred_inv.flatten())
|
||
metrics['training_time'] = training_time
|
||
metrics['best_iteration'] = model.best_iteration
|
||
|
||
print("\n模型评估指标:")
|
||
print(f"MSE: {metrics['mse']:.4f}, RMSE: {metrics['rmse']:.4f}, MAE: {metrics['mae']:.4f}, R²: {metrics['r2']:.4f}, MAPE: {metrics['mape']:.2f}%")
|
||
|
||
# 提取损失并绘制曲线
|
||
train_losses = evals_result['train']['rmse']
|
||
test_losses = evals_result['test']['rmse']
|
||
loss_curve_path = plot_loss_curve(
|
||
train_losses,
|
||
test_losses,
|
||
product_name,
|
||
'xgboost',
|
||
model_dir=model_dir
|
||
)
|
||
print(f"📈 损失曲线已保存到: {loss_curve_path}")
|
||
|
||
# --- 5. 模型保存 (借道 utils.model_manager) ---
|
||
model_data = {
|
||
'model_state_dict': model, # 直接保存模型对象
|
||
'scaler_X': scaler_X,
|
||
'scaler_y': scaler_y,
|
||
'config': {
|
||
'sequence_length': sequence_length,
|
||
'forecast_horizon': forecast_horizon,
|
||
'model_type': 'xgboost',
|
||
'features': features,
|
||
'xgb_params': xgb_params
|
||
},
|
||
'metrics': metrics,
|
||
'loss_history': evals_result,
|
||
'loss_curve_path': loss_curve_path # 添加损失图路径
|
||
}
|
||
|
||
# 保存最终版本模型
|
||
final_model_path, final_version = model_manager.save_model(
|
||
model_data=model_data,
|
||
product_id=product_id,
|
||
model_type='xgboost',
|
||
store_id=store_id,
|
||
training_mode=training_mode,
|
||
aggregation_method=aggregation_method,
|
||
product_name=product_name
|
||
)
|
||
print(f"✅ XGBoost最终模型已通过统一管理器保存,版本: {final_version}, 路径: {final_model_path}")
|
||
|
||
# 保存最佳版本模型
|
||
best_model_path, best_version = model_manager.save_model(
|
||
model_data=model_data,
|
||
product_id=product_id,
|
||
model_type='xgboost',
|
||
store_id=store_id,
|
||
training_mode=training_mode,
|
||
aggregation_method=aggregation_method,
|
||
product_name=product_name,
|
||
version='best' # 明确指定版本为 'best'
|
||
)
|
||
print(f"✅ XGBoost最佳模型已通过统一管理器保存,版本: {best_version}, 路径: {best_model_path}")
|
||
|
||
# 返回值遵循统一格式
|
||
return model, metrics, final_version, final_model_path
|
||
|
||
# --- 将此训练器注册到系统中 ---
|
||
register_trainer('xgboost', train_product_model_with_xgboost) |