# -*- coding: utf-8 -*- """ CNN-BiLSTM-Attention 模型训练器 """ import torch import torch.optim as optim import numpy as np from models.model_registry import register_trainer from utils.model_manager import model_manager from analysis.metrics import evaluate_model from utils.data_utils import create_dataset from sklearn.preprocessing import MinMaxScaler # 导入新创建的模型 from models.cnn_bilstm_attention import CnnBiLstmAttention def train_with_cnn_bilstm_attention(product_id, model_identifier, product_df, store_id, training_mode, aggregation_method, epochs, sequence_length, forecast_horizon, model_dir, **kwargs): """ 使用 CNN-BiLSTM-Attention 模型进行训练。 函数签名遵循系统标准。 """ print(f"🚀 CNN-BiLSTM-Attention 训练器启动: model_identifier='{model_identifier}'") # --- 1. 数据准备 --- features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'] X = product_df[features].values y = product_df[['sales']].values scaler_X = MinMaxScaler(feature_range=(0, 1)) scaler_y = MinMaxScaler(feature_range=(0, 1)) X_scaled = scaler_X.fit_transform(X) y_scaled = scaler_y.fit_transform(y) train_size = int(len(X_scaled) * 0.8) X_train_raw, X_test_raw = X_scaled[:train_size], X_scaled[train_size:] y_train_raw, y_test_raw = y_scaled[:train_size], y_scaled[train_size:] trainX, trainY = create_dataset(X_train_raw, y_train_raw, sequence_length, forecast_horizon) testX, testY = create_dataset(X_test_raw, y_test_raw, sequence_length, forecast_horizon) # 转换为 PyTorch Tensors trainX = torch.from_numpy(trainX).float() trainY = torch.from_numpy(trainY).float() testX = torch.from_numpy(testX).float() testY = torch.from_numpy(testY).float() # --- 2. 实例化模型和优化器 --- input_dim = trainX.shape[2] model = CnnBiLstmAttention( input_dim=input_dim, output_dim=forecast_horizon, sequence_length=sequence_length ) optimizer = optim.Adam(model.parameters(), lr=kwargs.get('learning_rate', 0.001)) criterion = torch.nn.MSELoss() # --- 3. 训练循环 --- print("开始训练 CNN-BiLSTM-Attention 模型...") for epoch in range(epochs): model.train() optimizer.zero_grad() outputs = model(trainX) loss = criterion(outputs, trainY.squeeze(-1)) # 确保目标维度匹配 loss.backward() optimizer.step() if (epoch + 1) % 10 == 0: print(f'Epoch [{epoch+1}/{epochs}], Loss: {loss.item():.4f}') # --- 4. 模型评估 --- model.eval() with torch.no_grad(): test_pred_scaled = model(testX) test_pred_unscaled = scaler_y.inverse_transform(test_pred_scaled.numpy()) test_true_unscaled = scaler_y.inverse_transform(testY.squeeze(-1).numpy()) metrics = evaluate_model(test_true_unscaled.flatten(), test_pred_unscaled.flatten()) print(f"模型评估完成: RMSE={metrics['rmse']:.4f}") # --- 5. 模型保存 --- model_data = { 'model_state_dict': model.state_dict(), 'scaler_X': scaler_X, 'scaler_y': scaler_y, 'config': { 'model_type': 'cnn_bilstm_attention', 'input_dim': input_dim, 'output_dim': forecast_horizon, 'sequence_length': sequence_length, 'features': features }, 'metrics': metrics } final_model_path, final_version = model_manager.save_model( model_data=model_data, product_id=product_id, model_type='cnn_bilstm_attention', store_id=store_id, training_mode=training_mode, aggregation_method=aggregation_method, product_name=product_df['product_name'].iloc[0] ) print(f"✅ CNN-BiLSTM-Attention 模型已保存,版本: {final_version}") return model, metrics, final_version, final_model_path # --- 关键步骤: 将训练器注册到系统中 --- register_trainer('cnn_bilstm_attention', train_with_cnn_bilstm_attention)