""" 药店销售预测系统 - TCN模型训练函数 """ import os import time import pandas as pd import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from sklearn.preprocessing import MinMaxScaler import matplotlib.pyplot as plt from tqdm import tqdm from models.tcn_model import TCNForecaster from utils.data_utils import prepare_data, PharmacyDataset, prepare_sequences from utils.visualization import plot_loss_curve from analysis.metrics import evaluate_model from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON from utils.training_progress import progress_manager from utils.model_manager import model_manager def train_product_model_with_tcn( model_identifier: str, training_df: pd.DataFrame, feature_list: list, training_mode: str, epochs: int = 50, sequence_length: int = LOOK_BACK, forecast_horizon: int = FORECAST_HORIZON, model_dir: str = DEFAULT_MODEL_DIR, product_id: str = None, store_id: str = None, aggregation_method: str = None, version: str = None, socketio=None, task_id: str = None, progress_callback=None, **kwargs ): """ 使用TCN模型训练产品销售预测模型 (新数据管道版) """ def emit_progress(message, progress=None, metrics=None): """发送训练进度到前端""" if socketio and task_id: data = { 'task_id': task_id, 'message': message, 'timestamp': time.time() } if progress is not None: data['progress'] = progress if metrics is not None: data['metrics'] = metrics socketio.emit('training_progress', data, namespace='/training') emit_progress(f"开始训练 TCN 模型") min_required_samples = sequence_length + forecast_horizon if len(training_df) < min_required_samples: error_msg = f"训练数据不足: 需要 {min_required_samples} 条记录, 但只有 {len(training_df)} 条。" emit_progress(error_msg) raise ValueError(error_msg) product_name = training_df['product_name'].iloc[0] if 'product_name' in training_df.columns else model_identifier emit_progress(f"开始为 '{product_name}' (标识: {model_identifier}) 训练TCN模型") # --- 新数据管道核心改造 --- emit_progress("数据预处理中...") # 1. 使用标准化的 prepare_data 函数处理数据 _, _, trainX, testX, trainY, testY, scaler_X, scaler_y = prepare_data( training_df=training_df, feature_list=feature_list, target_column='net_sales_quantity', sequence_length=sequence_length, forecast_horizon=forecast_horizon ) # 2. 使用标准化的 prepare_sequences 函数创建 DataLoader batch_size = 32 train_loader = prepare_sequences(trainX, trainY, batch_size) test_loader = prepare_sequences(testX, testY, batch_size) total_batches = len(train_loader) total_samples = len(trainX) if hasattr(progress_manager, 'total_batches_per_epoch'): progress_manager.total_batches_per_epoch = total_batches progress_manager.batch_size = batch_size progress_manager.total_samples = total_samples input_dim = trainX.shape[2] output_dim = forecast_horizon hidden_size = 64 num_layers = 3 kernel_size = 3 dropout_rate = 0.2 model = TCNForecaster( num_features=input_dim, output_sequence_length=output_dim, num_channels=[hidden_size] * num_layers, kernel_size=kernel_size, dropout=dropout_rate ) # TODO: Implement continue_training logic with the new model_manager model = model.to(DEVICE) criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=0.001) emit_progress("开始模型训练...") train_losses = [] test_losses = [] start_time = time.time() # 版本锁定 current_version = model_manager.peek_next_version( model_type='tcn', product_id=product_id, store_id=store_id, training_mode=training_mode, aggregation_method=aggregation_method ) print(f"🔒 本次训练版本锁定为: {current_version}") checkpoint_interval = max(1, epochs // 10) best_loss = float('inf') best_model_path = None progress_manager.set_stage("model_training", 0) emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}") for epoch in range(epochs): progress_manager.start_epoch(epoch) model.train() epoch_loss = 0 for batch_idx, (X_batch, y_batch) in enumerate(train_loader): X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE) if y_batch.dim() == 2: y_batch = y_batch.unsqueeze(-1) outputs = model(X_batch) loss = criterion(outputs, y_batch) optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss += loss.item() if batch_idx % 10 == 0 or batch_idx == len(train_loader) - 1: current_lr = optimizer.param_groups[0]['lr'] progress_manager.update_batch(batch_idx, loss.item(), current_lr) train_loss = epoch_loss / len(train_loader) train_losses.append(train_loss) progress_manager.set_stage("validation", 0) model.eval() test_loss = 0 with torch.no_grad(): for batch_idx, (X_batch, y_batch) in enumerate(test_loader): X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE) if y_batch.dim() == 2: y_batch = y_batch.unsqueeze(-1) outputs = model(X_batch) loss = criterion(outputs, y_batch) test_loss += loss.item() if batch_idx % 5 == 0 or batch_idx == len(test_loader) - 1: val_progress = (batch_idx / len(test_loader)) * 100 progress_manager.set_stage("validation", val_progress) test_loss = test_loss / len(test_loader) test_losses.append(test_loss) progress_manager.finish_epoch(train_loss, test_loss) if (epoch + 1) % 5 == 0 or epoch == epochs - 1: progress = ((epoch + 1) / epochs) * 100 current_metrics = { 'train_loss': train_loss, 'test_loss': test_loss, 'epoch': epoch + 1, 'total_epochs': epochs } emit_progress(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}", progress=progress, metrics=current_metrics) if (epoch + 1) % checkpoint_interval == 0 or epoch == epochs - 1: checkpoint_data = { 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'train_loss': train_loss, 'test_loss': test_loss, 'train_losses': train_losses, 'test_losses': test_losses, 'scaler_X': scaler_X, 'scaler_y': scaler_y, 'config': { 'input_dim': input_dim, 'output_dim': output_dim, 'hidden_size': hidden_size, 'num_layers': num_layers, 'num_channels': [hidden_size] * num_layers, 'dropout': dropout_rate, 'kernel_size': kernel_size, 'sequence_length': sequence_length, 'forecast_horizon': forecast_horizon, 'model_type': 'tcn' }, 'training_info': { 'product_id': product_id, 'product_name': product_name, 'training_mode': training_mode, 'store_id': store_id, 'aggregation_method': aggregation_method, 'timestamp': time.time() } } if test_loss < best_loss: best_loss = test_loss best_model_path, _ = model_manager.save_model( model_data=checkpoint_data, product_id=product_id, model_type='tcn', store_id=store_id, training_mode=training_mode, aggregation_method=aggregation_method, product_name=product_name, version=f"{current_version}_best" ) emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})") if (epoch + 1) % 10 == 0: print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}") training_time = time.time() - start_time progress_manager.set_stage("model_saving", 0) emit_progress("训练完成,正在保存模型...") model.eval() with torch.no_grad(): all_test_X = [] all_test_Y = [] for X_batch, y_batch in test_loader: all_test_X.append(X_batch) all_test_Y.append(y_batch) testX_tensor = torch.cat(all_test_X, dim=0) testY_tensor = torch.cat(all_test_Y, dim=0) test_pred = model(testX_tensor.to(DEVICE)) test_pred = test_pred.squeeze(-1).cpu().numpy() # 反归一化需要reshape test_pred_inv = scaler_y.inverse_transform(test_pred) test_true_inv = scaler_y.inverse_transform(testY_tensor.cpu().numpy()) metrics = evaluate_model(test_true_inv, test_pred_inv) metrics['training_time'] = training_time print("\n模型评估指标:") print(f"MSE: {metrics['mse']:.4f}") print(f"RMSE: {metrics['rmse']:.4f}") print(f"MAE: {metrics['mae']:.4f}") print(f"R²: {metrics['r2']:.4f}") print(f"MAPE: {metrics['mape']:.2f}%") print(f"训练时间: {training_time:.2f}秒") final_model_data = { 'epoch': epochs, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'train_loss': train_losses[-1], 'test_loss': test_losses[-1], 'train_losses': train_losses, 'test_losses': test_losses, 'scaler_X': scaler_X, 'scaler_y': scaler_y, 'config': { 'input_dim': input_dim, 'output_dim': output_dim, 'hidden_size': hidden_size, 'num_layers': num_layers, 'num_channels': [hidden_size] * num_layers, 'dropout': dropout_rate, 'kernel_size': kernel_size, 'sequence_length': sequence_length, 'forecast_horizon': forecast_horizon, 'model_type': 'tcn' }, 'metrics': metrics, 'metrics': metrics, 'training_info': { 'product_id': product_id, 'product_name': product_name, 'training_mode': training_mode, 'store_id': store_id, 'aggregation_method': aggregation_method, 'timestamp': time.time(), 'training_completed': True } } progress_manager.set_stage("model_saving", 50) final_model_path, final_version = model_manager.save_model( model_data=final_model_data, product_id=product_id, model_type='tcn', store_id=store_id, training_mode=training_mode, aggregation_method=aggregation_method, product_name=product_name, version=current_version ) progress_manager.set_stage("model_saving", 100) final_metrics = { 'mse': metrics['mse'], 'rmse': metrics['rmse'], 'mae': metrics['mae'], 'r2': metrics['r2'], 'mape': metrics['mape'], 'training_time': training_time, 'final_epoch': epochs, 'version': final_version } emit_progress(f"模型训练完成!版本 {final_version} 已保存", progress=100, metrics=final_metrics) # 准备 scope 和 identifier 以生成标准化的文件名 scope = training_mode if scope == 'product': identifier = model_identifier elif scope == 'store': identifier = store_id elif scope == 'global': identifier = aggregation_method else: identifier = product_name # 后备方案 # 绘制带有版本号的损失曲线图 loss_curve_path = plot_loss_curve( train_losses=train_losses, val_losses=test_losses, model_type='tcn', scope=scope, identifier=identifier, version=current_version, # 使用锁定的版本 model_dir=model_dir ) print(f"📈 带版本号的损失曲线已保存: {loss_curve_path}") # 更新模型数据中的损失图路径 final_model_data['loss_curve_path'] = loss_curve_path artifacts = { "versioned_model": final_model_path, "loss_curve_plot": loss_curve_path, "best_model": best_model_path, "version": final_version } return metrics, artifacts # --- 将此训练器注册到系统中 --- from models.model_registry import register_trainer register_trainer('tcn', train_product_model_with_tcn)