ShopTRAINING/server/trainers/mlstm_trainer.py

495 lines
18 KiB
Python
Raw Normal View History

"""
药店销售预测系统 - mLSTM模型训练函数
"""
import os
import time
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
from tqdm import tqdm
from models.mlstm_model import MLSTMTransformer as MatrixLSTM
from utils.data_utils import prepare_data, PharmacyDataset, prepare_sequences
2025-07-02 11:05:23 +08:00
from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data
from utils.visualization import plot_loss_curve
from analysis.metrics import evaluate_model
2025-07-02 11:05:23 +08:00
from core.config import (
DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
2025-07-02 11:05:23 +08:00
)
from utils.training_progress import progress_manager
from utils.model_manager import model_manager
2025-07-02 11:05:23 +08:00
def train_product_model_with_mlstm(
model_identifier: str,
training_df: pd.DataFrame,
feature_list: list,
training_mode: str,
epochs: int = 50,
sequence_length: int = LOOK_BACK,
forecast_horizon: int = FORECAST_HORIZON,
model_dir: str = DEFAULT_MODEL_DIR,
product_id: str = None,
store_id: str = None,
aggregation_method: str = None,
version: str = None,
2025-07-02 11:05:23 +08:00
socketio=None,
task_id: str = None,
progress_callback=None,
patience: int = 10,
learning_rate: float = 0.001,
clip_norm: float = 1.0,
continue_training: bool = False,
**kwargs # 捕获其他未使用参数
2025-07-02 11:05:23 +08:00
):
"""
使用mLSTM训练产品销售预测模型 (新数据管道版)
2025-07-02 11:05:23 +08:00
参数:
model_identifier: 模型的唯一标识符 (e.g., product_id, store_id)
training_df: 用于训练的已筛选的DataFrame
feature_list: 用于训练的特征列名列表
2025-07-02 11:05:23 +08:00
training_mode: 训练模式 ('product', 'store', 'global')
epochs: 训练轮次
... (其他参数)
"""
training_scope = training_mode
2025-07-02 11:05:23 +08:00
# 创建WebSocket进度反馈函数支持多进程
def emit_progress(message, progress=None, metrics=None):
"""发送训练进度到前端"""
progress_data = {
'task_id': task_id,
'message': message,
'timestamp': time.time()
}
if progress is not None:
progress_data['progress'] = progress
if metrics is not None:
progress_data['metrics'] = metrics
# 在多进程环境中使用progress_callback
if progress_callback:
try:
progress_callback(progress_data)
except Exception as e:
print(f"[mLSTM] 进度回调失败: {e}")
# 在单进程环境中使用socketio
if socketio and task_id:
try:
socketio.emit('training_progress', progress_data, namespace='/training')
except Exception as e:
print(f"[mLSTM] WebSocket发送失败: {e}")
print(f"[mLSTM] {message}", flush=True)
# 强制刷新输出缓冲区
import sys
sys.stdout.flush()
sys.stderr.flush()
emit_progress("开始mLSTM模型训练...")
# 确定版本号
emit_progress(f"开始训练 mLSTM 模型")
if version:
emit_progress(f"使用指定版本: {version}")
2025-07-02 11:05:23 +08:00
# 初始化训练进度管理器(如果还未初始化)
if socketio and task_id:
print(f"[mLSTM] 任务 {task_id}: 开始mLSTM训练器", flush=True)
try:
# 初始化进度管理器
if not hasattr(progress_manager, 'training_id') or progress_manager.training_id != task_id:
progress_manager.start_training(
training_id=task_id,
product_id=product_id,
model_type='mlstm',
training_mode=training_mode,
total_epochs=epochs,
total_batches=0, # 将在后面设置
batch_size=32, # 默认值
total_samples=0 # 将在后面设置
)
print(f"[mLSTM] 任务 {task_id}: 进度管理器已初始化", flush=True)
else:
print(f"[mLSTM] 任务 {task_id}: 使用现有进度管理器", flush=True)
except Exception as e:
print(f"[mLSTM] 任务 {task_id}: 进度管理器初始化失败: {e}", flush=True)
2025-07-02 11:05:23 +08:00
# 数据量检查
2025-07-16 12:59:56 +08:00
min_required_samples = sequence_length + forecast_horizon
if len(training_df) < min_required_samples:
error_msg = f"训练数据不足: 需要 {min_required_samples} 条记录, 但只有 {len(training_df)} 条。"
emit_progress(error_msg)
2025-07-02 11:05:23 +08:00
raise ValueError(error_msg)
# 从DataFrame中提取信息
# 注意product_name等信息可能不存在于全局聚合数据中需安全获取
product_name = training_df['product_name'].iloc[0] if 'product_name' in training_df.columns else model_identifier
2025-07-02 11:05:23 +08:00
emit_progress(f"开始为 '{product_name}' (标识: {model_identifier}) 训练mLSTM模型")
# --- 新数据管道核心改造 ---
2025-07-02 11:05:23 +08:00
emit_progress("数据预处理中...")
# 1. 使用标准化的 prepare_data 函数处理数据
_, _, trainX, testX, trainY, testY, scaler_X, scaler_y = prepare_data(
training_df=training_df,
feature_list=feature_list,
target_column='net_sales_quantity',
sequence_length=sequence_length,
forecast_horizon=forecast_horizon
)
# 2. 使用标准化的 prepare_sequences 函数创建 DataLoader
batch_size = 32
train_loader = prepare_sequences(trainX, trainY, batch_size)
test_loader = prepare_sequences(testX, testY, batch_size)
2025-07-02 11:05:23 +08:00
# 更新进度管理器的批次信息
total_batches = len(train_loader)
total_samples = len(trainX)
2025-07-02 11:05:23 +08:00
print(f"[mLSTM] 数据加载器创建完成 - 批次数: {total_batches}, 样本数: {total_samples}", flush=True)
emit_progress(f"数据加载器准备完成 - 批次数: {total_batches}, 样本数: {total_samples}")
# 初始化mLSTM结合Transformer模型
input_dim = trainX.shape[2]
2025-07-16 12:59:56 +08:00
output_dim = forecast_horizon
hidden_size = 128
num_heads = 4
dropout_rate = 0.1
num_blocks = 3
embed_dim = 32
dense_dim = 32
2025-07-02 11:05:23 +08:00
print(f"[mLSTM] 初始化模型 - 输入维度: {input_dim}, 输出维度: {output_dim}", flush=True)
print(f"[mLSTM] 模型参数 - 隐藏层: {hidden_size}, 注意力头: {num_heads}", flush=True)
emit_progress(f"初始化mLSTM模型 - 输入维度: {input_dim}, 隐藏层: {hidden_size}")
model = MatrixLSTM(
num_features=input_dim,
hidden_size=hidden_size,
mlstm_layers=2,
embed_dim=embed_dim,
dense_dim=dense_dim,
num_heads=num_heads,
dropout_rate=dropout_rate,
num_blocks=num_blocks,
output_sequence_length=output_dim
)
2025-07-02 11:05:23 +08:00
print(f"[mLSTM] 模型创建完成", flush=True)
emit_progress("mLSTM模型初始化完成")
# 如果是继续训练,加载现有模型
if continue_training and version != 'v1':
# TODO: Implement continue_training logic with the new model_manager
pass
2025-07-02 11:05:23 +08:00
# 将模型移动到设备上
model = model.to(DEVICE)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
2025-07-15 11:55:39 +08:00
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience // 2, factor=0.5)
2025-07-02 11:05:23 +08:00
emit_progress("数据预处理完成,开始模型训练...", progress=10)
# 训练模型
# 版本锁定
current_version = model_manager.peek_next_version(
model_type='mlstm',
product_id=product_id,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method
)
print(f"🔒 本次训练版本锁定为: {current_version}")
train_losses = []
test_losses = []
start_time = time.time()
2025-07-02 11:05:23 +08:00
# 配置检查点保存
checkpoint_interval = max(1, epochs // 10) # 每10%进度保存一次最少每1个epoch
best_loss = float('inf')
epochs_no_improve = 0
best_model_path = None
2025-07-02 11:05:23 +08:00
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}, 耐心值: {patience}")
2025-07-02 11:05:23 +08:00
for epoch in range(epochs):
2025-07-02 11:05:23 +08:00
emit_progress(f"开始训练 Epoch {epoch+1}/{epochs}")
model.train()
epoch_loss = 0
2025-07-02 11:05:23 +08:00
for batch_idx, (X_batch, y_batch) in enumerate(train_loader):
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
# 确保目标张量有正确的形状
# 前向传播
outputs = model(X_batch)
loss = criterion(outputs, y_batch)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
if clip_norm:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_norm)
optimizer.step()
epoch_loss += loss.item()
# 计算训练损失
train_loss = epoch_loss / len(train_loader)
train_losses.append(train_loss)
# 在测试集上评估
model.eval()
test_loss = 0
with torch.no_grad():
for X_batch, y_batch in test_loader:
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
outputs = model(X_batch)
loss = criterion(outputs, y_batch)
test_loss += loss.item()
test_loss = test_loss / len(test_loader)
test_losses.append(test_loss)
# 更新学习率
scheduler.step(test_loss)
2025-07-02 11:05:23 +08:00
# 计算总体训练进度
epoch_progress = ((epoch + 1) / epochs) * 90 + 10 # 10-100% 范围
# 发送训练进度
current_metrics = {
'train_loss': train_loss,
'test_loss': test_loss,
'epoch': epoch + 1,
'total_epochs': epochs,
'learning_rate': optimizer.param_groups[0]['lr']
}
emit_progress(f"Epoch {epoch+1}/{epochs} 完成 - Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}",
progress=epoch_progress, metrics=current_metrics)
# 定期保存检查点
if (epoch + 1) % checkpoint_interval == 0 or epoch == epochs - 1:
checkpoint_data = {
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_loss,
'test_loss': test_loss,
'train_losses': train_losses,
'test_losses': test_losses,
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
'input_dim': input_dim,
'output_dim': output_dim,
'hidden_size': hidden_size,
'num_heads': num_heads,
2025-07-16 18:50:16 +08:00
'dropout_rate': dropout_rate,
2025-07-02 11:05:23 +08:00
'num_blocks': num_blocks,
'embed_dim': embed_dim,
'dense_dim': dense_dim,
2025-07-16 18:50:16 +08:00
'mlstm_layers': 2, # 确保这个参数被保存
2025-07-16 12:59:56 +08:00
'sequence_length': sequence_length,
'forecast_horizon': forecast_horizon,
2025-07-02 11:05:23 +08:00
'model_type': 'mlstm'
},
'training_info': {
'product_id': product_id,
'product_name': product_name,
'training_mode': training_mode,
'store_id': store_id,
'aggregation_method': aggregation_method,
'training_scope': training_scope,
'timestamp': time.time()
}
}
# 如果是最佳模型,额外保存一份
if test_loss < best_loss:
best_loss = test_loss
best_model_path, _ = model_manager.save_model(
model_data=checkpoint_data,
product_id=product_id,
model_type='mlstm',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name,
version=f"{current_version}_best"
)
2025-07-02 11:05:23 +08:00
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
epochs_no_improve = 0
else:
epochs_no_improve += 1
2025-07-02 11:05:23 +08:00
if (epoch + 1) % 10 == 0:
2025-07-02 11:05:23 +08:00
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}", flush=True)
# 提前停止逻辑
if epochs_no_improve >= patience:
emit_progress(f"连续 {patience} 个epoch测试损失未改善提前停止训练。")
break
# 计算训练时间
training_time = time.time() - start_time
2025-07-02 11:05:23 +08:00
emit_progress("模型评估中...", progress=98)
# 评估模型
model.eval()
with torch.no_grad():
# 重新从加载器获取所有测试数据以进行最终评估
all_test_X = []
all_test_Y = []
for X_batch, y_batch in test_loader:
all_test_X.append(X_batch)
all_test_Y.append(y_batch)
testX_tensor = torch.cat(all_test_X, dim=0)
testY_tensor = torch.cat(all_test_Y, dim=0)
test_pred = model(testX_tensor.to(DEVICE)).cpu().numpy()
test_true = testY_tensor.cpu().numpy()
# 反归一化预测结果和真实值
test_pred_inv = scaler_y.inverse_transform(test_pred)
test_true_inv = scaler_y.inverse_transform(test_true)
# 计算评估指标
metrics = evaluate_model(test_true_inv, test_pred_inv)
metrics['training_time'] = training_time
# 打印评估指标
print("\n模型评估指标:")
print(f"MSE: {metrics['mse']:.4f}")
print(f"RMSE: {metrics['rmse']:.4f}")
print(f"MAE: {metrics['mae']:.4f}")
print(f"R²: {metrics['r2']:.4f}")
print(f"MAPE: {metrics['mape']:.2f}%")
print(f"训练时间: {training_time:.2f}")
2025-07-02 11:05:23 +08:00
emit_progress("保存最终模型...", progress=99)
2025-07-02 11:05:23 +08:00
# 保存最终训练完成的模型基于最终epoch
final_model_data = {
'epoch': epochs, # 最终epoch
'model_state_dict': model.state_dict(),
2025-07-02 11:05:23 +08:00
'optimizer_state_dict': optimizer.state_dict(),
'train_loss': train_losses[-1],
'test_loss': test_losses[-1],
'train_losses': train_losses,
'test_losses': test_losses,
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
'input_dim': input_dim,
'output_dim': output_dim,
'hidden_size': hidden_size,
'num_heads': num_heads,
2025-07-16 18:50:16 +08:00
'dropout_rate': dropout_rate,
'num_blocks': num_blocks,
'embed_dim': embed_dim,
'dense_dim': dense_dim,
2025-07-16 18:50:16 +08:00
'mlstm_layers': 2, # 确保这个参数被保存
2025-07-16 12:59:56 +08:00
'sequence_length': sequence_length,
'forecast_horizon': forecast_horizon,
'model_type': 'mlstm'
},
'metrics': metrics,
'metrics': metrics,
2025-07-02 11:05:23 +08:00
'training_info': {
'product_id': product_id,
'product_name': product_name,
'training_mode': training_mode,
'store_id': store_id,
'aggregation_method': aggregation_method,
'training_scope': training_scope,
'timestamp': time.time(),
'training_completed': True
}
}
# 准备 scope 和 identifier 以生成标准化的文件名
scope = training_mode
if scope == 'product':
identifier = model_identifier
elif scope == 'store':
identifier = store_id
elif scope == 'global':
identifier = aggregation_method
else:
identifier = product_name # 后备方案
# 绘制带有版本号的损失曲线图
loss_curve_path = plot_loss_curve(
train_losses=train_losses,
val_losses=test_losses,
model_type='mlstm',
scope=scope,
identifier=identifier,
version=current_version, # 使用锁定的版本
model_dir=model_dir
)
print(f"📈 带版本号的损失曲线已保存: {loss_curve_path}")
# 更新模型数据中的损失图路径
final_model_data['loss_curve_path'] = loss_curve_path
# 保存最终模型,让 model_manager 自动处理版本号
final_model_path, final_version = model_manager.save_model(
model_data=final_model_data,
product_id=product_id,
model_type='mlstm',
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name,
version=current_version
2025-07-02 11:05:23 +08:00
)
# 发送训练完成消息
final_metrics = {
'mse': metrics['mse'],
'rmse': metrics['rmse'],
'mae': metrics['mae'],
'r2': metrics['r2'],
'mape': metrics['mape'],
'training_time': training_time,
'final_epoch': epochs,
'model_path': final_model_path,
'version': final_version
2025-07-02 11:05:23 +08:00
}
emit_progress(f"✅ mLSTM模型训练完成版本 {final_version} 已保存", progress=100, metrics=final_metrics)
# 组装 artifacts 字典
artifacts = {
"versioned_model": final_model_path,
"loss_curve_plot": loss_curve_path,
# 假设 best model 的路径可以从 model_manager 获取或推断
"best_model": best_model_path,
"version": final_version
}
return metrics, artifacts
2025-07-22 15:40:37 +08:00
# --- 将此训练器注册到系统中 ---
from models.model_registry import register_trainer
register_trainer('mlstm', train_product_model_with_mlstm)