ShopTRAINING/server/trainers/xgboost_trainer.py

198 lines
8.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
药店销售预测系统 - XGBoost 模型训练器 (插件式)
"""
import time
import os
import pandas as pd
import numpy as np
import xgboost as xgb
from sklearn.preprocessing import MinMaxScaler
from xgboost.callback import EarlyStopping
import json
import torch
# 导入核心工具
from utils.data_utils import prepare_tabular_data
from analysis.metrics import evaluate_model
from utils.model_manager import model_manager
from models.model_registry import register_trainer
from utils.visualization import plot_loss_curve # 导入绘图函数
def train_product_model_with_xgboost(
model_identifier: str,
training_df: pd.DataFrame,
feature_list: list,
training_mode: str,
epochs: int = 500, # XGBoost通常需要更多轮次
sequence_length: int = 1, # 对于非序列模型,此参数意义不大,但为兼容性保留
forecast_horizon: int = 1,
model_dir: str = 'saved_models',
product_id: str = None,
store_id: str = None,
aggregation_method: str = None,
version: str = None,
**kwargs
):
"""
使用 XGBoost 模型训练产品销售预测模型 (新数据管道版)。
"""
print(f"🚀 XGBoost训练器启动: model_identifier='{model_identifier}'")
created_files = []
success = False
try:
# --- 1. 数据准备和验证 ---
if training_df.empty:
raise ValueError("用于训练的数据为空")
product_name = training_df['product_name'].iloc[0] if 'product_name' in training_df.columns else model_identifier
# --- 2. 数据预处理和适配 ---
print(f"[XGBoost] 开始数据预处理,使用 {len(feature_list)} 个预选特征...")
trainX, testX, trainY, testY, scaler_X, scaler_y, used_features = prepare_tabular_data(
training_df=training_df,
feature_list=feature_list,
target_column='net_sales_quantity'
)
dtrain = xgb.DMatrix(trainX, label=trainY)
dtest = xgb.DMatrix(testX, label=testY)
# --- 3. 模型训练 ---
xgb_params = {
'learning_rate': kwargs.get('learning_rate', 0.08),
'subsample': kwargs.get('subsample', 0.75),
'colsample_bytree': kwargs.get('colsample_bytree', 1),
'max_depth': kwargs.get('max_depth', 7),
'gamma': kwargs.get('gamma', 0),
'objective': 'reg:squarederror',
'eval_metric': 'rmse',
'n_jobs': -1
}
# 核心修复使用前端传入的epochs作为训练轮次 (num_boost_round)
n_estimators = epochs
print(f"开始训练XGBoost模型 (使用核心xgb.train API),共 {n_estimators} 轮...")
current_version = model_manager.peek_next_version(
model_type='xgboost', product_id=product_id, store_id=store_id,
training_mode=training_mode, aggregation_method=aggregation_method
)
print(f"🔒 本次训练版本锁定为: {current_version}")
start_time = time.time()
evals_result = {}
model = xgb.train(
params=xgb_params, dtrain=dtrain, num_boost_round=n_estimators,
evals=[(dtrain, 'train'), (dtest, 'test')],
early_stopping_rounds=50, evals_result=evals_result, verbose_eval=False
)
training_time = time.time() - start_time
print(f"XGBoost模型训练完成耗时: {training_time:.2f}")
# --- 4. 模型评估 ---
test_pred = model.predict(dtest, iteration_range=(0, model.best_iteration))
# 核心修复:确保真实值和预测值都进行反归一化
test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, 1))
test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, 1))
metrics = evaluate_model(test_true_inv.flatten(), test_pred_inv.flatten())
metrics.update({'training_time': training_time, 'best_iteration': model.best_iteration})
print(f"\n模型评估指标 (真实值): MSE={metrics['mse']:.4f}, RMSE={metrics['rmse']:.4f}, MAE={metrics['mae']:.4f}, R²={metrics['r2']:.4f}, MAPE={metrics['mape']:.2f}%")
# --- 5. 保存工件 ---
scope = training_mode
identifier = product_id if scope == 'product' else store_id if scope == 'store' else aggregation_method if scope == 'global' else product_name
# 核心修复:安全地提取损失历史数据
train_losses = evals_result.get('train', {}).get('rmse', [])
test_losses = evals_result.get('test', {}).get('rmse', [])
# 准备X轴数据 (boosting rounds)
rounds = list(range(1, len(train_losses) + 1)) if train_losses else []
# 绘制损失曲线
loss_curve_path = plot_loss_curve(
train_losses=train_losses, val_losses=test_losses,
model_type='xgboost', scope=scope, identifier=identifier,
version=current_version, model_dir=model_dir,
x_axis_data=rounds # 传递X轴数据
)
created_files.append(loss_curve_path)
print(f"📈 损失曲线已保存: {loss_curve_path}")
# 准备 Checkpoint
config = {
'model_type': 'xgboost',
'features': used_features,
'sequence_length': sequence_length, # For compatibility
'params': xgb_params,
'best_iteration': model.best_iteration
}
checkpoint = {
'model_raw': model.save_raw(), # 序列化XGBoost模型为byte array
'config': config,
'scaler_X': scaler_X,
'scaler_y': scaler_y
}
# 保存最终模型 Checkpoint
base_model_filename = model_manager.generate_model_filename(
model_type='xgboost', version=current_version, training_mode=training_mode,
product_id=product_id, store_id=store_id, aggregation_method=aggregation_method
)
final_model_path = os.path.join(model_dir, base_model_filename)
torch.save(checkpoint, final_model_path)
created_files.append(final_model_path)
print(f"✅ 最终模型Checkpoint已创建: {final_model_path}")
# 保存最佳模型 Checkpoint (内容相同,仅文件名不同)
best_model_filename = model_manager.generate_model_filename(
model_type='xgboost', version=f"{current_version}_best", training_mode=training_mode,
product_id=product_id, store_id=store_id, aggregation_method=aggregation_method
)
best_model_path = os.path.join(model_dir, best_model_filename)
torch.save(checkpoint, best_model_path)
created_files.append(best_model_path)
print(f"✅ 最佳模型Checkpoint已创建: {best_model_path}")
# 保存损失历史
base_filename = os.path.splitext(final_model_path)[0]
loss_data_filename = f"{base_filename}_loss_curve_data.json"
loss_data_path = os.path.join(model_dir, loss_data_filename)
with open(loss_data_path, 'w') as f:
json.dump({
'epochs': rounds, # 使用正确的boosting rounds作为epochs
'train_loss': train_losses,
'test_loss': test_losses
}, f)
created_files.append(loss_data_path)
print(f"💾 损失历史数据已保存: {loss_data_path}")
artifacts = {
"versioned_model": final_model_path,
"loss_curve_plot": loss_curve_path,
"loss_curve_data": loss_data_path,
"best_model": best_model_path,
"version": current_version
}
success = True
return metrics, artifacts
finally:
if not success:
print("❌ 训练失败,正在回滚并删除已创建的文件...")
for file_path in created_files:
try:
if os.path.exists(file_path):
os.remove(file_path)
print(f" - 已删除: {file_path}")
except OSError as e:
print(f" - 警告: 删除文件 '{file_path}' 失败: {e}")
# --- 将此训练器注册到系统中 ---
register_trainer('xgboost', train_product_model_with_xgboost)