ShopTRAINING/server/trainers/transformer_trainer.py

451 lines
16 KiB
Python
Raw Normal View History

"""
药店销售预测系统 - Transformer模型训练函数
"""
import os
import time
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
from tqdm import tqdm
2025-07-02 11:05:23 +08:00
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from models.transformer_model import TimeSeriesTransformer
from utils.data_utils import create_dataset, PharmacyDataset
2025-07-02 11:05:23 +08:00
from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data
from utils.visualization import plot_loss_curve
from analysis.metrics import evaluate_model
2025-07-02 11:05:23 +08:00
from core.config import (
DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
2025-07-02 11:05:23 +08:00
)
from utils.training_progress import progress_manager
from utils.model_manager import model_manager
2025-07-02 11:05:23 +08:00
def train_product_model_with_transformer(
product_id,
2025-07-16 18:50:16 +08:00
model_identifier,
product_df=None,
store_id=None,
training_mode='product',
aggregation_method='sum',
epochs=50,
2025-07-16 12:59:56 +08:00
sequence_length=LOOK_BACK,
forecast_horizon=FORECAST_HORIZON,
2025-07-02 11:05:23 +08:00
model_dir=DEFAULT_MODEL_DIR,
version=None,
socketio=None,
task_id=None,
continue_training=False,
patience=10,
learning_rate=0.001,
clip_norm=1.0
2025-07-02 11:05:23 +08:00
):
"""
使用Transformer模型训练产品销售预测模型
"""
2025-07-02 11:05:23 +08:00
def emit_progress(message, progress=None, metrics=None):
"""发送训练进度到前端"""
if socketio and task_id:
data = {
'task_id': task_id,
'message': message,
'timestamp': time.time()
}
if progress is not None:
data['progress'] = progress
if metrics is not None:
data['metrics'] = metrics
socketio.emit('training_progress', data, namespace='/training')
print(f"[{time.strftime('%H:%M:%S')}] {message}", flush=True)
import sys
sys.stdout.flush()
sys.stderr.flush()
emit_progress("开始Transformer模型训练...")
try:
from utils.training_progress import progress_manager
except ImportError:
class DummyProgressManager:
def set_stage(self, *args, **kwargs): pass
def start_training(self, *args, **kwargs): pass
def start_epoch(self, *args, **kwargs): pass
def update_batch(self, *args, **kwargs): pass
def finish_epoch(self, *args, **kwargs): pass
def finish_training(self, *args, **kwargs): pass
progress_manager = DummyProgressManager()
if product_df is None:
from utils.multi_store_data_utils import aggregate_multi_store_data
product_df = aggregate_multi_store_data(
product_id=product_id,
aggregation_method=aggregation_method
)
training_scope = f"全局聚合({aggregation_method})"
else:
training_scope = "所有店铺"
2025-07-02 11:05:23 +08:00
if product_df.empty:
raise ValueError(f"产品 {product_id} 没有可用的销售数据")
2025-07-16 12:59:56 +08:00
min_required_samples = sequence_length + forecast_horizon
2025-07-02 11:05:23 +08:00
if len(product_df) < min_required_samples:
error_msg = (
f"❌ 训练数据不足错误\n"
2025-07-16 12:59:56 +08:00
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={sequence_length} + FORECAST_HORIZON={forecast_horizon})\n"
2025-07-02 11:05:23 +08:00
f"实际数据量: {len(product_df)}\n"
f"产品ID: {product_id}, 训练模式: {training_mode}\n"
)
print(error_msg)
raise ValueError(error_msg)
product_df = product_df.sort_values('date')
product_name = product_df['product_name'].iloc[0]
2025-07-02 11:05:23 +08:00
print(f"[Transformer] 训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型", flush=True)
print(f"[Device] 使用设备: {DEVICE}", flush=True)
print(f"[Model] 模型将保存到目录: {model_dir}", flush=True)
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
2025-07-02 11:05:23 +08:00
progress_manager.set_stage("data_preprocessing", 0)
emit_progress("数据预处理中...")
X = product_df[features].values
y = product_df[['sales']].values
scaler_X = MinMaxScaler(feature_range=(0, 1))
scaler_y = MinMaxScaler(feature_range=(0, 1))
X_scaled = scaler_X.fit_transform(X)
y_scaled = scaler_y.fit_transform(y)
2025-07-02 11:05:23 +08:00
progress_manager.set_stage("data_preprocessing", 40)
train_size = int(len(X_scaled) * 0.8)
X_train, X_test = X_scaled[:train_size], X_scaled[train_size:]
y_train, y_test = y_scaled[:train_size], y_scaled[train_size:]
2025-07-16 12:59:56 +08:00
trainX, trainY = create_dataset(X_train, y_train, sequence_length, forecast_horizon)
testX, testY = create_dataset(X_test, y_test, sequence_length, forecast_horizon)
2025-07-02 11:05:23 +08:00
progress_manager.set_stage("data_preprocessing", 70)
trainX_tensor = torch.Tensor(trainX)
trainY_tensor = torch.Tensor(trainY)
testX_tensor = torch.Tensor(testX)
testY_tensor = torch.Tensor(testY)
train_dataset = PharmacyDataset(trainX_tensor, trainY_tensor)
test_dataset = PharmacyDataset(testX_tensor, testY_tensor)
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
2025-07-02 11:05:23 +08:00
total_batches = len(train_loader)
total_samples = len(train_dataset)
progress_manager.total_batches_per_epoch = total_batches
progress_manager.batch_size = batch_size
progress_manager.total_samples = total_samples
progress_manager.set_stage("data_preprocessing", 100)
emit_progress("数据预处理完成,开始模型训练...")
input_dim = X_train.shape[1]
2025-07-16 12:59:56 +08:00
output_dim = forecast_horizon
hidden_size = 64
num_heads = 4
dropout_rate = 0.1
num_layers = 3
model = TimeSeriesTransformer(
num_features=input_dim,
d_model=hidden_size,
nhead=num_heads,
num_encoder_layers=num_layers,
dim_feedforward=hidden_size * 2,
dropout=dropout_rate,
output_sequence_length=output_dim,
2025-07-16 12:59:56 +08:00
seq_length=sequence_length,
batch_size=batch_size
)
model = model.to(DEVICE)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
2025-07-15 11:55:39 +08:00
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience // 2, factor=0.5)
train_losses = []
test_losses = []
start_time = time.time()
# 版本锁定
current_version = model_manager.peek_next_version(
model_type='transformer',
product_id=model_identifier,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method
)
print(f"🔒 本次训练版本锁定为: {current_version}")
checkpoint_interval = max(1, epochs // 10)
2025-07-02 11:05:23 +08:00
best_loss = float('inf')
epochs_no_improve = 0
best_model_path = None
2025-07-02 11:05:23 +08:00
progress_manager.set_stage("model_training", 0)
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}, 耐心值: {patience}")
2025-07-02 11:05:23 +08:00
for epoch in range(epochs):
2025-07-02 11:05:23 +08:00
progress_manager.start_epoch(epoch)
model.train()
epoch_loss = 0
2025-07-02 11:05:23 +08:00
for batch_idx, (X_batch, y_batch) in enumerate(train_loader):
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
outputs = model(X_batch)
loss = criterion(outputs, y_batch)
optimizer.zero_grad()
loss.backward()
if clip_norm:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_norm)
optimizer.step()
epoch_loss += loss.item()
2025-07-02 11:05:23 +08:00
if batch_idx % 5 == 0 or batch_idx == len(train_loader) - 1:
current_lr = optimizer.param_groups[0]['lr']
progress_manager.update_batch(batch_idx, loss.item(), current_lr)
train_loss = epoch_loss / len(train_loader)
train_losses.append(train_loss)
2025-07-02 11:05:23 +08:00
progress_manager.set_stage("validation", 0)
model.eval()
test_loss = 0
with torch.no_grad():
2025-07-02 11:05:23 +08:00
for batch_idx, (X_batch, y_batch) in enumerate(test_loader):
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
outputs = model(X_batch)
loss = criterion(outputs, y_batch)
test_loss += loss.item()
2025-07-02 11:05:23 +08:00
if batch_idx % 3 == 0 or batch_idx == len(test_loader) - 1:
val_progress = (batch_idx / len(test_loader)) * 100
progress_manager.set_stage("validation", val_progress)
test_loss = test_loss / len(test_loader)
test_losses.append(test_loss)
scheduler.step(test_loss)
2025-07-02 11:05:23 +08:00
progress_manager.finish_epoch(train_loss, test_loss)
if (epoch + 1) % 5 == 0 or epoch == epochs - 1:
progress = ((epoch + 1) / epochs) * 100
current_metrics = {
'train_loss': train_loss,
'test_loss': test_loss,
'epoch': epoch + 1,
'total_epochs': epochs
}
emit_progress(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}",
progress=progress, metrics=current_metrics)
if (epoch + 1) % checkpoint_interval == 0 or epoch == epochs - 1:
checkpoint_data = {
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_loss,
'test_loss': test_loss,
'train_losses': train_losses,
'test_losses': test_losses,
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
'input_dim': input_dim,
'output_dim': output_dim,
'hidden_size': hidden_size,
'num_heads': num_heads,
'dropout': dropout_rate,
'num_layers': num_layers,
2025-07-16 12:59:56 +08:00
'sequence_length': sequence_length,
'forecast_horizon': forecast_horizon,
2025-07-02 11:05:23 +08:00
'model_type': 'transformer'
},
'training_info': {
'product_id': product_id,
'product_name': product_name,
'training_mode': training_mode,
'store_id': store_id,
'aggregation_method': aggregation_method,
'timestamp': time.time()
}
}
if test_loss < best_loss:
best_loss = test_loss
# 修正: 保存最佳模型路径
best_model_path, _ = model_manager.save_model(
model_data=checkpoint_data,
product_id=model_identifier, # 修正:使用唯一的标识符
model_type='transformer',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name,
version=f"{current_version}_best"
)
2025-07-02 11:05:23 +08:00
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
epochs_no_improve = 0
else:
epochs_no_improve += 1
2025-07-02 11:05:23 +08:00
if (epoch + 1) % 10 == 0:
2025-07-02 11:05:23 +08:00
print(f"📊 Epoch {epoch+1}/{epochs}, 训练损失: {train_loss:.4f}, 测试损失: {test_loss:.4f}", flush=True)
if epochs_no_improve >= patience:
emit_progress(f"连续 {patience} 个epoch测试损失未改善提前停止训练。")
break
training_time = time.time() - start_time
2025-07-02 11:05:23 +08:00
progress_manager.set_stage("model_saving", 0)
emit_progress("训练完成,正在保存模型...")
model.eval()
with torch.no_grad():
test_pred = model(testX_tensor.to(DEVICE)).cpu().numpy()
test_true = testY
test_pred_inv = scaler_y.inverse_transform(test_pred)
test_true_inv = scaler_y.inverse_transform(test_true)
metrics = evaluate_model(test_true_inv, test_pred_inv)
metrics['training_time'] = training_time
2025-07-02 11:05:23 +08:00
print(f"\n📊 模型评估指标:", flush=True)
print(f" MSE: {metrics['mse']:.4f}", flush=True)
print(f" RMSE: {metrics['rmse']:.4f}", flush=True)
print(f" MAE: {metrics['mae']:.4f}", flush=True)
print(f" R²: {metrics['r2']:.4f}", flush=True)
print(f" MAPE: {metrics['mape']:.2f}%", flush=True)
print(f" ⏱️ 训练时间: {training_time:.2f}", flush=True)
final_model_data = {
'epoch': epochs,
'model_state_dict': model.state_dict(),
2025-07-02 11:05:23 +08:00
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_losses[-1],
'test_loss': test_losses[-1],
'train_losses': train_losses,
'test_losses': test_losses,
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
'input_dim': input_dim,
'output_dim': output_dim,
'hidden_size': hidden_size,
'num_heads': num_heads,
'dropout': dropout_rate,
'num_layers': num_layers,
2025-07-16 12:59:56 +08:00
'sequence_length': sequence_length,
'forecast_horizon': forecast_horizon,
'model_type': 'transformer'
},
'metrics': metrics,
'metrics': metrics,
2025-07-02 11:05:23 +08:00
'training_info': {
'product_id': product_id,
'product_name': product_name,
'training_mode': training_mode,
'store_id': store_id,
'aggregation_method': aggregation_method,
'timestamp': time.time(),
'training_completed': True
}
}
progress_manager.set_stage("model_saving", 50)
final_model_path, final_version = model_manager.save_model(
model_data=final_model_data,
product_id=model_identifier, # 修正:使用唯一的标识符
model_type='transformer',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name,
version=current_version
2025-07-02 11:05:23 +08:00
)
progress_manager.set_stage("model_saving", 100)
emit_progress(f"模型已保存到 {final_model_path}")
print(f"💾 模型已保存到 {final_model_path}", flush=True)
2025-07-02 11:05:23 +08:00
final_metrics = {
'mse': metrics['mse'],
'rmse': metrics['rmse'],
'mae': metrics['mae'],
'r2': metrics['r2'],
'mape': metrics['mape'],
'training_time': training_time,
'final_epoch': epochs,
'version': final_version
2025-07-02 11:05:23 +08:00
}
# 准备 scope 和 identifier 以生成标准化的文件名
scope = training_mode
if scope == 'product':
identifier = model_identifier
elif scope == 'store':
identifier = store_id
elif scope == 'global':
identifier = aggregation_method
else:
identifier = product_name # 后备方案
# 绘制带有版本号的损失曲线图
loss_curve_path = plot_loss_curve(
train_losses=train_losses,
val_losses=test_losses,
model_type='transformer',
scope=scope,
identifier=identifier,
version=current_version, # 使用锁定的版本
model_dir=model_dir
)
print(f"📈 带版本号的损失曲线已保存: {loss_curve_path}")
# 更新模型数据中的损失图路径
final_model_data['loss_curve_path'] = loss_curve_path
artifacts = {
"versioned_model": final_model_path,
"loss_curve_plot": loss_curve_path,
"best_model": best_model_path,
"version": final_version
}
return model, final_metrics, artifacts
2025-07-22 15:40:37 +08:00
# --- 将此训练器注册到系统中 ---
from models.model_registry import register_trainer
register_trainer('transformer', train_product_model_with_transformer)