# -*- coding: utf-8 -*- """ CNN-BiLSTM-Attention 模型训练器 """ import pandas as pd import torch import torch.optim as optim import numpy as np import time import copy from models.model_registry import register_trainer from utils.model_manager import model_manager from analysis.metrics import evaluate_model from utils.data_utils import prepare_data, prepare_sequences from sklearn.preprocessing import MinMaxScaler from utils.visualization import plot_loss_curve # 导入绘图函数 # 导入新创建的模型 from models.cnn_bilstm_attention import CnnBiLstmAttention def train_with_cnn_bilstm_attention( model_identifier: str, training_df: pd.DataFrame, feature_list: list, training_mode: str, epochs: int, sequence_length: int, forecast_horizon: int, model_dir: str, product_id: str = None, store_id: str = None, aggregation_method: str = None, version: str = None, clip_norm: float = 1.0, # 梯度裁剪的阈值 **kwargs ): """ 使用 CNN-BiLSTM-Attention 模型进行训练 (新数据管道版)。 """ print(f"🚀 CNN-BiLSTM-Attention 训练器启动: model_identifier='{model_identifier}'") start_time = time.time() # --- 1. 数据准备 --- product_name = training_df['product_name'].iloc[0] if 'product_name' in training_df.columns else model_identifier print(f"[Hybrid] 开始数据预处理,使用 {len(feature_list)} 个预选特征...") # 使用标准化的数据准备函数 _, _, trainX, testX, trainY, testY, scaler_X, scaler_y, used_features = prepare_data( training_df=training_df, feature_list=feature_list, target_column='net_sales_quantity', sequence_length=sequence_length, forecast_horizon=forecast_horizon ) # 使用标准化的序列创建函数 batch_size = kwargs.get('batch_size', 32) # 允许从外部传入batch_size train_loader = prepare_sequences(trainX, trainY, batch_size) test_loader = prepare_sequences(testX, testY, batch_size) # --- 2. 实例化模型和优化器 --- input_dim = trainX.shape[2] # 特征数量 model = CnnBiLstmAttention( input_dim=input_dim, output_dim=forecast_horizon, sequence_length=sequence_length ) optimizer = optim.Adam(model.parameters(), lr=kwargs.get('learning_rate', 0.001)) criterion = torch.nn.MSELoss() # --- 3. 训练循环与早停 --- print("开始训练 CNN-BiLSTM-Attention 模型 (含早停)...") # 版本锁定:在训练开始前确定本次训练的版本号 current_version = model_manager.peek_next_version( model_type='cnn_bilstm_attention', product_id=product_id, store_id=store_id, training_mode=training_mode, aggregation_method=aggregation_method ) print(f"🔒 本次训练版本锁定为: {current_version}") loss_history = {'train': [], 'val': []} best_val_loss = float('inf') best_model_state = None best_model_path = None # 用于存储最佳模型的路径 patience = kwargs.get('patience', 15) patience_counter = 0 for epoch in range(epochs): model.train() epoch_train_loss = 0 for X_batch, y_batch in train_loader: optimizer.zero_grad() outputs = model(X_batch) train_loss = criterion(outputs, y_batch.squeeze(-1)) train_loss.backward() # --- 梯度裁剪 --- torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_norm) optimizer.step() epoch_train_loss += train_loss.item() train_loss = epoch_train_loss / len(train_loader) # 验证 model.eval() epoch_val_loss = 0 with torch.no_grad(): for X_batch, y_batch in test_loader: val_outputs = model(X_batch) val_loss = criterion(val_outputs, y_batch.squeeze(-1)) epoch_val_loss += val_loss.item() val_loss = epoch_val_loss / len(test_loader) loss_history['train'].append(train_loss) loss_history['val'].append(val_loss) if (epoch + 1) % 10 == 0: print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}') # 早停逻辑 if val_loss < best_val_loss: best_val_loss = val_loss best_model_state = copy.deepcopy(model.state_dict()) patience_counter = 0 print(f"✨ 新的最佳模型! Epoch: {epoch+1}, Val Loss: {best_val_loss:.4f}") # 立即保存最佳模型 best_model_data = { 'model_state_dict': best_model_state, 'scaler_X': scaler_X, 'scaler_y': scaler_y, 'config': { 'model_type': 'cnn_bilstm_attention', 'input_dim': input_dim, 'output_dim': forecast_horizon, 'sequence_length': sequence_length, 'features': used_features }, 'epoch': epoch + 1 } best_model_path, _ = model_manager.save_model( model_data=best_model_data, product_id=product_id, model_type='cnn_bilstm_attention', store_id=store_id, training_mode=training_mode, aggregation_method=aggregation_method, product_name=product_name, version=f"{current_version}_best" ) else: patience_counter += 1 if patience_counter >= patience: print(f"🚫 早停触发! 在 epoch {epoch+1} 停止。") break training_time = time.time() - start_time print(f"模型训练完成,耗时: {training_time:.2f}秒") # --- 4. 使用最佳模型进行评估 --- if best_model_state: model.load_state_dict(best_model_state) print("最佳模型已加载用于最终评估。") model.eval() with torch.no_grad(): all_test_X = [] all_test_Y = [] for X_batch, y_batch in test_loader: all_test_X.append(X_batch) all_test_Y.append(y_batch) testX_tensor = torch.cat(all_test_X, dim=0) testY_tensor = torch.cat(all_test_Y, dim=0) test_pred_scaled = model(testX_tensor) test_pred_unscaled = scaler_y.inverse_transform(test_pred_scaled.numpy()) test_true_unscaled = scaler_y.inverse_transform(testY_tensor.squeeze(-1).numpy()) metrics = evaluate_model(test_true_unscaled.flatten(), test_pred_unscaled.flatten()) metrics['training_time'] = training_time metrics['best_val_loss'] = best_val_loss metrics['stopped_epoch'] = epoch + 1 print("\n最佳模型评估指标:") print(f"MSE: {metrics['mse']:.4f}, RMSE: {metrics['rmse']:.4f}, MAE: {metrics['mae']:.4f}, R²: {metrics['r2']:.4f}, MAPE: {metrics['mape']:.2f}%") # --- 5. 保存工件 --- # 准备 scope 和 identifier 以生成标准化的文件名 scope = training_mode if scope == 'product': identifier = product_id elif scope == 'store': identifier = store_id elif scope == 'global': identifier = aggregation_method else: identifier = product_name # 后备方案 # 绘制带有版本号的损失曲线图 loss_curve_path = plot_loss_curve( train_losses=loss_history['train'], val_losses=loss_history['val'], model_type='cnn_bilstm_attention', scope=scope, identifier=identifier, version=current_version, # 使用锁定的版本 model_dir=model_dir ) print(f"📈 带版本号的损失曲线已保存: {loss_curve_path}") # 准备要保存的最终模型数据 model_data = { 'model_state_dict': best_model_state, # 保存最佳模型的状态 'scaler_X': scaler_X, 'scaler_y': scaler_y, 'config': { 'model_type': 'cnn_bilstm_attention', 'input_dim': input_dim, 'output_dim': forecast_horizon, 'sequence_length': sequence_length, 'features': used_features }, 'metrics': metrics, 'loss_history': loss_history, 'loss_curve_path': loss_curve_path # 直接包含路径 } # 使用模型管理器保存最终模型 final_model_path, final_version = model_manager.save_model( model_data=model_data, product_id=product_id, model_type='cnn_bilstm_attention', store_id=store_id, training_mode=training_mode, aggregation_method=aggregation_method, product_name=product_name, version=current_version # 使用锁定的版本号 ) print(f"✅ CNN-BiLSTM-Attention 最终模型已保存,版本: {final_version}") # 组装返回的工件 artifacts = { "versioned_model": final_model_path, "loss_curve_plot": loss_curve_path, "best_model": best_model_path, "version": final_version } return metrics, artifacts # --- 关键步骤: 将训练器注册到系统中 --- register_trainer('cnn_bilstm_attention', train_with_cnn_bilstm_attention)