ShopTRAINING/server/trainers/cnn_bilstm_attention_trainer.py
xz2000 348741d49c 解决模型保存过程中的原子性问题,以确保系统在遇到任何错误时都能保持数据的一致性,避免产生“孤儿文件”或不完整的数据库记录。
主要变更和成果:

实现了原子化的模型保存逻辑:

在核心的 PharmacyPredictor 类 (ShopTRAINING/server/core/predictor.py) 中,我们引入了一个新的私有方法 _save_model_transactional。
此方法将所有与模型保存相关的操作(包括所有文件工件的磁盘写入和元数据到数据库的写入)封装成一个单一的事务性操作。
通过使用 try...except 块,我们确保了只有在数据库成功记录元数据后,整个保存操作才被视为成功。
引入了自动回滚机制:

如果在 _save_model_transactional 方法执行期间发生任何错误(例如,数据库连接失败、磁盘空间不足等),except 块会触发回滚。
回滚操作由 _rollback_files 方法执行,它会负责删除在该次失败的训练过程中已经生成的所有文件(如模型文件 .pth、损失曲线图 .png 和损失数据 .json)。
这保证了文件系统不会留下任何与数据库记录不匹配的“孤儿”文件,从而维护了系统的整洁和数据一致性。
重构了核心训练流程:

修改了 PharmacyPredictor 类中的 train_model 方法,使其在训练器成功返回结果后,不再直接返回,而是调用新的 _save_model_transactional 方法。
这使得保存逻辑从 api.py 的路由处理函数中解耦,并集中到负责业务流程编排的 predictor 核心类中,使代码结构更清晰,职责更分明。
相应地,api.py 中的后台训练任务逻辑也进行了简化,它现在只需信任 predictor 返回的结果,无需再关心具体的保存细节。
2025-07-29 16:30:39 +08:00

