ShopTRAINING/server/trainers/mlstm_trainer.py
2025-07-02 11:05:23 +08:00

613 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
药店销售预测系统 - mLSTM模型训练函数
"""
import os
import time
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
from tqdm import tqdm
from models.mlstm_model import MLSTMTransformer as MatrixLSTM
from utils.data_utils import create_dataset, PharmacyDataset
from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data
from utils.visualization import plot_loss_curve
from analysis.metrics import evaluate_model
from core.config import (
DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON,
get_next_model_version, get_model_file_path, get_latest_model_version
)
from utils.training_progress import progress_manager
def save_checkpoint(checkpoint_data: dict, epoch_or_label, product_id: str,
model_type: str, model_dir: str, store_id=None,
training_mode: str = 'product', aggregation_method=None):
"""
保存训练检查点
Args:
checkpoint_data: 检查点数据
epoch_or_label: epoch编号或标签'best'
product_id: 产品ID
model_type: 模型类型
model_dir: 模型保存目录
store_id: 店铺ID
training_mode: 训练模式
aggregation_method: 聚合方法
"""
# 创建检查点目录
checkpoint_dir = os.path.join(model_dir, 'checkpoints')
os.makedirs(checkpoint_dir, exist_ok=True)
# 生成检查点文件名
if training_mode == 'store' and store_id:
filename = f"{model_type}_store_{store_id}_{product_id}_epoch_{epoch_or_label}.pth"
elif training_mode == 'global' and aggregation_method:
filename = f"{model_type}_global_{product_id}_{aggregation_method}_epoch_{epoch_or_label}.pth"
else:
filename = f"{model_type}_product_{product_id}_epoch_{epoch_or_label}.pth"
checkpoint_path = os.path.join(checkpoint_dir, filename)
# 保存检查点
torch.save(checkpoint_data, checkpoint_path)
print(f"[mLSTM] 检查点已保存: {checkpoint_path}", flush=True)
return checkpoint_path
def load_checkpoint(product_id: str, model_type: str, epoch_or_label,
model_dir: str, store_id=None, training_mode: str = 'product',
aggregation_method=None):
"""
加载训练检查点
Args:
product_id: 产品ID
model_type: 模型类型
epoch_or_label: epoch编号或标签
model_dir: 模型保存目录
store_id: 店铺ID
training_mode: 训练模式
aggregation_method: 聚合方法
Returns:
checkpoint_data: 检查点数据如果未找到返回None
"""
checkpoint_dir = os.path.join(model_dir, 'checkpoints')
# 生成检查点文件名
if training_mode == 'store' and store_id:
filename = f"{model_type}_store_{store_id}_{product_id}_epoch_{epoch_or_label}.pth"
elif training_mode == 'global' and aggregation_method:
filename = f"{model_type}_global_{product_id}_{aggregation_method}_epoch_{epoch_or_label}.pth"
else:
filename = f"{model_type}_product_{product_id}_epoch_{epoch_or_label}.pth"
checkpoint_path = os.path.join(checkpoint_dir, filename)
if os.path.exists(checkpoint_path):
try:
checkpoint_data = torch.load(checkpoint_path, map_location=DEVICE)
print(f"[mLSTM] 检查点已加载: {checkpoint_path}", flush=True)
return checkpoint_data
except Exception as e:
print(f"[mLSTM] 加载检查点失败: {e}", flush=True)
return None
else:
print(f"[mLSTM] 检查点文件不存在: {checkpoint_path}", flush=True)
return None
def train_product_model_with_mlstm(
product_id,
store_id=None,
training_mode='product',
aggregation_method='sum',
epochs=50,
model_dir=DEFAULT_MODEL_DIR,
version=None,
socketio=None,
task_id=None,
continue_training=False,
progress_callback=None
):
"""
使用mLSTM训练产品销售预测模型
参数:
product_id: 产品ID
store_id: 店铺ID为None时使用全局数据
training_mode: 训练模式 ('product', 'store', 'global')
aggregation_method: 聚合方法 ('sum', 'mean', 'weighted')
epochs: 训练轮次
model_dir: 模型保存目录
version: 模型版本如果为None则自动生成
socketio: Socket.IO实例用于实时进度推送
task_id: 任务ID
continue_training: 是否继续训练
progress_callback: 进度回调函数,用于多进程训练
"""
# 创建WebSocket进度反馈函数支持多进程
def emit_progress(message, progress=None, metrics=None):
"""发送训练进度到前端"""
progress_data = {
'task_id': task_id,
'message': message,
'timestamp': time.time()
}
if progress is not None:
progress_data['progress'] = progress
if metrics is not None:
progress_data['metrics'] = metrics
# 在多进程环境中使用progress_callback
if progress_callback:
try:
progress_callback(progress_data)
except Exception as e:
print(f"[mLSTM] 进度回调失败: {e}")
# 在单进程环境中使用socketio
if socketio and task_id:
try:
socketio.emit('training_progress', progress_data, namespace='/training')
except Exception as e:
print(f"[mLSTM] WebSocket发送失败: {e}")
print(f"[mLSTM] {message}", flush=True)
# 强制刷新输出缓冲区
import sys
sys.stdout.flush()
sys.stderr.flush()
emit_progress("开始mLSTM模型训练...")
# 根据训练模式加载数据
from utils.multi_store_data_utils import load_multi_store_data, get_store_product_sales_data, aggregate_multi_store_data
# 确定版本号
if version is None:
if continue_training:
version = get_latest_model_version(product_id, 'mlstm')
if version is None:
version = get_next_model_version(product_id, 'mlstm')
else:
version = get_next_model_version(product_id, 'mlstm')
emit_progress(f"开始训练 mLSTM 模型版本 {version}")
# 初始化训练进度管理器(如果还未初始化)
if socketio and task_id:
print(f"[mLSTM] 任务 {task_id}: 开始mLSTM训练器", flush=True)
try:
# 初始化进度管理器
if not hasattr(progress_manager, 'training_id') or progress_manager.training_id != task_id:
progress_manager.start_training(
training_id=task_id,
product_id=product_id,
model_type='mlstm',
training_mode=training_mode,
total_epochs=epochs,
total_batches=0, # 将在后面设置
batch_size=32, # 默认值
total_samples=0 # 将在后面设置
)
print(f"[mLSTM] 任务 {task_id}: 进度管理器已初始化", flush=True)
else:
print(f"[mLSTM] 任务 {task_id}: 使用现有进度管理器", flush=True)
except Exception as e:
print(f"[mLSTM] 任务 {task_id}: 进度管理器初始化失败: {e}", flush=True)
# 根据训练模式加载数据
try:
if training_mode == 'store' and store_id:
# 加载特定店铺的数据
product_df = get_store_product_sales_data(
store_id,
product_id,
'pharmacy_sales_multi_store.csv'
)
training_scope = f"店铺 {store_id}"
elif training_mode == 'global':
# 聚合所有店铺的数据
product_df = aggregate_multi_store_data(
product_id,
aggregation_method=aggregation_method,
file_path='pharmacy_sales_multi_store.csv'
)
training_scope = f"全局聚合({aggregation_method})"
else:
# 默认:加载所有店铺的产品数据
product_df = load_multi_store_data('pharmacy_sales_multi_store.csv', product_id=product_id)
training_scope = "所有店铺"
except Exception as e:
print(f"多店铺数据加载失败: {e}")
# 后备方案:尝试原始数据
df = pd.read_excel('pharmacy_sales.xlsx')
product_df = df[df['product_id'] == product_id].sort_values(by='date')
training_scope = "原始数据"
# 数据量检查
min_required_samples = LOOK_BACK + FORECAST_HORIZON
if len(product_df) < min_required_samples:
error_msg = (
f"❌ 训练数据不足错误\n"
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={LOOK_BACK} + FORECAST_HORIZON={FORECAST_HORIZON})\n"
f"实际数据量: {len(product_df)}\n"
f"产品ID: {product_id}, 训练模式: {training_mode}\n"
f"建议解决方案:\n"
f"1. 生成更多数据: uv run generate_multi_store_data.py\n"
f"2. 调整配置参数: 减小 LOOK_BACK 或 FORECAST_HORIZON\n"
f"3. 使用全局训练模式聚合更多数据"
)
print(error_msg)
emit_progress(f"训练失败:数据不足 ({len(product_df)}/{min_required_samples} 天)")
raise ValueError(error_msg)
product_name = product_df['product_name'].iloc[0]
print(f"[mLSTM] 使用mLSTM模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型", flush=True)
print(f"[mLSTM] 训练范围: {training_scope}", flush=True)
print(f"[mLSTM] 版本: {version}", flush=True)
print(f"[mLSTM] 使用设备: {DEVICE}", flush=True)
print(f"[mLSTM] 模型将保存到目录: {model_dir}", flush=True)
print(f"[mLSTM] 数据量: {len(product_df)} 条记录", flush=True)
emit_progress(f"训练产品: {product_name} (ID: {product_id}) - {training_scope}")
# 创建特征和目标变量
features = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
print(f"[mLSTM] 开始数据预处理,特征: {features}", flush=True)
# 预处理数据
X = product_df[features].values
y = product_df[['sales']].values # 保持为二维数组
print(f"[mLSTM] 特征矩阵形状: {X.shape}, 目标矩阵形状: {y.shape}", flush=True)
emit_progress("数据预处理中...")
# 归一化数据
scaler_X = MinMaxScaler(feature_range=(0, 1))
scaler_y = MinMaxScaler(feature_range=(0, 1))
X_scaled = scaler_X.fit_transform(X)
y_scaled = scaler_y.fit_transform(y)
print(f"[mLSTM] 数据归一化完成", flush=True)
# 划分训练集和测试集80% 训练20% 测试)
train_size = int(len(X_scaled) * 0.8)
X_train, X_test = X_scaled[:train_size], X_scaled[train_size:]
y_train, y_test = y_scaled[:train_size], y_scaled[train_size:]
# 创建时间序列数据
trainX, trainY = create_dataset(X_train, y_train, LOOK_BACK, FORECAST_HORIZON)
testX, testY = create_dataset(X_test, y_test, LOOK_BACK, FORECAST_HORIZON)
# 转换为PyTorch的Tensor
trainX_tensor = torch.Tensor(trainX)
trainY_tensor = torch.Tensor(trainY)
testX_tensor = torch.Tensor(testX)
testY_tensor = torch.Tensor(testY)
# 创建数据加载器
train_dataset = PharmacyDataset(trainX_tensor, trainY_tensor)
test_dataset = PharmacyDataset(testX_tensor, testY_tensor)
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 更新进度管理器的批次信息
total_batches = len(train_loader)
total_samples = len(train_dataset)
print(f"[mLSTM] 数据加载器创建完成 - 批次数: {total_batches}, 样本数: {total_samples}", flush=True)
emit_progress(f"数据加载器准备完成 - 批次数: {total_batches}, 样本数: {total_samples}")
# 初始化mLSTM结合Transformer模型
input_dim = X_train.shape[1]
output_dim = FORECAST_HORIZON
hidden_size = 128
num_heads = 4
dropout_rate = 0.1
num_blocks = 3
embed_dim = 32
dense_dim = 32
print(f"[mLSTM] 初始化模型 - 输入维度: {input_dim}, 输出维度: {output_dim}", flush=True)
print(f"[mLSTM] 模型参数 - 隐藏层: {hidden_size}, 注意力头: {num_heads}", flush=True)
emit_progress(f"初始化mLSTM模型 - 输入维度: {input_dim}, 隐藏层: {hidden_size}")
model = MatrixLSTM(
num_features=input_dim,
hidden_size=hidden_size,
mlstm_layers=2,
embed_dim=embed_dim,
dense_dim=dense_dim,
num_heads=num_heads,
dropout_rate=dropout_rate,
num_blocks=num_blocks,
output_sequence_length=output_dim
)
print(f"[mLSTM] 模型创建完成", flush=True)
emit_progress("mLSTM模型初始化完成")
# 如果是继续训练,加载现有模型
if continue_training and version != 'v1':
try:
existing_model_path = get_model_file_path(product_id, 'mlstm', version)
if os.path.exists(existing_model_path):
checkpoint = torch.load(existing_model_path, map_location=DEVICE)
model.load_state_dict(checkpoint['model_state_dict'])
print(f"加载现有模型: {existing_model_path}")
emit_progress(f"加载现有模型版本 {version} 进行继续训练")
except Exception as e:
print(f"无法加载现有模型,将重新开始训练: {e}")
emit_progress("无法加载现有模型,重新开始训练")
# 将模型移动到设备上
model = model.to(DEVICE)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
emit_progress("数据预处理完成,开始模型训练...", progress=10)
# 训练模型
train_losses = []
test_losses = []
start_time = time.time()
# 配置检查点保存
checkpoint_interval = max(1, epochs // 10) # 每10%进度保存一次最少每1个epoch
best_loss = float('inf')
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}")
for epoch in range(epochs):
emit_progress(f"开始训练 Epoch {epoch+1}/{epochs}")
model.train()
epoch_loss = 0
for batch_idx, (X_batch, y_batch) in enumerate(train_loader):
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
# 确保目标张量有正确的形状
if y_batch.dim() == 2:
y_batch = y_batch.unsqueeze(-1)
# 前向传播
outputs = model(X_batch)
loss = criterion(outputs, y_batch)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
# 计算训练损失
train_loss = epoch_loss / len(train_loader)
train_losses.append(train_loss)
# 在测试集上评估
model.eval()
test_loss = 0
with torch.no_grad():
for batch_idx, (X_batch, y_batch) in enumerate(test_loader):
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
# 确保目标张量有正确的形状
if y_batch.dim() == 2:
y_batch = y_batch.unsqueeze(-1)
outputs = model(X_batch)
loss = criterion(outputs, y_batch)
test_loss += loss.item()
test_loss = test_loss / len(test_loader)
test_losses.append(test_loss)
# 计算总体训练进度
epoch_progress = ((epoch + 1) / epochs) * 90 + 10 # 10-100% 范围
# 发送训练进度
current_metrics = {
'train_loss': train_loss,
'test_loss': test_loss,
'epoch': epoch + 1,
'total_epochs': epochs,
'learning_rate': optimizer.param_groups[0]['lr']
}
emit_progress(f"Epoch {epoch+1}/{epochs} 完成 - Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}",
progress=epoch_progress, metrics=current_metrics)
# 定期保存检查点
if (epoch + 1) % checkpoint_interval == 0 or epoch == epochs - 1:
checkpoint_data = {
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_loss,
'test_loss': test_loss,
'train_losses': train_losses,
'test_losses': test_losses,
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
'input_dim': input_dim,
'output_dim': output_dim,
'hidden_size': hidden_size,
'num_heads': num_heads,
'dropout': dropout_rate,
'num_blocks': num_blocks,
'embed_dim': embed_dim,
'dense_dim': dense_dim,
'sequence_length': LOOK_BACK,
'forecast_horizon': FORECAST_HORIZON,
'model_type': 'mlstm'
},
'training_info': {
'product_id': product_id,
'product_name': product_name,
'training_mode': training_mode,
'store_id': store_id,
'aggregation_method': aggregation_method,
'training_scope': training_scope,
'timestamp': time.time()
}
}
# 保存检查点
save_checkpoint(checkpoint_data, epoch + 1, product_id, 'mlstm',
model_dir, store_id, training_mode, aggregation_method)
# 如果是最佳模型,额外保存一份
if test_loss < best_loss:
best_loss = test_loss
save_checkpoint(checkpoint_data, 'best', product_id, 'mlstm',
model_dir, store_id, training_mode, aggregation_method)
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
emit_progress(f"💾 保存训练检查点 epoch_{epoch+1}")
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}", flush=True)
# 计算训练时间
training_time = time.time() - start_time
emit_progress("生成损失曲线...", progress=95)
# 确定模型保存目录(支持多店铺)
if store_id:
# 为特定店铺创建子目录
store_model_dir = os.path.join(model_dir, 'mlstm', store_id)
os.makedirs(store_model_dir, exist_ok=True)
loss_curve_filename = f"{product_id}_mlstm_{version}_loss_curve.png"
loss_curve_path = os.path.join(store_model_dir, loss_curve_filename)
else:
# 全局模型保存在global目录
global_model_dir = os.path.join(model_dir, 'mlstm', 'global')
os.makedirs(global_model_dir, exist_ok=True)
loss_curve_filename = f"{product_id}_mlstm_{version}_global_loss_curve.png"
loss_curve_path = os.path.join(global_model_dir, loss_curve_filename)
# 绘制损失曲线并保存到模型目录
plt.figure(figsize=(10, 6))
plt.plot(train_losses, label='Training Loss')
plt.plot(test_losses, label='Test Loss')
title_suffix = f" - {training_scope}" if store_id else " - 全局模型"
plt.title(f'mLSTM 模型训练损失曲线 - {product_name} ({version}){title_suffix}')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.savefig(loss_curve_path, dpi=300, bbox_inches='tight')
plt.close()
print(f"损失曲线已保存到: {loss_curve_path}")
emit_progress("模型评估中...", progress=98)
# 评估模型
model.eval()
with torch.no_grad():
test_pred = model(testX_tensor.to(DEVICE)).cpu().numpy()
# 处理输出形状
if len(test_pred.shape) == 3:
test_pred = test_pred.squeeze(-1)
# 反归一化预测结果和真实值
test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, 1)).flatten()
test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, 1)).flatten()
# 计算评估指标
metrics = evaluate_model(test_true_inv, test_pred_inv)
metrics['training_time'] = training_time
metrics['version'] = version
# 打印评估指标
print("\n模型评估指标:")
print(f"MSE: {metrics['mse']:.4f}")
print(f"RMSE: {metrics['rmse']:.4f}")
print(f"MAE: {metrics['mae']:.4f}")
print(f"R²: {metrics['r2']:.4f}")
print(f"MAPE: {metrics['mape']:.2f}%")
print(f"训练时间: {training_time:.2f}")
emit_progress("保存最终模型...", progress=99)
# 保存最终训练完成的模型基于最终epoch
final_model_data = {
'epoch': epochs, # 最终epoch
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_losses[-1],
'test_loss': test_losses[-1],
'train_losses': train_losses,
'test_losses': test_losses,
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
'input_dim': input_dim,
'output_dim': output_dim,
'hidden_size': hidden_size,
'num_heads': num_heads,
'dropout': dropout_rate,
'num_blocks': num_blocks,
'embed_dim': embed_dim,
'dense_dim': dense_dim,
'sequence_length': LOOK_BACK,
'forecast_horizon': FORECAST_HORIZON,
'model_type': 'mlstm'
},
'metrics': metrics,
'loss_curve_path': loss_curve_path,
'training_info': {
'product_id': product_id,
'product_name': product_name,
'training_mode': training_mode,
'store_id': store_id,
'aggregation_method': aggregation_method,
'training_scope': training_scope,
'timestamp': time.time(),
'training_completed': True
}
}
# 保存最终模型使用epoch标识
final_model_path = save_checkpoint(
final_model_data, f"final_epoch_{epochs}", product_id, 'mlstm',
model_dir, store_id, training_mode, aggregation_method
)
# 发送训练完成消息
final_metrics = {
'mse': metrics['mse'],
'rmse': metrics['rmse'],
'mae': metrics['mae'],
'r2': metrics['r2'],
'mape': metrics['mape'],
'training_time': training_time,
'final_epoch': epochs,
'model_path': final_model_path
}
emit_progress(f"✅ mLSTM模型训练完成最终epoch: {epochs} 已保存", progress=100, metrics=final_metrics)
return model, metrics, epochs, final_model_path