数据文件保存机构改为### 1.2. 文件存储位置 - **最终产物**: 所有最终模型、元数据文件、损失图等,统一存放在 `saved_models/` 根目录下。 - **过程文件**: 所有训练过程中的检查点文件,统一存放在 `saved_models/checkpoints/` 目录下。 ### 1.3. 文件名生成规则 1. **构建逻辑路径**: 根据训练参数(模式、范围、类型、版本)确定逻辑路径。 - *示例*: `product/P001_all/mlstm/v2` 2. **生成文件名前缀**: 将逻辑路径中的所有 `/` 替换为 `_`。 - *示例*: `product_P001_all_mlstm_v2` 3. **拼接文件后缀**: 在前缀后加上描述文件类型的后缀。 - `_model.pth` - `_loss_curve.png` - `_checkpoint_best.pth` - `_checkpoint_epoch_{N}.pth` #### **完整示例:** - **最终模型**: `saved_models/product_P001_all_mlstm_v2_model.pth` - **最佳检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_best.pth` - **Epoch 50 检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_epoch_50.pth`
297 lines
11 KiB
Python
297 lines
11 KiB
Python
import xgboost as xgb
|
||
import numpy as np
|
||
import pandas as pd
|
||
import os
|
||
import joblib
|
||
import xgboost as xgb
|
||
from xgboost.callback import EarlyStopping
|
||
import matplotlib.pyplot as plt
|
||
from sklearn.model_selection import train_test_split
|
||
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
||
from sklearn.preprocessing import MinMaxScaler
|
||
|
||
# 从项目中导入正确的工具函数和配置
|
||
from utils.multi_store_data_utils import load_multi_store_data
|
||
from core.config import DEFAULT_DATA_PATH
|
||
from utils.file_save import ModelPathManager
|
||
from analysis.metrics import evaluate_model
|
||
|
||
# 重构后的原生API兼容回调
|
||
class EpochCheckpointCallback(xgb.callback.TrainingCallback):
|
||
def __init__(self, save_period, payload, base_path):
|
||
super().__init__()
|
||
self.save_period = save_period
|
||
self.payload = payload
|
||
self.base_path = base_path
|
||
self.best_score = float('inf')
|
||
|
||
def _save_checkpoint(self, model, path_suffix):
|
||
"""辅助函数,用于保存模型和元数据检查点"""
|
||
metadata_path = self.base_path.replace('_model.pth', f'_{path_suffix}.pth')
|
||
model_file_path = metadata_path.replace('.pth', '.xgb')
|
||
|
||
# 确保目录存在
|
||
os.makedirs(os.path.dirname(metadata_path), exist_ok=True)
|
||
|
||
# 保存原生Booster模型
|
||
model.save_model(model_file_path)
|
||
|
||
# 更新payload中的模型文件引用
|
||
self.payload['model_file'] = model_file_path
|
||
joblib.dump(self.payload, metadata_path)
|
||
|
||
print(f"[Checkpoint] 已保存检查点到: {metadata_path}")
|
||
|
||
def after_iteration(self, model, epoch, evals_log):
|
||
# 获取当前验证集的分数 (假设'test'是验证集)
|
||
current_score = evals_log['test']['rmse'][-1]
|
||
|
||
# 保存最佳模型
|
||
if current_score < self.best_score:
|
||
self.best_score = current_score
|
||
self._save_checkpoint(model, 'checkpoint_best')
|
||
|
||
# 保存周期性检查点
|
||
if (epoch + 1) % self.save_period == 0:
|
||
self._save_checkpoint(model, f'checkpoint_epoch_{epoch + 1}')
|
||
|
||
return False # 继续训练
|
||
|
||
def create_dataset(data, look_back=7):
|
||
"""
|
||
将时间序列数据转换为监督学习格式。
|
||
:param data: 输入的DataFrame,包含特征和目标。
|
||
:param look_back: 用于预测的时间窗口大小。
|
||
:return: X (特征), y (目标)
|
||
"""
|
||
X, y = [], []
|
||
feature_columns = [col for col in data.columns if col != 'date']
|
||
|
||
for i in range(len(data) - look_back):
|
||
# 展平look_back窗口内的所有特征
|
||
features = data[feature_columns].iloc[i:(i + look_back)].values.flatten()
|
||
X.append(features)
|
||
# 目标是窗口后的第一个销售值
|
||
y.append(data['sales'].iloc[i + look_back])
|
||
|
||
return np.array(X), np.array(y)
|
||
|
||
def train_product_model_with_xgboost(
|
||
product_id,
|
||
store_id=None,
|
||
epochs=100, # XGBoost中n_estimators更常用
|
||
look_back=7,
|
||
socketio=None,
|
||
task_id=None,
|
||
version='v1',
|
||
path_info=None,
|
||
**kwargs):
|
||
"""
|
||
使用XGBoost训练产品销售预测模型。
|
||
"""
|
||
|
||
def emit_progress(message, progress=None):
|
||
if socketio and task_id:
|
||
payload = {'task_id': task_id, 'message': message}
|
||
if progress is not None:
|
||
payload['progress'] = progress
|
||
socketio.emit('training_update', payload, namespace='/api/training', room=task_id)
|
||
print(f"[{task_id}] {message}")
|
||
|
||
try:
|
||
model_path = None
|
||
emit_progress("开始XGBoost模型训练...", 0)
|
||
|
||
# 1. 加载数据
|
||
# 使用正确的函数并从config导入路径
|
||
full_df = load_multi_store_data(DEFAULT_DATA_PATH)
|
||
|
||
# 根据 store_id 和 product_id 筛选数据
|
||
if store_id:
|
||
df = full_df[(full_df['product_id'] == product_id) & (full_df['store_id'] == store_id)].copy()
|
||
else:
|
||
# 如果没有store_id,则聚合该产品在所有店铺的数据
|
||
df = full_df[full_df['product_id'] == product_id].groupby('date').agg({
|
||
'sales': 'sum',
|
||
'weekday': 'first',
|
||
'month': 'first',
|
||
'is_holiday': 'max',
|
||
'is_weekend': 'max',
|
||
'is_promotion': 'max',
|
||
'temperature': 'mean'
|
||
}).reset_index()
|
||
|
||
if df.empty:
|
||
raise ValueError(f"加载的数据为空 (product: {product_id}, store: {store_id}),无法进行训练。")
|
||
|
||
# 确保数据按日期排序
|
||
df = df.sort_values('date').reset_index(drop=True)
|
||
|
||
emit_progress("数据加载完成。", 10)
|
||
|
||
# 2. 创建数据集
|
||
features_to_use = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||
# 确保所有需要的特征都存在
|
||
for col in features_to_use:
|
||
if col not in df.columns:
|
||
# 如果特征不存在,用0填充
|
||
df[col] = 0
|
||
|
||
df_features = df[['date'] + features_to_use]
|
||
|
||
X, y = create_dataset(df_features, look_back)
|
||
if X.shape[0] == 0:
|
||
raise ValueError("创建数据集后样本数量为0,请检查数据量和look_back参数。")
|
||
|
||
emit_progress(f"数据集创建完成,样本数: {X.shape[0]}", 20)
|
||
|
||
# 3. 划分训练集和测试集
|
||
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
|
||
|
||
# 数据缩放
|
||
scaler_X = MinMaxScaler()
|
||
X_train_scaled = scaler_X.fit_transform(X_train)
|
||
X_test_scaled = scaler_X.transform(X_test)
|
||
|
||
scaler_y = MinMaxScaler()
|
||
y_train_scaled = scaler_y.fit_transform(y_train.reshape(-1, 1))
|
||
# y_test is not scaled, used for metric calculation against inverse_transformed predictions
|
||
|
||
emit_progress("数据划分和缩放完成。", 30)
|
||
|
||
# 4. 切换到XGBoost原生API
|
||
params = {
|
||
'learning_rate': kwargs.get('learning_rate', 0.1),
|
||
'max_depth': kwargs.get('max_depth', 5),
|
||
'subsample': kwargs.get('subsample', 0.8),
|
||
'colsample_bytree': kwargs.get('colsample_bytree', 0.8),
|
||
'objective': 'reg:squarederror',
|
||
'eval_metric': 'rmse',
|
||
'random_state': 42
|
||
}
|
||
|
||
dtrain = xgb.DMatrix(X_train_scaled, label=y_train_scaled.ravel())
|
||
dtest = xgb.DMatrix(X_test_scaled, label=scaler_y.transform(y_test.reshape(-1, 1)).ravel())
|
||
|
||
emit_progress("开始模型训练...", 40)
|
||
|
||
# 定义验证集
|
||
evals = [(dtrain, 'train'), (dtest, 'test')]
|
||
|
||
# 准备回调
|
||
callbacks = []
|
||
checkpoint_interval = kwargs.get('checkpoint_interval', 10) # 默认每10轮保存一次
|
||
if path_info and path_info.get('model_path') and checkpoint_interval > 0:
|
||
# 准备用于保存的payload,模型对象将在回调中动态更新
|
||
checkpoint_payload = {
|
||
'metrics': {}, # 检查点不保存最终指标
|
||
'scaler_X': scaler_X,
|
||
'scaler_y': scaler_y,
|
||
'config': {
|
||
'look_back': look_back,
|
||
'features': features_to_use,
|
||
'product_id': product_id,
|
||
'store_id': store_id,
|
||
'version': version
|
||
}
|
||
}
|
||
checkpoint_callback = EpochCheckpointCallback(
|
||
save_period=checkpoint_interval,
|
||
payload=checkpoint_payload,
|
||
base_path=path_info['model_path']
|
||
)
|
||
callbacks.append(checkpoint_callback)
|
||
|
||
# 添加早停回调 (移除save_best)
|
||
callbacks.append(EarlyStopping(rounds=10))
|
||
|
||
# 用于存储评估结果
|
||
evals_result = {}
|
||
|
||
model = xgb.train(
|
||
params=params,
|
||
dtrain=dtrain,
|
||
num_boost_round=epochs,
|
||
evals=evals,
|
||
callbacks=callbacks,
|
||
evals_result=evals_result,
|
||
verbose_eval=False
|
||
)
|
||
emit_progress("模型训练完成。", 80)
|
||
|
||
# 绘制并保存损失曲线
|
||
if path_info and path_info.get('model_path'):
|
||
try:
|
||
loss_curve_path = path_info['model_path'].replace('_model.pth', '_loss_curve.png')
|
||
results = evals_result
|
||
train_rmse = results['train']['rmse']
|
||
test_rmse = results['test']['rmse']
|
||
num_epochs = len(train_rmse)
|
||
x_axis = range(0, num_epochs)
|
||
|
||
fig, ax = plt.subplots(figsize=(10, 6))
|
||
ax.plot(x_axis, train_rmse, label='Train')
|
||
ax.plot(x_axis, test_rmse, label='Test')
|
||
ax.legend()
|
||
plt.ylabel('RMSE')
|
||
plt.xlabel('Epoch')
|
||
plt.title('XGBoost RMSE Loss Curve')
|
||
plt.savefig(loss_curve_path)
|
||
plt.close(fig)
|
||
emit_progress(f"损失曲线图已保存到: {loss_curve_path}")
|
||
except Exception as e:
|
||
emit_progress(f"警告: 绘制损失曲线失败: {str(e)}")
|
||
|
||
# 5. 评估模型
|
||
dtest_pred = xgb.DMatrix(X_test_scaled)
|
||
y_pred_scaled = model.predict(dtest_pred)
|
||
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
|
||
|
||
metrics = {
|
||
'RMSE': np.sqrt(mean_squared_error(y_test, y_pred)),
|
||
'MAE': mean_absolute_error(y_test, y_pred),
|
||
'R2': r2_score(y_test, y_pred)
|
||
}
|
||
emit_progress(f"模型评估完成: {metrics}", 90)
|
||
|
||
# 6. 保存模型 (原生API方式)
|
||
if path_info and path_info.get('model_path'):
|
||
metadata_path = path_info['model_path']
|
||
# 使用 .xgb 扩展名保存原生Booster模型
|
||
model_file_path = metadata_path.replace('.pth', '.xgb')
|
||
|
||
# 确保目录存在
|
||
os.makedirs(os.path.dirname(metadata_path), exist_ok=True)
|
||
|
||
# 使用原生方法保存Booster模型
|
||
model.save_model(model_file_path)
|
||
emit_progress(f"原生XGBoost模型已保存到: {model_file_path}")
|
||
|
||
# 保存元数据(包括模型文件路径)
|
||
metadata_payload = {
|
||
'model_file': model_file_path, # 保存模型文件的引用
|
||
'metrics': metrics,
|
||
'scaler_X': scaler_X,
|
||
'scaler_y': scaler_y,
|
||
'config': {
|
||
'look_back': look_back,
|
||
'features': features_to_use,
|
||
'product_id': product_id,
|
||
'store_id': store_id,
|
||
'version': version
|
||
}
|
||
}
|
||
joblib.dump(metadata_payload, metadata_path)
|
||
model_path = metadata_path # 确保model_path被赋值
|
||
emit_progress(f"模型元数据已保存到: {metadata_path}", 100)
|
||
else:
|
||
emit_progress("警告: 未提供path_info,模型未保存。", 100)
|
||
|
||
return metrics, model_path
|
||
|
||
except Exception as e:
|
||
emit_progress(f"XGBoost训练失败: {str(e)}", 100)
|
||
import traceback
|
||
traceback.print_exc()
|
||
return {'error': str(e)}, None
|