""" 药店销售预测系统 - mLSTM模型训练函数 """ import os import time import pandas as pd import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from sklearn.preprocessing import MinMaxScaler import matplotlib.pyplot as plt from datetime import datetime from models.mlstm_model import MLSTMTransformer as MatrixLSTM from utils.data_utils import create_dataset, PharmacyDataset from analysis.metrics import evaluate_model from core.config import ( DEVICE, LOOK_BACK, FORECAST_HORIZON ) from utils.training_progress import progress_manager from utils.model_manager import model_manager def train_product_model_with_mlstm( product_id, product_df, store_id=None, training_mode='product', aggregation_method='sum', epochs=50, socketio=None, task_id=None, continue_training=False, progress_callback=None, patience=10, learning_rate=0.001, clip_norm=1.0 ): """ 使用mLSTM训练产品销售预测模型 参数: product_id: 产品ID store_id: 店铺ID,为None时使用全局数据 training_mode: 训练模式 ('product', 'store', 'global') aggregation_method: 聚合方法 ('sum', 'mean', 'weighted') epochs: 训练轮次 model_dir: 模型保存目录 version: 模型版本,如果为None则自动生成 socketio: Socket.IO实例,用于实时进度推送 task_id: 任务ID continue_training: 是否继续训练 progress_callback: 进度回调函数,用于多进程训练 """ # 创建WebSocket进度反馈函数,支持多进程 """ def emit_progress(message, progress=None, metrics=None): """发送训练进度到前端""" progress_data = { 'task_id': task_id, 'message': message, 'timestamp': time.time() } if progress is not None: progress_data['progress'] = progress if metrics is not None: progress_data['metrics'] = metrics if progress_callback: try: progress_callback(progress_data) except Exception as e: print(f"[mLSTM] 进度回调失败: {e}") if socketio and task_id: try: socketio.emit('training_progress', progress_data, namespace='/training') except Exception as e: print(f"[mLSTM] WebSocket发送失败: {e}") print(f"[mLSTM] {message}", flush=True) import sys sys.stdout.flush() sys.stderr.flush() emit_progress("开始mLSTM模型训练...") # 1. 确定模型标识符和版本 model_type = 'mlstm' if training_mode == 'store': scope = f"{store_id}_{product_id}" elif training_mode == 'global': scope = f"{product_id}" if product_id else "all" else: scope = f"{product_id}_all" model_identifier = model_manager.get_model_identifier(model_type, training_mode, scope, aggregation_method) version = model_manager.get_next_version_number(model_identifier) emit_progress(f"开始训练 mLSTM 模型 v{version}") # 2. 获取模型版本路径 model_version_path = model_manager.get_model_version_path(model_type, training_mode, scope, version, aggregation_method) emit_progress(f"模型将保存到: {model_version_path}") if training_mode == 'store' and store_id: training_scope = f"店铺 {store_id}" elif training_mode == 'global': training_scope = f"全局聚合({aggregation_method})" else: training_scope = "所有店铺" min_required_samples = LOOK_BACK + FORECAST_HORIZON if len(product_df) < min_required_samples: error_msg = f"数据不足: 需要 {min_required_samples} 天, 实际 {len(product_df)} 天。" print(error_msg) emit_progress(f"训练失败:{error_msg}") raise ValueError(error_msg) product_name = product_df['product_name'].iloc[0] print(f"[mLSTM] 使用mLSTM模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型", flush=True) print(f"[mLSTM] 训练范围: {training_scope}", flush=True) print(f"[mLSTM] 版本: v{version}", flush=True) print(f"[mLSTM] 使用设备: {DEVICE}", flush=True) print(f"[mLSTM] 数据量: {len(product_df)} 条记录", flush=True) emit_progress(f"训练产品: {product_name} (ID: {product_id}) - {training_scope}") # 创建特征和目标变量 features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'] print(f"[mLSTM] 开始数据预处理,特征: {features}", flush=True) # 预处理数据 X = product_df[features].values y = product_df[['sales']].values print(f"[mLSTM] 特征矩阵形状: {X.shape}, 目标矩阵形状: {y.shape}", flush=True) emit_progress("数据预处理中...") scaler_X = MinMaxScaler(feature_range=(0, 1)) scaler_y = MinMaxScaler(feature_range=(0, 1)) X_scaled = scaler_X.fit_transform(X) y_scaled = scaler_y.fit_transform(y) print(f"[mLSTM] 数据归一化完成", flush=True) train_size = int(len(X_scaled) * 0.8) X_train, X_test = X_scaled[:train_size], X_scaled[train_size:] y_train, y_test = y_scaled[:train_size], y_scaled[train_size:] trainX, trainY = create_dataset(X_train, y_train, LOOK_BACK, FORECAST_HORIZON) testX, testY = create_dataset(X_test, y_test, LOOK_BACK, FORECAST_HORIZON) train_loader = DataLoader(PharmacyDataset(torch.Tensor(trainX), torch.Tensor(trainY)), batch_size=32, shuffle=True) test_loader = DataLoader(PharmacyDataset(torch.Tensor(testX), torch.Tensor(testY)), batch_size=32, shuffle=False) print(f"[mLSTM] 数据加载器创建完成 - 批次数: {total_batches}, 样本数: {total_samples}", flush=True) emit_progress(f"数据加载器准备完成 - 批次数: {total_batches}, 样本数: {total_samples}") input_dim = X_train.shape[1] output_dim = FORECAST_HORIZON hidden_size, num_heads, dropout_rate, num_blocks, embed_dim, dense_dim = 128, 4, 0.1, 3, 32, 32 model = MatrixLSTM( num_features=input_dim, hidden_size=hidden_size, mlstm_layers=2, embed_dim=embed_dim, dense_dim=dense_dim, num_heads=num_heads, dropout_rate=dropout_rate, num_blocks=num_blocks, output_sequence_length=output_dim ).to(DEVICE) print(f"[mLSTM] 模型创建完成", flush=True) emit_progress("mLSTM模型初始化完成") if continue_training: emit_progress("继续训练模式启动,但当前重构版本将从头开始。") criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=learning_rate) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience // 2, factor=0.5) emit_progress("数据预处理完成,开始模型训练...", progress=10) train_losses, test_losses = [], [] start_time = time.time() checkpoint_interval = max(1, epochs // 10) best_loss = float('inf') epochs_no_improve = 0 emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}, 耐心值: {patience}") for epoch in range(epochs): emit_progress(f"开始训练 Epoch {epoch+1}/{epochs}") model.train() epoch_loss = 0 for X_batch, y_batch in train_loader: X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE) optimizer.zero_grad() outputs = model(X_batch) loss = criterion(outputs, y_batch) loss.backward() if clip_norm: torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_norm) optimizer.step() epoch_loss += loss.item() # 计算训练损失 train_loss = epoch_loss / len(train_loader) train_losses.append(train_loss) # 在测试集上评估 model.eval() test_loss = 0 with torch.no_grad(): for X_batch, y_batch in test_loader: X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE) outputs = model(X_batch) loss = criterion(outputs, y_batch) test_loss += loss.item() test_loss /= len(test_loader) test_losses.append(test_loss) # 更新学习率 scheduler.step(test_loss) emit_progress(f"Epoch {epoch+1}/{epochs} 完成 - Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}", progress=10 + ((epoch + 1) / epochs) * 85) # 定期保存检查点 # 3. 保存检查点 checkpoint_data = { 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scaler_X': scaler_X, 'scaler_y': scaler_y, } if (epoch + 1) % checkpoint_interval == 0: model_manager.save_model_artifact(checkpoint_data, f"checkpoint_epoch_{epoch+1}.pth", model_version_path) emit_progress(f"💾 保存训练检查点 epoch_{epoch+1}") if test_loss < best_loss: best_loss = test_loss model_manager.save_model_artifact(checkpoint_data, "checkpoint_best.pth", model_version_path) emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})") epochs_no_improve = 0 else: epochs_no_improve += 1 if epochs_no_improve >= patience: emit_progress(f"连续 {patience} 个epoch测试损失未改善,提前停止训练。") break training_time = time.time() - start_time loss_fig = plt.figure(figsize=(10, 6)) plt.plot(train_losses, label='Training Loss') plt.plot(test_losses, label='Test Loss') plt.title(f'mLSTM 损失曲线 - {product_name} (v{version}) - {training_scope}') plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.grid(True) model_manager.save_model_artifact(loss_fig, "loss_curve.png", model_version_path) plt.close(loss_fig) print(f"损失曲线已保存到: {os.path.join(model_version_path, 'loss_curve.png')}") model.eval() with torch.no_grad(): test_pred = model(torch.Tensor(testX).to(DEVICE)).cpu().numpy() metrics = evaluate_model(scaler_y.inverse_transform(testY), scaler_y.inverse_transform(test_pred)) metrics['training_time'] = training_time # 打印评估指标 print("\n模型评估指标:") print(f"MSE: {metrics['mse']:.4f}") print(f"RMSE: {metrics['rmse']:.4f}") print(f"MAE: {metrics['mae']:.4f}") print(f"R²: {metrics['r2']:.4f}") print(f"MAPE: {metrics['mape']:.2f}%") print(f"训练时间: {training_time:.2f}秒") final_model_data = { 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'scaler_X': scaler_X, 'scaler_y': scaler_y, } model_manager.save_model_artifact(final_model_data, "model.pth", model_version_path) metadata = { 'product_id': product_id, 'product_name': product_name, 'model_type': model_type, 'version': f'v{version}', 'training_mode': training_mode, 'scope': scope, 'aggregation_method': aggregation_method, 'training_scope_description': training_scope, 'timestamp': datetime.now().isoformat(), 'metrics': metrics, 'config': { 'input_dim': input_dim, 'output_dim': output_dim, 'hidden_size': hidden_size, 'num_heads': num_heads, 'dropout': dropout_rate, 'num_blocks': num_blocks, 'embed_dim': embed_dim, 'dense_dim': dense_dim, 'sequence_length': LOOK_BACK, 'forecast_horizon': FORECAST_HORIZON, } } model_manager.save_model_artifact(metadata, "metadata.json", model_version_path) # 6. 更新版本文件 model_manager.update_version(model_identifier, version) emit_progress(f"✅ mLSTM模型 v{version} 训练完成!", progress=100, metrics=metrics) return model, metrics, version, model_version_path