2025-07-22 15:40:37 +08:00
|
|
|
|
# -*- coding: utf-8 -*-
|
|
|
|
|
"""
|
|
|
|
|
CNN-BiLSTM-Attention 模型训练器
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.optim as optim
|
|
|
|
|
import numpy as np
|
2025-07-23 16:55:27 +08:00
|
|
|
|
import time
|
|
|
|
|
import copy
|
2025-07-22 15:40:37 +08:00
|
|
|
|
|
|
|
|
|
from models.model_registry import register_trainer
|
|
|
|
|
from utils.model_manager import model_manager
|
|
|
|
|
from analysis.metrics import evaluate_model
|
|
|
|
|
from utils.data_utils import create_dataset
|
|
|
|
|
from sklearn.preprocessing import MinMaxScaler
|
2025-07-23 16:55:27 +08:00
|
|
|
|
from utils.visualization import plot_loss_curve # 导入绘图函数
|
2025-07-22 15:40:37 +08:00
|
|
|
|
|
|
|
|
|
# 导入新创建的模型
|
|
|
|
|
from models.cnn_bilstm_attention import CnnBiLstmAttention
|
|
|
|
|
|
|
|
|
|
def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs):
|
|
|
|
|
"""
|
2025-07-23 16:55:27 +08:00
|
|
|
|
使用 CNN-BiLSTM-Attention 模型进行训练,并实现早停和最佳模型保存。
|
2025-07-22 15:40:37 +08:00
|
|
|
|
"""
|
|
|
|
|
print(f"🚀 CNN-BiLSTM-Attention 训练器启动: model_identifier='{model_identifier}'")
|
2025-07-23 16:55:27 +08:00
|
|
|
|
start_time = time.time()
|
2025-07-22 15:40:37 +08:00
|
|
|
|
|
|
|
|
|
# --- 1. 数据准备 ---
|
|
|
|
|
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
2025-07-23 16:55:27 +08:00
|
|
|
|
product_name = product_df['product_name'].iloc[0] if 'product_name' in product_df.columns else model_identifier
|
2025-07-22 15:40:37 +08:00
|
|
|
|
|
|
|
|
|
X = product_df[features].values
|
|
|
|
|
y = product_df[['sales']].values
|
|
|
|
|
|
|
|
|
|
scaler_X = MinMaxScaler(feature_range=(0, 1))
|
|
|
|
|
scaler_y = MinMaxScaler(feature_range=(0, 1))
|
|
|
|
|
|
|
|
|
|
X_scaled = scaler_X.fit_transform(X)
|
|
|
|
|
y_scaled = scaler_y.fit_transform(y)
|
|
|
|
|
|
|
|
|
|
train_size = int(len(X_scaled) * 0.8)
|
|
|
|
|
X_train_raw, X_test_raw = X_scaled[:train_size], X_scaled[train_size:]
|
|
|
|
|
y_train_raw, y_test_raw = y_scaled[:train_size], y_scaled[train_size:]
|
|
|
|
|
|
|
|
|
|
trainX, trainY = create_dataset(X_train_raw, y_train_raw, sequence_length, forecast_horizon)
|
|
|
|
|
testX, testY = create_dataset(X_test_raw, y_test_raw, sequence_length, forecast_horizon)
|
|
|
|
|
|
|
|
|
|
trainX = torch.from_numpy(trainX).float()
|
|
|
|
|
trainY = torch.from_numpy(trainY).float()
|
|
|
|
|
testX = torch.from_numpy(testX).float()
|
|
|
|
|
testY = torch.from_numpy(testY).float()
|
|
|
|
|
|
|
|
|
|
# --- 2. 实例化模型和优化器 ---
|
|
|
|
|
input_dim = trainX.shape[2]
|
|
|
|
|
|
|
|
|
|
model = CnnBiLstmAttention(
|
|
|
|
|
input_dim=input_dim,
|
|
|
|
|
output_dim=forecast_horizon,
|
|
|
|
|
sequence_length=sequence_length
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
optimizer = optim.Adam(model.parameters(), lr=kwargs.get('learning_rate', 0.001))
|
|
|
|
|
criterion = torch.nn.MSELoss()
|
|
|
|
|
|
2025-07-23 16:55:27 +08:00
|
|
|
|
# --- 3. 训练循环与早停 ---
|
|
|
|
|
print("开始训练 CNN-BiLSTM-Attention 模型 (含早停)...")
|
2025-07-24 17:55:10 +08:00
|
|
|
|
|
|
|
|
|
# 版本锁定:在训练开始前确定本次训练的版本号
|
|
|
|
|
current_version = model_manager.peek_next_version(
|
|
|
|
|
model_type='cnn_bilstm_attention',
|
|
|
|
|
product_id=product_id,
|
|
|
|
|
store_id=store_id,
|
|
|
|
|
training_mode=training_mode,
|
|
|
|
|
aggregation_method=aggregation_method
|
|
|
|
|
)
|
|
|
|
|
print(f"🔒 本次训练版本锁定为: {current_version}")
|
|
|
|
|
|
2025-07-23 16:55:27 +08:00
|
|
|
|
loss_history = {'train': [], 'val': []}
|
|
|
|
|
best_val_loss = float('inf')
|
|
|
|
|
best_model_state = None
|
2025-07-24 17:55:10 +08:00
|
|
|
|
best_model_path = None # 用于存储最佳模型的路径
|
2025-07-23 16:55:27 +08:00
|
|
|
|
patience = kwargs.get('patience', 15)
|
|
|
|
|
patience_counter = 0
|
|
|
|
|
|
2025-07-22 15:40:37 +08:00
|
|
|
|
for epoch in range(epochs):
|
|
|
|
|
model.train()
|
|
|
|
|
optimizer.zero_grad()
|
|
|
|
|
|
|
|
|
|
outputs = model(trainX)
|
2025-07-23 16:55:27 +08:00
|
|
|
|
train_loss = criterion(outputs, trainY.squeeze(-1))
|
2025-07-22 15:40:37 +08:00
|
|
|
|
|
2025-07-23 16:55:27 +08:00
|
|
|
|
train_loss.backward()
|
2025-07-22 15:40:37 +08:00
|
|
|
|
optimizer.step()
|
|
|
|
|
|
2025-07-23 16:55:27 +08:00
|
|
|
|
# 验证
|
|
|
|
|
model.eval()
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
val_outputs = model(testX)
|
|
|
|
|
val_loss = criterion(val_outputs, testY.squeeze(-1))
|
|
|
|
|
|
|
|
|
|
loss_history['train'].append(train_loss.item())
|
|
|
|
|
loss_history['val'].append(val_loss.item())
|
|
|
|
|
|
2025-07-22 15:40:37 +08:00
|
|
|
|
if (epoch + 1) % 10 == 0:
|
2025-07-23 16:55:27 +08:00
|
|
|
|
print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss.item():.4f}, Val Loss: {val_loss.item():.4f}')
|
|
|
|
|
|
|
|
|
|
# 早停逻辑
|
|
|
|
|
if val_loss.item() < best_val_loss:
|
|
|
|
|
best_val_loss = val_loss.item()
|
|
|
|
|
best_model_state = copy.deepcopy(model.state_dict())
|
|
|
|
|
patience_counter = 0
|
|
|
|
|
print(f"✨ 新的最佳模型! Epoch: {epoch+1}, Val Loss: {best_val_loss:.4f}")
|
2025-07-24 17:55:10 +08:00
|
|
|
|
|
|
|
|
|
# 立即保存最佳模型
|
|
|
|
|
best_model_data = {
|
|
|
|
|
'model_state_dict': best_model_state,
|
|
|
|
|
'scaler_X': scaler_X,
|
|
|
|
|
'scaler_y': scaler_y,
|
|
|
|
|
'config': {
|
|
|
|
|
'model_type': 'cnn_bilstm_attention',
|
|
|
|
|
'input_dim': input_dim,
|
|
|
|
|
'output_dim': forecast_horizon,
|
|
|
|
|
'sequence_length': sequence_length,
|
|
|
|
|
'features': features
|
|
|
|
|
},
|
|
|
|
|
'epoch': epoch + 1
|
|
|
|
|
}
|
|
|
|
|
best_model_path, _ = model_manager.save_model(
|
|
|
|
|
model_data=best_model_data,
|
|
|
|
|
product_id=product_id,
|
|
|
|
|
model_type='cnn_bilstm_attention',
|
|
|
|
|
store_id=store_id,
|
|
|
|
|
training_mode=training_mode,
|
|
|
|
|
aggregation_method=aggregation_method,
|
|
|
|
|
product_name=product_name,
|
|
|
|
|
version=f"{current_version}_best"
|
|
|
|
|
)
|
2025-07-23 16:55:27 +08:00
|
|
|
|
else:
|
|
|
|
|
patience_counter += 1
|
|
|
|
|
if patience_counter >= patience:
|
|
|
|
|
print(f"🚫 早停触发! 在 epoch {epoch+1} 停止。")
|
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
training_time = time.time() - start_time
|
|
|
|
|
print(f"模型训练完成,耗时: {training_time:.2f}秒")
|
|
|
|
|
|
|
|
|
|
# --- 4. 使用最佳模型进行评估 ---
|
|
|
|
|
if best_model_state:
|
|
|
|
|
model.load_state_dict(best_model_state)
|
|
|
|
|
print("最佳模型已加载用于最终评估。")
|
2025-07-22 15:40:37 +08:00
|
|
|
|
|
|
|
|
|
model.eval()
|
|
|
|
|
with torch.no_grad():
|
|
|
|
|
test_pred_scaled = model(testX)
|
|
|
|
|
|
|
|
|
|
test_pred_unscaled = scaler_y.inverse_transform(test_pred_scaled.numpy())
|
|
|
|
|
test_true_unscaled = scaler_y.inverse_transform(testY.squeeze(-1).numpy())
|
|
|
|
|
|
|
|
|
|
metrics = evaluate_model(test_true_unscaled.flatten(), test_pred_unscaled.flatten())
|
2025-07-23 16:55:27 +08:00
|
|
|
|
metrics['training_time'] = training_time
|
|
|
|
|
metrics['best_val_loss'] = best_val_loss
|
|
|
|
|
metrics['stopped_epoch'] = epoch + 1
|
|
|
|
|
|
|
|
|
|
print("\n最佳模型评估指标:")
|
|
|
|
|
print(f"MSE: {metrics['mse']:.4f}, RMSE: {metrics['rmse']:.4f}, MAE: {metrics['mae']:.4f}, R²: {metrics['r2']:.4f}, MAPE: {metrics['mape']:.2f}%")
|
|
|
|
|
|
2025-07-24 17:55:10 +08:00
|
|
|
|
# --- 5. 保存工件 ---
|
|
|
|
|
|
|
|
|
|
# 准备 scope 和 identifier 以生成标准化的文件名
|
|
|
|
|
scope = training_mode
|
|
|
|
|
if scope == 'product':
|
|
|
|
|
identifier = product_id
|
|
|
|
|
elif scope == 'store':
|
|
|
|
|
identifier = store_id
|
|
|
|
|
elif scope == 'global':
|
|
|
|
|
identifier = aggregation_method
|
|
|
|
|
else:
|
|
|
|
|
identifier = product_name # 后备方案
|
|
|
|
|
|
|
|
|
|
# 绘制带有版本号的损失曲线图
|
2025-07-23 16:55:27 +08:00
|
|
|
|
loss_curve_path = plot_loss_curve(
|
2025-07-24 17:55:10 +08:00
|
|
|
|
train_losses=loss_history['train'],
|
|
|
|
|
val_losses=loss_history['val'],
|
|
|
|
|
model_type='cnn_bilstm_attention',
|
|
|
|
|
scope=scope,
|
|
|
|
|
identifier=identifier,
|
|
|
|
|
version=current_version, # 使用锁定的版本
|
2025-07-23 16:55:27 +08:00
|
|
|
|
model_dir=model_dir
|
|
|
|
|
)
|
2025-07-24 17:55:10 +08:00
|
|
|
|
print(f"📈 带版本号的损失曲线已保存: {loss_curve_path}")
|
2025-07-22 15:40:37 +08:00
|
|
|
|
|
2025-07-24 17:55:10 +08:00
|
|
|
|
# 准备要保存的最终模型数据
|
2025-07-22 15:40:37 +08:00
|
|
|
|
model_data = {
|
2025-07-23 16:55:27 +08:00
|
|
|
|
'model_state_dict': best_model_state, # 保存最佳模型的状态
|
2025-07-22 15:40:37 +08:00
|
|
|
|
'scaler_X': scaler_X,
|
|
|
|
|
'scaler_y': scaler_y,
|
|
|
|
|
'config': {
|
|
|
|
|
'model_type': 'cnn_bilstm_attention',
|
|
|
|
|
'input_dim': input_dim,
|
|
|
|
|
'output_dim': forecast_horizon,
|
|
|
|
|
'sequence_length': sequence_length,
|
|
|
|
|
'features': features
|
|
|
|
|
},
|
2025-07-23 16:55:27 +08:00
|
|
|
|
'metrics': metrics,
|
2025-07-24 17:55:10 +08:00
|
|
|
|
'loss_history': loss_history,
|
|
|
|
|
'loss_curve_path': loss_curve_path # 直接包含路径
|
2025-07-22 15:40:37 +08:00
|
|
|
|
}
|
2025-07-24 17:55:10 +08:00
|
|
|
|
|
|
|
|
|
# 使用模型管理器保存最终模型
|
2025-07-22 15:40:37 +08:00
|
|
|
|
final_model_path, final_version = model_manager.save_model(
|
|
|
|
|
model_data=model_data,
|
|
|
|
|
product_id=product_id,
|
|
|
|
|
model_type='cnn_bilstm_attention',
|
|
|
|
|
store_id=store_id,
|
|
|
|
|
training_mode=training_mode,
|
|
|
|
|
aggregation_method=aggregation_method,
|
2025-07-24 17:55:10 +08:00
|
|
|
|
product_name=product_name,
|
|
|
|
|
version=current_version # 使用锁定的版本号
|
2025-07-23 16:55:27 +08:00
|
|
|
|
)
|
|
|
|
|
print(f"✅ CNN-BiLSTM-Attention 最终模型已保存,版本: {final_version}")
|
|
|
|
|
|
2025-07-24 17:55:10 +08:00
|
|
|
|
# 组装返回的工件
|
|
|
|
|
artifacts = {
|
|
|
|
|
"versioned_model": final_model_path,
|
|
|
|
|
"loss_curve_plot": loss_curve_path,
|
|
|
|
|
"best_model": best_model_path,
|
|
|
|
|
"version": final_version
|
|
|
|
|
}
|
2025-07-22 15:40:37 +08:00
|
|
|
|
|
2025-07-24 17:55:10 +08:00
|
|
|
|
return model, metrics, artifacts
|
2025-07-22 15:40:37 +08:00
|
|
|
|
|
|
|
|
|
# --- 关键步骤: 将训练器注册到系统中 ---
|
|
|
|
|
register_trainer('cnn_bilstm_attention', train_with_cnn_bilstm_attention)
|