ShopTRAINING/server/trainers/tcn_trainer.py
xz2000 348741d49c 解决模型保存过程中的原子性问题,以确保系统在遇到任何错误时都能保持数据的一致性,避免产生“孤儿文件”或不完整的数据库记录。
主要变更和成果:

实现了原子化的模型保存逻辑:

在核心的 PharmacyPredictor 类 (ShopTRAINING/server/core/predictor.py) 中,我们引入了一个新的私有方法 _save_model_transactional。
此方法将所有与模型保存相关的操作(包括所有文件工件的磁盘写入和元数据到数据库的写入)封装成一个单一的事务性操作。
通过使用 try...except 块,我们确保了只有在数据库成功记录元数据后,整个保存操作才被视为成功。
引入了自动回滚机制:

如果在 _save_model_transactional 方法执行期间发生任何错误(例如,数据库连接失败、磁盘空间不足等),except 块会触发回滚。
回滚操作由 _rollback_files 方法执行,它会负责删除在该次失败的训练过程中已经生成的所有文件(如模型文件 .pth、损失曲线图 .png 和损失数据 .json)。
这保证了文件系统不会留下任何与数据库记录不匹配的“孤儿”文件,从而维护了系统的整洁和数据一致性。
重构了核心训练流程:

修改了 PharmacyPredictor 类中的 train_model 方法,使其在训练器成功返回结果后,不再直接返回,而是调用新的 _save_model_transactional 方法。
这使得保存逻辑从 api.py 的路由处理函数中解耦,并集中到负责业务流程编排的 predictor 核心类中,使代码结构更清晰,职责更分明。
相应地,api.py 中的后台训练任务逻辑也进行了简化,它现在只需信任 predictor 返回的结果,无需再关心具体的保存细节。
2025-07-29 16:30:39 +08:00

