主要变更和成果: 实现了原子化的模型保存逻辑: 在核心的 PharmacyPredictor 类 (ShopTRAINING/server/core/predictor.py) 中,我们引入了一个新的私有方法 _save_model_transactional。 此方法将所有与模型保存相关的操作(包括所有文件工件的磁盘写入和元数据到数据库的写入)封装成一个单一的事务性操作。 通过使用 try...except 块,我们确保了只有在数据库成功记录元数据后,整个保存操作才被视为成功。 引入了自动回滚机制: 如果在 _save_model_transactional 方法执行期间发生任何错误(例如,数据库连接失败、磁盘空间不足等),except 块会触发回滚。 回滚操作由 _rollback_files 方法执行,它会负责删除在该次失败的训练过程中已经生成的所有文件(如模型文件 .pth、损失曲线图 .png 和损失数据 .json)。 这保证了文件系统不会留下任何与数据库记录不匹配的“孤儿”文件,从而维护了系统的整洁和数据一致性。 重构了核心训练流程: 修改了 PharmacyPredictor 类中的 train_model 方法,使其在训练器成功返回结果后,不再直接返回,而是调用新的 _save_model_transactional 方法。 这使得保存逻辑从 api.py 的路由处理函数中解耦,并集中到负责业务流程编排的 predictor 核心类中,使代码结构更清晰,职责更分明。 相应地,api.py 中的后台训练任务逻辑也进行了简化,它现在只需信任 predictor 返回的结果,无需再关心具体的保存细节。
279 lines
9.8 KiB
Python
279 lines
9.8 KiB
Python
# -*- coding: utf-8 -*-
|
||
"""
|
||
CNN-BiLSTM-Attention 模型训练器
|
||
"""
|
||
|
||
import pandas as pd
|
||
import torch
|
||
import torch.optim as optim
|
||
import numpy as np
|
||
import os
|
||
import time
|
||
import copy
|
||
|
||
from models.model_registry import register_trainer
|
||
from utils.model_manager import model_manager
|
||
from analysis.metrics import evaluate_model
|
||
from utils.data_utils import prepare_data, prepare_sequences
|
||
from sklearn.preprocessing import MinMaxScaler
|
||
from utils.visualization import plot_loss_curve # 导入绘图函数
|
||
import json # 导入json库
|
||
|
||
# 导入新创建的模型
|
||
from models.cnn_bilstm_attention import CnnBiLstmAttention
|
||
|
||
def train_with_cnn_bilstm_attention(
|
||
model_identifier: str,
|
||
training_df: pd.DataFrame,
|
||
feature_list: list,
|
||
training_mode: str,
|
||
epochs: int,
|
||
sequence_length: int,
|
||
forecast_horizon: int,
|
||
model_dir: str,
|
||
product_id: str = None,
|
||
store_id: str = None,
|
||
aggregation_method: str = None,
|
||
version: str = None,
|
||
clip_norm: float = 1.0, # 梯度裁剪的阈值
|
||
**kwargs
|
||
):
|
||
"""
|
||
使用 CNN-BiLSTM-Attention 模型进行训练 (新数据管道版)。
|
||
"""
|
||
print(f"🚀 CNN-BiLSTM-Attention 训练器启动: model_identifier='{model_identifier}'")
|
||
start_time = time.time()
|
||
|
||
# --- 1. 数据准备 ---
|
||
product_name = training_df['product_name'].iloc[0] if 'product_name' in training_df.columns else model_identifier
|
||
print(f"[Hybrid] 开始数据预处理,使用 {len(feature_list)} 个预选特征...")
|
||
|
||
# 使用标准化的数据准备函数
|
||
_, _, trainX, testX, trainY, testY, scaler_X, scaler_y, used_features = prepare_data(
|
||
training_df=training_df,
|
||
feature_list=feature_list,
|
||
target_column='net_sales_quantity',
|
||
sequence_length=sequence_length,
|
||
forecast_horizon=forecast_horizon
|
||
)
|
||
|
||
# 使用标准化的序列创建函数
|
||
batch_size = kwargs.get('batch_size', 32) # 允许从外部传入batch_size
|
||
train_loader = prepare_sequences(trainX, trainY, batch_size)
|
||
test_loader = prepare_sequences(testX, testY, batch_size)
|
||
|
||
# --- 2. 实例化模型和优化器 ---
|
||
input_dim = trainX.shape[2] # 特征数量
|
||
|
||
model = CnnBiLstmAttention(
|
||
input_dim=input_dim,
|
||
output_dim=forecast_horizon,
|
||
sequence_length=sequence_length
|
||
)
|
||
|
||
optimizer = optim.Adam(model.parameters(), lr=kwargs.get('learning_rate', 0.001))
|
||
criterion = torch.nn.MSELoss()
|
||
|
||
# --- 3. 训练循环与早停 ---
|
||
print("开始训练 CNN-BiLSTM-Attention 模型 (含早停)...")
|
||
|
||
# 版本锁定:在训练开始前确定本次训练的版本号
|
||
current_version = model_manager.peek_next_version(
|
||
model_type='cnn_bilstm_attention',
|
||
product_id=product_id,
|
||
store_id=store_id,
|
||
training_mode=training_mode,
|
||
aggregation_method=aggregation_method
|
||
)
|
||
print(f"🔒 本次训练版本锁定为: {current_version}")
|
||
|
||
loss_history = {'train': [], 'val': []}
|
||
best_val_loss = float('inf')
|
||
best_model_state = None
|
||
best_model_path = None # 用于存储最佳模型的路径
|
||
patience = kwargs.get('patience', 15)
|
||
patience_counter = 0
|
||
|
||
for epoch in range(epochs):
|
||
model.train()
|
||
epoch_train_loss = 0
|
||
for X_batch, y_batch in train_loader:
|
||
optimizer.zero_grad()
|
||
outputs = model(X_batch)
|
||
train_loss = criterion(outputs, y_batch.squeeze(-1))
|
||
train_loss.backward()
|
||
# --- 梯度裁剪 ---
|
||
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_norm)
|
||
optimizer.step()
|
||
epoch_train_loss += train_loss.item()
|
||
train_loss = epoch_train_loss / len(train_loader)
|
||
|
||
# 验证
|
||
model.eval()
|
||
epoch_val_loss = 0
|
||
with torch.no_grad():
|
||
for X_batch, y_batch in test_loader:
|
||
val_outputs = model(X_batch)
|
||
val_loss = criterion(val_outputs, y_batch.squeeze(-1))
|
||
epoch_val_loss += val_loss.item()
|
||
val_loss = epoch_val_loss / len(test_loader)
|
||
|
||
loss_history['train'].append(train_loss)
|
||
loss_history['val'].append(val_loss)
|
||
|
||
if (epoch + 1) % 10 == 0:
|
||
print(f'Epoch [{epoch+1}/{epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
|
||
|
||
# 早停逻辑
|
||
if val_loss < best_val_loss:
|
||
best_val_loss = val_loss
|
||
best_model_state = copy.deepcopy(model.state_dict())
|
||
patience_counter = 0
|
||
print(f"✨ 新的最佳模型! Epoch: {epoch+1}, Val Loss: {best_val_loss:.4f}")
|
||
|
||
# 立即保存最佳模型
|
||
best_model_data = {
|
||
'model_state_dict': best_model_state,
|
||
'scaler_X': scaler_X,
|
||
'scaler_y': scaler_y,
|
||
'config': {
|
||
'model_type': 'cnn_bilstm_attention',
|
||
'input_dim': input_dim,
|
||
'output_dim': forecast_horizon,
|
||
'sequence_length': sequence_length,
|
||
'features': used_features
|
||
},
|
||
'epoch': epoch + 1
|
||
}
|
||
best_model_path, _ = model_manager.save_model(
|
||
model_data=best_model_data,
|
||
product_id=product_id,
|
||
model_type='cnn_bilstm_attention',
|
||
store_id=store_id,
|
||
training_mode=training_mode,
|
||
aggregation_method=aggregation_method,
|
||
product_name=product_name,
|
||
version=f"{current_version}_best"
|
||
)
|
||
else:
|
||
patience_counter += 1
|
||
if patience_counter >= patience:
|
||
print(f"🚫 早停触发! 在 epoch {epoch+1} 停止。")
|
||
break
|
||
|
||
training_time = time.time() - start_time
|
||
print(f"模型训练完成,耗时: {training_time:.2f}秒")
|
||
|
||
# --- 4. 使用最佳模型进行评估 ---
|
||
if best_model_state:
|
||
model.load_state_dict(best_model_state)
|
||
print("最佳模型已加载用于最终评估。")
|
||
|
||
model.eval()
|
||
with torch.no_grad():
|
||
all_test_X = []
|
||
all_test_Y = []
|
||
for X_batch, y_batch in test_loader:
|
||
all_test_X.append(X_batch)
|
||
all_test_Y.append(y_batch)
|
||
|
||
testX_tensor = torch.cat(all_test_X, dim=0)
|
||
testY_tensor = torch.cat(all_test_Y, dim=0)
|
||
|
||
test_pred_scaled = model(testX_tensor)
|
||
|
||
test_pred_unscaled = scaler_y.inverse_transform(test_pred_scaled.numpy())
|
||
test_true_unscaled = scaler_y.inverse_transform(testY_tensor.squeeze(-1).numpy())
|
||
|
||
metrics = evaluate_model(test_true_unscaled.flatten(), test_pred_unscaled.flatten())
|
||
metrics['training_time'] = training_time
|
||
metrics['best_val_loss'] = best_val_loss
|
||
metrics['stopped_epoch'] = epoch + 1
|
||
|
||
print("\n最佳模型评估指标:")
|
||
print(f"MSE: {metrics['mse']:.4f}, RMSE: {metrics['rmse']:.4f}, MAE: {metrics['mae']:.4f}, R²: {metrics['r2']:.4f}, MAPE: {metrics['mape']:.2f}%")
|
||
|
||
# --- 5. 保存工件 ---
|
||
|
||
# 准备 scope 和 identifier 以生成标准化的文件名
|
||
scope = training_mode
|
||
if scope == 'product':
|
||
identifier = product_id
|
||
elif scope == 'store':
|
||
identifier = store_id
|
||
elif scope == 'global':
|
||
identifier = aggregation_method
|
||
else:
|
||
identifier = product_name # 后备方案
|
||
|
||
# 绘制带有版本号的损失曲线图
|
||
loss_curve_path = plot_loss_curve(
|
||
train_losses=loss_history['train'],
|
||
val_losses=loss_history['val'],
|
||
model_type='cnn_bilstm_attention',
|
||
scope=scope,
|
||
identifier=identifier,
|
||
version=current_version, # 使用锁定的版本
|
||
model_dir=model_dir
|
||
)
|
||
print(f"📈 带版本号的损失曲线已保存: {loss_curve_path}")
|
||
|
||
# 准备要保存的最终模型数据
|
||
model_data = {
|
||
'model_state_dict': best_model_state, # 保存最佳模型的状态
|
||
'scaler_X': scaler_X,
|
||
'scaler_y': scaler_y,
|
||
'config': {
|
||
'model_type': 'cnn_bilstm_attention',
|
||
'input_dim': input_dim,
|
||
'output_dim': forecast_horizon,
|
||
'sequence_length': sequence_length,
|
||
'features': used_features
|
||
},
|
||
'metrics': metrics,
|
||
'loss_history': loss_history,
|
||
'loss_curve_path': loss_curve_path # 直接包含路径
|
||
}
|
||
|
||
# 使用模型管理器保存最终模型
|
||
final_model_path, final_version = model_manager.save_model(
|
||
model_data=model_data,
|
||
product_id=product_id,
|
||
model_type='cnn_bilstm_attention',
|
||
store_id=store_id,
|
||
training_mode=training_mode,
|
||
aggregation_method=aggregation_method,
|
||
product_name=product_name,
|
||
version=current_version # 使用锁定的版本号
|
||
)
|
||
print(f"✅ CNN-BiLSTM-Attention 最终模型已保存,版本: {final_version}")
|
||
|
||
# --- 新增:保存损失历史为JSON文件 ---
|
||
loss_data_filename = f"{identifier}_{current_version}_loss_curve_data.json"
|
||
loss_data_path = os.path.join(model_dir, loss_data_filename)
|
||
|
||
# 准备要保存的数据
|
||
loss_data_to_save = {
|
||
'epochs': list(range(1, len(loss_history['train']) + 1)),
|
||
'train_loss': loss_history['train'],
|
||
'test_loss': loss_history['val']
|
||
}
|
||
|
||
with open(loss_data_path, 'w') as f:
|
||
json.dump(loss_data_to_save, f)
|
||
print(f"💾 损失历史数据已保存: {loss_data_path}")
|
||
|
||
# 组装返回的工件
|
||
artifacts = {
|
||
"versioned_model": final_model_path,
|
||
"loss_curve_plot": loss_curve_path,
|
||
"loss_curve_data": loss_data_path, # 新增字段
|
||
"best_model": best_model_path,
|
||
"version": final_version,
|
||
'loss_history': loss_history
|
||
}
|
||
|
||
return metrics, artifacts
|
||
|
||
# --- 关键步骤: 将训练器注册到系统中 ---
|
||
register_trainer('cnn_bilstm_attention', train_with_cnn_bilstm_attention) |