ShopTRAINING/server/trainers/tcn_trainer.py
2025-07-22 15:41:05 +08:00

386 lines
14 KiB
Python

"""
药店销售预测系统 - TCN模型训练函数
"""
import os
import time
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
from tqdm import tqdm
from models.tcn_model import TCNForecaster
from utils.data_utils import create_dataset, PharmacyDataset
from utils.visualization import plot_loss_curve
from analysis.metrics import evaluate_model
from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
from utils.training_progress import progress_manager
from utils.model_manager import model_manager
def train_product_model_with_tcn(
product_id,
model_identifier,
product_df=None,
store_id=None,
training_mode='product',
aggregation_method='sum',
epochs=50,
sequence_length=LOOK_BACK,
forecast_horizon=FORECAST_HORIZON,
model_dir=DEFAULT_MODEL_DIR,
version=None,
socketio=None,
task_id=None,
continue_training=False
):
"""
使用TCN模型训练产品销售预测模型
"""
def emit_progress(message, progress=None, metrics=None):
"""发送训练进度到前端"""
if socketio and task_id:
data = {
'task_id': task_id,
'message': message,
'timestamp': time.time()
}
if progress is not None:
data['progress'] = progress
if metrics is not None:
data['metrics'] = metrics
socketio.emit('training_progress', data, namespace='/training')
emit_progress(f"开始训练 TCN 模型")
if product_df is None:
from utils.multi_store_data_utils import aggregate_multi_store_data
product_df = aggregate_multi_store_data(
product_id=product_id,
aggregation_method=aggregation_method
)
training_scope = f"全局聚合({aggregation_method})"
else:
training_scope = "所有店铺"
if product_df.empty:
raise ValueError(f"产品 {product_id} 没有可用的销售数据")
min_required_samples = sequence_length + forecast_horizon
if len(product_df) < min_required_samples:
error_msg = (
f"❌ 训练数据不足错误\n"
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={sequence_length} + FORECAST_HORIZON={forecast_horizon})\n"
f"实际数据量: {len(product_df)}\n"
f"产品ID: {product_id}, 训练模式: {training_mode}\n"
)
print(error_msg)
emit_progress(f"训练失败:数据不足 ({len(product_df)}/{min_required_samples} 天)")
raise ValueError(error_msg)
product_df = product_df.sort_values('date')
product_name = product_df['product_name'].iloc[0]
print(f"使用TCN模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型")
print(f"训练范围: {training_scope}")
print(f"使用设备: {DEVICE}")
print(f"模型将保存到目录: {model_dir}")
emit_progress(f"训练产品: {product_name} (ID: {product_id})")
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
X = product_df[features].values
y = product_df[['sales']].values
progress_manager.set_stage("data_preprocessing", 0)
emit_progress("数据预处理中...")
scaler_X = MinMaxScaler(feature_range=(0, 1))
scaler_y = MinMaxScaler(feature_range=(0, 1))
X_scaled = scaler_X.fit_transform(X)
y_scaled = scaler_y.fit_transform(y)
train_size = int(len(X_scaled) * 0.8)
X_train, X_test = X_scaled[:train_size], X_scaled[train_size:]
y_train, y_test = y_scaled[:train_size], y_scaled[train_size:]
progress_manager.set_stage("data_preprocessing", 50)
trainX, trainY = create_dataset(X_train, y_train, sequence_length, forecast_horizon)
testX, testY = create_dataset(X_test, y_test, sequence_length, forecast_horizon)
trainX_tensor = torch.Tensor(trainX)
trainY_tensor = torch.Tensor(trainY)
testX_tensor = torch.Tensor(testX)
testY_tensor = torch.Tensor(testY)
train_dataset = PharmacyDataset(trainX_tensor, trainY_tensor)
test_dataset = PharmacyDataset(testX_tensor, testY_tensor)
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
total_batches = len(train_loader)
total_samples = len(train_dataset)
progress_manager.total_batches_per_epoch = total_batches
progress_manager.batch_size = batch_size
progress_manager.total_samples = total_samples
progress_manager.set_stage("data_preprocessing", 100)
input_dim = X_train.shape[1]
output_dim = forecast_horizon
hidden_size = 64
num_layers = 3
kernel_size = 3
dropout_rate = 0.2
model = TCNForecaster(
num_features=input_dim,
output_sequence_length=output_dim,
num_channels=[hidden_size] * num_layers,
kernel_size=kernel_size,
dropout=dropout_rate
)
# TODO: Implement continue_training logic with the new model_manager
model = model.to(DEVICE)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
emit_progress("开始模型训练...")
train_losses = []
test_losses = []
start_time = time.time()
checkpoint_interval = max(1, epochs // 10)
best_loss = float('inf')
progress_manager.set_stage("model_training", 0)
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}")
for epoch in range(epochs):
progress_manager.start_epoch(epoch)
model.train()
epoch_loss = 0
for batch_idx, (X_batch, y_batch) in enumerate(train_loader):
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
if y_batch.dim() == 2:
y_batch = y_batch.unsqueeze(-1)
outputs = model(X_batch)
loss = criterion(outputs, y_batch)
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
if batch_idx % 10 == 0 or batch_idx == len(train_loader) - 1:
current_lr = optimizer.param_groups[0]['lr']
progress_manager.update_batch(batch_idx, loss.item(), current_lr)
train_loss = epoch_loss / len(train_loader)
train_losses.append(train_loss)
progress_manager.set_stage("validation", 0)
model.eval()
test_loss = 0
with torch.no_grad():
for batch_idx, (X_batch, y_batch) in enumerate(test_loader):
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
if y_batch.dim() == 2:
y_batch = y_batch.unsqueeze(-1)
outputs = model(X_batch)
loss = criterion(outputs, y_batch)
test_loss += loss.item()
if batch_idx % 5 == 0 or batch_idx == len(test_loader) - 1:
val_progress = (batch_idx / len(test_loader)) * 100
progress_manager.set_stage("validation", val_progress)
test_loss = test_loss / len(test_loader)
test_losses.append(test_loss)
progress_manager.finish_epoch(train_loss, test_loss)
if (epoch + 1) % 5 == 0 or epoch == epochs - 1:
progress = ((epoch + 1) / epochs) * 100
current_metrics = {
'train_loss': train_loss,
'test_loss': test_loss,
'epoch': epoch + 1,
'total_epochs': epochs
}
emit_progress(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}",
progress=progress, metrics=current_metrics)
if (epoch + 1) % checkpoint_interval == 0 or epoch == epochs - 1:
checkpoint_data = {
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_loss,
'test_loss': test_loss,
'train_losses': train_losses,
'test_losses': test_losses,
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
'input_dim': input_dim,
'output_dim': output_dim,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_channels': [hidden_size] * num_layers,
'dropout': dropout_rate,
'kernel_size': kernel_size,
'sequence_length': sequence_length,
'forecast_horizon': forecast_horizon,
'model_type': 'tcn'
},
'training_info': {
'product_id': product_id,
'product_name': product_name,
'training_mode': training_mode,
'store_id': store_id,
'aggregation_method': aggregation_method,
'timestamp': time.time()
}
}
if test_loss < best_loss:
best_loss = test_loss
model_manager.save_model(
model_data=checkpoint_data,
product_id=model_identifier, # 修正:使用唯一的标识符
model_type='tcn',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name,
version='best'
)
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
training_time = time.time() - start_time
progress_manager.set_stage("model_saving", 0)
emit_progress("训练完成,正在保存模型...")
loss_curve_path = plot_loss_curve(
train_losses,
test_losses,
product_name,
'TCN',
model_dir=model_dir
)
print(f"损失曲线已保存到: {loss_curve_path}")
model.eval()
with torch.no_grad():
test_pred = model(testX_tensor.to(DEVICE))
test_pred = test_pred.squeeze(-1).cpu().numpy()
test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, 1)).flatten()
test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, 1)).flatten()
metrics = evaluate_model(test_true_inv, test_pred_inv)
metrics['training_time'] = training_time
print("\n模型评估指标:")
print(f"MSE: {metrics['mse']:.4f}")
print(f"RMSE: {metrics['rmse']:.4f}")
print(f"MAE: {metrics['mae']:.4f}")
print(f"R²: {metrics['r2']:.4f}")
print(f"MAPE: {metrics['mape']:.2f}%")
print(f"训练时间: {training_time:.2f}")
final_model_data = {
'epoch': epochs,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_losses[-1],
'test_loss': test_losses[-1],
'train_losses': train_losses,
'test_losses': test_losses,
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
'input_dim': input_dim,
'output_dim': output_dim,
'hidden_size': hidden_size,
'num_layers': num_layers,
'num_channels': [hidden_size] * num_layers,
'dropout': dropout_rate,
'kernel_size': kernel_size,
'sequence_length': sequence_length,
'forecast_horizon': forecast_horizon,
'model_type': 'tcn'
},
'metrics': metrics,
'loss_curve_path': loss_curve_path,
'training_info': {
'product_id': product_id,
'product_name': product_name,
'training_mode': training_mode,
'store_id': store_id,
'aggregation_method': aggregation_method,
'timestamp': time.time(),
'training_completed': True
}
}
progress_manager.set_stage("model_saving", 50)
final_model_path, final_version = model_manager.save_model(
model_data=final_model_data,
product_id=model_identifier, # 修正:使用唯一的标识符
model_type='tcn',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name
)
progress_manager.set_stage("model_saving", 100)
final_metrics = {
'mse': metrics['mse'],
'rmse': metrics['rmse'],
'mae': metrics['mae'],
'r2': metrics['r2'],
'mape': metrics['mape'],
'training_time': training_time,
'final_epoch': epochs,
'version': final_version
}
emit_progress(f"模型训练完成!版本 {final_version} 已保存", progress=100, metrics=final_metrics)
return model, metrics, epochs, final_model_path
# --- 将此训练器注册到系统中 ---
from models.model_registry import register_trainer
register_trainer('tcn', train_product_model_with_tcn)