411 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
药店销售预测系统 - TCN模型训练函数
"""
import os
import time
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
from tqdm import tqdm
import json
from models.tcn_model import TCNForecaster
from utils.data_utils import prepare_data, PharmacyDataset, prepare_sequences
from utils.visualization import plot_loss_curve
from analysis.metrics import evaluate_model
from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
from utils.training_progress import progress_manager
from utils.model_manager import model_manager
def train_product_model_with_tcn(
model_identifier: str,
training_df: pd.DataFrame,
feature_list: list,
training_mode: str,
epochs: int = 50,
sequence_length: int = LOOK_BACK,
forecast_horizon: int = FORECAST_HORIZON,
model_dir: str = DEFAULT_MODEL_DIR,
product_id: str = None,
store_id: str = None,
aggregation_method: str = None,
version: str = None,
socketio=None,
task_id: str = None,
progress_callback=None,
**kwargs
):
"""
使用TCN模型训练产品销售预测模型 (新数据管道版)
"""
def emit_progress(message, progress=None, metrics=None):
"""发送训练进度到前端"""
if socketio and task_id:
data = {
'task_id': task_id,
'message': message,
'timestamp': time.time()
}
if progress is not None:
data['progress'] = progress
if metrics is not None:
data['metrics'] = metrics
socketio.emit('training_progress', data, namespace='/training')
emit_progress(f"开始训练 TCN 模型")
min_required_samples = sequence_length + forecast_horizon
if len(training_df) < min_required_samples:
error_msg = f"训练数据不足: 需要 {min_required_samples} 条记录, 但只有 {len(training_df)} 条。"
emit_progress(error_msg)
raise ValueError(error_msg)
product_name = training_df['product_name'].iloc[0] if 'product_name' in training_df.columns else model_identifier
emit_progress(f"开始为 '{product_name}' (标识: {model_identifier}) 训练TCN模型")
# --- 新数据管道核心改造 ---
emit_progress("数据预处理中...")
# 1. 使用标准化的 prepare_data 函数处理数据
_, _, trainX, testX, trainY, testY, scaler_X, scaler_y, used_features = prepare_data(
training_df=training_df,
feature_list=feature_list,
target_column='net_sales_quantity',
sequence_length=sequence_length,
forecast_horizon=forecast_horizon
)
# 2. 使用标准化的 prepare_sequences 函数创建 DataLoader
batch_size = 32
train_loader = prepare_sequences(trainX, trainY, batch_size)
test_loader = prepare_sequences(testX, testY, batch_size)
total_batches = len(train_loader)
total_samples = len(trainX)
if hasattr(progress_manager, 'total_batches_per_epoch'):
progress_manager.total_batches_per_epoch = total_batches
progress_manager.batch_size = batch_size
progress_manager.total_samples = total_samples
input_dim = trainX.shape[2]
output_dim = forecast_horizon
hidden_size = 64
num_layers = 3
kernel_size = 3
dropout_rate = 0.2
model = TCNForecaster(
num_features=input_dim,
output_sequence_length=output_dim,
num_channels=[hidden_size] * num_layers,
kernel_size=kernel_size,
dropout=dropout_rate
)
# TODO: Implement continue_training logic with the new model_manager
model = model.to(DEVICE)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
emit_progress("开始模型训练...")
train_losses = []
test_losses = []
start_time = time.time()
# 版本锁定
current_version = model_manager.peek_next_version(
model_type='tcn',
product_id=product_id,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method
)
print(f"🔒 本次训练版本锁定为: {current_version}")
checkpoint_interval = max(1, epochs // 10)
best_loss = float('inf')
best_model_path = None
progress_manager.set_stage("model_training", 0)
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}")
for epoch in range(epochs):
progress_manager.start_epoch(epoch)
model.train()
epoch_loss = 0
for batch_idx, (X_batch, y_batch) in enumerate(train_loader):
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
if y_batch.dim() == 2:
y_batch = y_batch.unsqueeze(-1)
outputs = model(X_batch)
loss = criterion(outputs, y_batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
if batch_idx % 10 == 0 or batch_idx == len(train_loader) - 1:
current_lr = optimizer.param_groups[0]['lr']
progress_manager.update_batch(batch_idx, loss.item(), current_lr)
train_loss = epoch_loss / len(train_loader)
train_losses.append(train_loss)
progress_manager.set_stage("validation", 0)
model.eval()
test_loss = 0
with torch.no_grad():
for batch_idx, (X_batch, y_batch) in enumerate(test_loader):
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
if y_batch.dim() == 2:
y_batch = y_batch.unsqueeze(-1)
outputs = model(X_batch)
loss = criterion(outputs, y_batch)
test_loss += loss.item()
if batch_idx % 5 == 0 or batch_idx == len(test_loader) - 1:
val_progress = (batch_idx / len(test_loader)) * 100
progress_manager.set_stage("validation", val_progress)
test_loss = test_loss / len(test_loader)
test_losses.append(test_loss)
progress_manager.finish_epoch(train_loss, test_loss)
if (epoch + 1) % 5 == 0 or epoch == epochs - 1:
progress = ((epoch + 1) / epochs) * 100
current_metrics = {
'train_loss': train_loss,
'test_loss': test_loss,
'epoch': epoch + 1,
'total_epochs': epochs
}
emit_progress(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}",
progress=progress, metrics=current_metrics)
if (epoch + 1) % checkpoint_interval == 0 or epoch == epochs - 1:
checkpoint_data = {
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_loss,
'test_loss': test_loss,
'train_losses': train_losses,
'test_losses': test_losses,
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
'input_dim': input_dim,
'output_dim': output_dim,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_channels': [hidden_size] * num_layers,
'dropout': dropout_rate,
'kernel_size': kernel_size,
'sequence_length': sequence_length,
'forecast_horizon': forecast_horizon,
'model_type': 'tcn',
'features': used_features
},
'training_info': {
'product_id': product_id,
'product_name': product_name,
'training_mode': training_mode,
'store_id': store_id,
'aggregation_method': aggregation_method,
'timestamp': time.time()
}
}
if test_loss < best_loss:
best_loss = test_loss
best_model_path, _ = model_manager.save_model(
model_data=checkpoint_data,
product_id=product_id,
model_type='tcn',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name,
version=f"{current_version}_best"
)
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
training_time = time.time() - start_time
progress_manager.set_stage("model_saving", 0)
emit_progress("训练完成,正在保存模型...")
model.eval()
with torch.no_grad():
all_test_X = []
all_test_Y = []
for X_batch, y_batch in test_loader:
all_test_X.append(X_batch)
all_test_Y.append(y_batch)
testX_tensor = torch.cat(all_test_X, dim=0)
testY_tensor = torch.cat(all_test_Y, dim=0)
test_pred = model(testX_tensor.to(DEVICE))
test_pred = test_pred.squeeze(-1).cpu().numpy()
# 反归一化需要reshape
test_pred_inv = scaler_y.inverse_transform(test_pred)
test_true_inv = scaler_y.inverse_transform(testY_tensor.cpu().numpy())
metrics = evaluate_model(test_true_inv, test_pred_inv)
metrics['training_time'] = training_time
print("\n模型评估指标:")
print(f"MSE: {metrics['mse']:.4f}")
print(f"RMSE: {metrics['rmse']:.4f}")
print(f"MAE: {metrics['mae']:.4f}")
print(f"R²: {metrics['r2']:.4f}")
print(f"MAPE: {metrics['mape']:.2f}%")
print(f"训练时间: {training_time:.2f}")
final_model_data = {
'epoch': epochs,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_losses[-1],
'test_loss': test_losses[-1],
'train_losses': train_losses,
'test_losses': test_losses,
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
'input_dim': input_dim,
'output_dim': output_dim,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_channels': [hidden_size] * num_layers,
'dropout': dropout_rate,
'kernel_size': kernel_size,
'sequence_length': sequence_length,
'forecast_horizon': forecast_horizon,
'model_type': 'tcn',
'features': used_features
},
'metrics': metrics,
'metrics': metrics,
'training_info': {
'product_id': product_id,
'product_name': product_name,
'training_mode': training_mode,
'store_id': store_id,
'aggregation_method': aggregation_method,
'timestamp': time.time(),
'training_completed': True
}
}
progress_manager.set_stage("model_saving", 50)
final_model_path, final_version = model_manager.save_model(
model_data=final_model_data,
product_id=product_id,
model_type='tcn',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name,
version=current_version
)
progress_manager.set_stage("model_saving", 100)
final_metrics = {
'mse': metrics['mse'],
'rmse': metrics['rmse'],
'mae': metrics['mae'],
'r2': metrics['r2'],
'mape': metrics['mape'],
'training_time': training_time,
'final_epoch': epochs,
'version': final_version
}
emit_progress(f"模型训练完成!版本 {final_version} 已保存", progress=100, metrics=final_metrics)
# 准备 scope 和 identifier 以生成标准化的文件名
scope = training_mode
if scope == 'product':
identifier = model_identifier
elif scope == 'store':
identifier = store_id
elif scope == 'global':
identifier = aggregation_method
else:
identifier = product_name # 后备方案
# 绘制带有版本号的损失曲线图
loss_curve_path = plot_loss_curve(
train_losses=train_losses,
val_losses=test_losses,
model_type='tcn',
scope=scope,
identifier=identifier,
version=current_version, # 使用锁定的版本
model_dir=model_dir
)
print(f"📈 带版本号的损失曲线已保存: {loss_curve_path}")
# 更新模型数据中的损失图路径
final_model_data['loss_curve_path'] = loss_curve_path
# --- 新增保存损失历史为JSON文件 ---
loss_data_filename = f"{identifier}_{current_version}_loss_curve_data.json"
loss_data_path = os.path.join(model_dir, loss_data_filename)
loss_data_to_save = {
'epochs': list(range(1, len(train_losses) + 1)),
'train_loss': train_losses,
'test_loss': test_losses
}
with open(loss_data_path, 'w') as f:
json.dump(loss_data_to_save, f)
print(f"💾 损失历史数据已保存: {loss_data_path}")
artifacts = {
"versioned_model": final_model_path,
"loss_curve_plot": loss_curve_path,
"loss_curve_data": loss_data_path,
"best_model": best_model_path,
"version": final_version,
'loss_history': {
'train': train_losses,
'val': test_losses
}
}
return metrics, artifacts
# --- 将此训练器注册到系统中 ---
from models.model_registry import register_trainer
register_trainer('tcn', train_product_model_with_tcn)