ShopTRAINING/server/trainers/cnn_bilstm_attention_trainer.py
2025-07-22 15:41:05 +08:00

118 lines
4.0 KiB
Python

# -*- coding: utf-8 -*-
"""
CNN-BiLSTM-Attention 模型训练器
"""
import torch
import torch.optim as optim
import numpy as np
from models.model_registry import register_trainer
from utils.model_manager import model_manager
from analysis.metrics import evaluate_model
from utils.data_utils import create_dataset
from sklearn.preprocessing import MinMaxScaler
# 导入新创建的模型
from models.cnn_bilstm_attention import CnnBiLstmAttention
def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs):
"""
使用 CNN-BiLSTM-Attention 模型进行训练。
函数签名遵循系统标准。
"""
print(f"🚀 CNN-BiLSTM-Attention 训练器启动: model_identifier='{model_identifier}'")
# --- 1. 数据准备 ---
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
X = product_df[features].values
y = product_df[['sales']].values
scaler_X = MinMaxScaler(feature_range=(0, 1))
scaler_y = MinMaxScaler(feature_range=(0, 1))
X_scaled = scaler_X.fit_transform(X)
y_scaled = scaler_y.fit_transform(y)
train_size = int(len(X_scaled) * 0.8)
X_train_raw, X_test_raw = X_scaled[:train_size], X_scaled[train_size:]
y_train_raw, y_test_raw = y_scaled[:train_size], y_scaled[train_size:]
trainX, trainY = create_dataset(X_train_raw, y_train_raw, sequence_length, forecast_horizon)
testX, testY = create_dataset(X_test_raw, y_test_raw, sequence_length, forecast_horizon)
# 转换为 PyTorch Tensors
trainX = torch.from_numpy(trainX).float()
trainY = torch.from_numpy(trainY).float()
testX = torch.from_numpy(testX).float()
testY = torch.from_numpy(testY).float()
# --- 2. 实例化模型和优化器 ---
input_dim = trainX.shape[2]
model = CnnBiLstmAttention(
input_dim=input_dim,
output_dim=forecast_horizon,
sequence_length=sequence_length
)
optimizer = optim.Adam(model.parameters(), lr=kwargs.get('learning_rate', 0.001))
criterion = torch.nn.MSELoss()
# --- 3. 训练循环 ---
print("开始训练 CNN-BiLSTM-Attention 模型...")
for epoch in range(epochs):
model.train()
optimizer.zero_grad()
outputs = model(trainX)
loss = criterion(outputs, trainY.squeeze(-1)) # 确保目标维度匹配
loss.backward()
optimizer.step()
if (epoch + 1) % 10 == 0:
print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}')
# --- 4. 模型评估 ---
model.eval()
with torch.no_grad():
test_pred_scaled = model(testX)
test_pred_unscaled = scaler_y.inverse_transform(test_pred_scaled.numpy())
test_true_unscaled = scaler_y.inverse_transform(testY.squeeze(-1).numpy())
metrics = evaluate_model(test_true_unscaled.flatten(), test_pred_unscaled.flatten())
print(f"模型评估完成: RMSE={metrics['rmse']:.4f}")
# --- 5. 模型保存 ---
model_data = {
'model_state_dict': model.state_dict(),
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
'model_type': 'cnn_bilstm_attention',
'input_dim': input_dim,
'output_dim': forecast_horizon,
'sequence_length': sequence_length,
'features': features
},
'metrics': metrics
}
final_model_path, final_version = model_manager.save_model(
model_data=model_data,
product_id=product_id,
model_type='cnn_bilstm_attention',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_df['product_name'].iloc[0]
)
print(f"✅ CNN-BiLSTM-Attention 模型已保存,版本: {final_version}")
return model, metrics, final_version, final_model_path
# --- 关键步骤: 将训练器注册到系统中 ---
register_trainer('cnn_bilstm_attention', train_with_cnn_bilstm_attention)