279 lines
9.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# -*- coding: utf-8 -*-
"""
CNN-BiLSTM-Attention 模型训练器
"""
import pandas as pd
import torch
import torch.optim as optim
import numpy as np
import os
import time
import copy
from models.model_registry import register_trainer
from utils.model_manager import model_manager
from analysis.metrics import evaluate_model
from utils.data_utils import prepare_data, prepare_sequences
from sklearn.preprocessing import MinMaxScaler
from utils.visualization import plot_loss_curve # 导入绘图函数
import json # 导入json库
# 导入新创建的模型
from models.cnn_bilstm_attention import CnnBiLstmAttention
def train_with_cnn_bilstm_attention(
model_identifier: str,
training_df: pd.DataFrame,
feature_list: list,
training_mode: str,
epochs: int,
sequence_length: int,
forecast_horizon: int,
model_dir: str,
product_id: str = None,
store_id: str = None,
aggregation_method: str = None,
version: str = None,
clip_norm: float = 1.0, # 梯度裁剪的阈值
**kwargs
):
"""
使用 CNN-BiLSTM-Attention 模型进行训练 (新数据管道版)。
"""
print(f"🚀 CNN-BiLSTM-Attention 训练器启动: model_identifier='{model_identifier}'")
start_time = time.time()
# --- 1. 数据准备 ---
product_name = training_df['product_name'].iloc[0] if 'product_name' in training_df.columns else model_identifier
print(f"[Hybrid] 开始数据预处理,使用 {len(feature_list)} 个预选特征...")
# 使用标准化的数据准备函数
_, _, trainX, testX, trainY, testY, scaler_X, scaler_y, used_features = prepare_data(
training_df=training_df,
feature_list=feature_list,
target_column='net_sales_quantity',
sequence_length=sequence_length,
forecast_horizon=forecast_horizon
)
# 使用标准化的序列创建函数
batch_size = kwargs.get('batch_size', 32) # 允许从外部传入batch_size
train_loader = prepare_sequences(trainX, trainY, batch_size)
test_loader = prepare_sequences(testX, testY, batch_size)
# --- 2. 实例化模型和优化器 ---
input_dim = trainX.shape[2] # 特征数量
model = CnnBiLstmAttention(
input_dim=input_dim,
output_dim=forecast_horizon,
sequence_length=sequence_length
)
optimizer = optim.Adam(model.parameters(), lr=kwargs.get('learning_rate', 0.001))
criterion = torch.nn.MSELoss()
# --- 3. 训练循环与早停 ---
print("开始训练 CNN-BiLSTM-Attention 模型 (含早停)...")
# 版本锁定:在训练开始前确定本次训练的版本号
current_version = model_manager.peek_next_version(
model_type='cnn_bilstm_attention',
product_id=product_id,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method
)
print(f"🔒 本次训练版本锁定为: {current_version}")
loss_history = {'train': [], 'val': []}
best_val_loss = float('inf')
best_model_state = None
best_model_path = None # 用于存储最佳模型的路径
patience = kwargs.get('patience', 15)
patience_counter = 0
for epoch in range(epochs):
model.train()
epoch_train_loss = 0
for X_batch, y_batch in train_loader:
optimizer.zero_grad()
outputs = model(X_batch)
train_loss = criterion(outputs, y_batch.squeeze(-1))
train_loss.backward()
# --- 梯度裁剪 ---
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_norm)
optimizer.step()
epoch_train_loss += train_loss.item()
train_loss = epoch_train_loss / len(train_loader)
# 验证
model.eval()
epoch_val_loss = 0
with torch.no_grad():
for X_batch, y_batch in test_loader:
val_outputs = model(X_batch)
val_loss = criterion(val_outputs, y_batch.squeeze(-1))
epoch_val_loss += val_loss.item()
val_loss = epoch_val_loss / len(test_loader)
loss_history['train'].append(train_loss)
loss_history['val'].append(val_loss)
if (epoch + 1) % 10 == 0:
print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
# 早停逻辑
if val_loss < best_val_loss:
best_val_loss = val_loss
best_model_state = copy.deepcopy(model.state_dict())
patience_counter = 0
print(f"✨ 新的最佳模型! Epoch: {epoch+1}, Val Loss: {best_val_loss:.4f}")
# 立即保存最佳模型
best_model_data = {
'model_state_dict': best_model_state,
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
'model_type': 'cnn_bilstm_attention',
'input_dim': input_dim,
'output_dim': forecast_horizon,
'sequence_length': sequence_length,
'features': used_features
},
'epoch': epoch + 1
}
best_model_path, _ = model_manager.save_model(
model_data=best_model_data,
product_id=product_id,
model_type='cnn_bilstm_attention',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name,
version=f"{current_version}_best"
)
else:
patience_counter += 1
if patience_counter >= patience:
print(f"🚫 早停触发! 在 epoch {epoch+1} 停止。")
break
training_time = time.time() - start_time
print(f"模型训练完成,耗时: {training_time:.2f}")
# --- 4. 使用最佳模型进行评估 ---
if best_model_state:
model.load_state_dict(best_model_state)
print("最佳模型已加载用于最终评估。")
model.eval()
with torch.no_grad():
all_test_X = []
all_test_Y = []
for X_batch, y_batch in test_loader:
all_test_X.append(X_batch)
all_test_Y.append(y_batch)
testX_tensor = torch.cat(all_test_X, dim=0)
testY_tensor = torch.cat(all_test_Y, dim=0)
test_pred_scaled = model(testX_tensor)
test_pred_unscaled = scaler_y.inverse_transform(test_pred_scaled.numpy())
test_true_unscaled = scaler_y.inverse_transform(testY_tensor.squeeze(-1).numpy())
metrics = evaluate_model(test_true_unscaled.flatten(), test_pred_unscaled.flatten())
metrics['training_time'] = training_time
metrics['best_val_loss'] = best_val_loss
metrics['stopped_epoch'] = epoch + 1
print("\n最佳模型评估指标:")
print(f"MSE: {metrics['mse']:.4f}, RMSE: {metrics['rmse']:.4f}, MAE: {metrics['mae']:.4f}, R²: {metrics['r2']:.4f}, MAPE: {metrics['mape']:.2f}%")
# --- 5. 保存工件 ---
# 准备 scope 和 identifier 以生成标准化的文件名
scope = training_mode
if scope == 'product':
identifier = product_id
elif scope == 'store':
identifier = store_id
elif scope == 'global':
identifier = aggregation_method
else:
identifier = product_name # 后备方案
# 绘制带有版本号的损失曲线图
loss_curve_path = plot_loss_curve(
train_losses=loss_history['train'],
val_losses=loss_history['val'],
model_type='cnn_bilstm_attention',
scope=scope,
identifier=identifier,
version=current_version, # 使用锁定的版本
model_dir=model_dir
)
print(f"📈 带版本号的损失曲线已保存: {loss_curve_path}")
# 准备要保存的最终模型数据
model_data = {
'model_state_dict': best_model_state, # 保存最佳模型的状态
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
'model_type': 'cnn_bilstm_attention',
'input_dim': input_dim,
'output_dim': forecast_horizon,
'sequence_length': sequence_length,
'features': used_features
},
'metrics': metrics,
'loss_history': loss_history,
'loss_curve_path': loss_curve_path # 直接包含路径
}
# 使用模型管理器保存最终模型
final_model_path, final_version = model_manager.save_model(
model_data=model_data,
product_id=product_id,
model_type='cnn_bilstm_attention',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name,
version=current_version # 使用锁定的版本号
)
print(f"✅ CNN-BiLSTM-Attention 最终模型已保存,版本: {final_version}")
# --- 新增保存损失历史为JSON文件 ---
loss_data_filename = f"{identifier}_{current_version}_loss_curve_data.json"
loss_data_path = os.path.join(model_dir, loss_data_filename)
# 准备要保存的数据
loss_data_to_save = {
'epochs': list(range(1, len(loss_history['train']) + 1)),
'train_loss': loss_history['train'],
'test_loss': loss_history['val']
}
with open(loss_data_path, 'w') as f:
json.dump(loss_data_to_save, f)
print(f"💾 损失历史数据已保存: {loss_data_path}")
# 组装返回的工件
artifacts = {
"versioned_model": final_model_path,
"loss_curve_plot": loss_curve_path,
"loss_curve_data": loss_data_path, # 新增字段
"best_model": best_model_path,
"version": final_version,
'loss_history': loss_history
}
return metrics, artifacts
# --- 关键步骤: 将训练器注册到系统中 ---
register_trainer('cnn_bilstm_attention', train_with_cnn_bilstm_attention)