ShopTRAINING/server/trainers/mlstm_trainer.py
xz2000 e999ed4af2 ### 2025-07-15 (续): 训练器与核心调用层重构
**核心目标**: 将新的 `ModelManager` 统一应用到项目中所有剩余的模型训练器,并重构核心调用逻辑,确保整个训练链路的架构一致性。

**1. 修改 `server/trainers/kan_trainer.py`**
*   **内容**: 完全重写了 `kan_trainer.py`。
    *   **适配接口**: 函数签名与 `mlstm_trainer` 对齐,增加了 `socketio`, `task_id`, `patience` 等参数。
    *   **集成 `ModelManager`**: 移除了所有旧的、手动的保存逻辑,改为在训练开始时调用 `model_manager` 获取版本号和路径。
    *   **标准化产物保存**: 所有产物(模型、元数据、检查点、损失曲线)均通过 `model_manager.save_model_artifact()` 保存。
    *   **增加健壮性**: 引入了早停(Early Stopping)和保存最佳检查点(Best Checkpoint)的逻辑。

**2. 修改 `server/trainers/tcn_trainer.py`**
*   **内容**: 完全重写了 `tcn_trainer.py`,应用了与 `kan_trainer` 完全相同的重构模式。
    *   移除了旧的 `save_checkpoint` 辅助函数和基于 `core.config` 的版本管理。
    *   全面转向使用 `model_manager` 进行版本控制和文件保存。
    *   统一了函数签名和进度反馈逻辑。

**3. 修改 `server/trainers/transformer_trainer.py`**
*   **内容**: 完全重写了 `transformer_trainer.py`,完成了对所有训练器的统一重构。
    *   移除了所有遗留的、基于文件名的路径拼接和保存逻辑。
    *   实现了与其它训练器一致的、基于 `ModelManager` 的标准化训练流程。

**4. 修改 `server/core/predictor.py`**
*   **内容**: 对核心预测器类 `PharmacyPredictor` 进行了彻底重构。
    *   **统一调用接口**: `train_model` 方法现在以完全一致的方式调用所有(`mlstm`, `kan`, `tcn`, `transformer`)训练器。
    *   **移除旧逻辑**: 删除了 `_parse_model_filename` 等所有基于文件名解析的旧方法。
    *   **适配 `ModelManager`**: `list_models` 和 `delete_model` 等方法现在直接调用 `model_manager` 的相应功能,不再自己实现逻辑。
    *   **简化 `predict`**: 预测方法现在直接接收标准化的模型版本路径 (`model_version_path`) 作为输入,逻辑更清晰。
2025-07-15 20:09:09 +08:00

304 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
药店销售预测系统 - mLSTM模型训练函数
"""
import os
import time
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
from datetime import datetime
from models.mlstm_model import MLSTMTransformer as MatrixLSTM
from utils.data_utils import create_dataset, PharmacyDataset
from analysis.metrics import evaluate_model
from core.config import (
DEVICE, LOOK_BACK, FORECAST_HORIZON
)
from utils.training_progress import progress_manager
from utils.model_manager import model_manager
def train_product_model_with_mlstm(
product_id,
product_df,
store_id=None,
training_mode='product',
aggregation_method='sum',
epochs=50,
socketio=None,
task_id=None,
continue_training=False,
progress_callback=None,
patience=10,
learning_rate=0.001,
clip_norm=1.0
):
"""
使用mLSTM训练产品销售预测模型
参数:
product_id: 产品ID
store_id: 店铺ID为None时使用全局数据
training_mode: 训练模式 ('product', 'store', 'global')
aggregation_method: 聚合方法 ('sum', 'mean', 'weighted')
epochs: 训练轮次
model_dir: 模型保存目录
version: 模型版本如果为None则自动生成
socketio: Socket.IO实例用于实时进度推送
task_id: 任务ID
continue_training: 是否继续训练
progress_callback: 进度回调函数,用于多进程训练
"""
# 创建WebSocket进度反馈函数支持多进程 """
def emit_progress(message, progress=None, metrics=None):
"""发送训练进度到前端"""
progress_data = {
'task_id': task_id,
'message': message,
'timestamp': time.time()
}
if progress is not None:
progress_data['progress'] = progress
if metrics is not None:
progress_data['metrics'] = metrics
if progress_callback:
try:
progress_callback(progress_data)
except Exception as e:
print(f"[mLSTM] 进度回调失败: {e}")
if socketio and task_id:
try:
socketio.emit('training_progress', progress_data, namespace='/training')
except Exception as e:
print(f"[mLSTM] WebSocket发送失败: {e}")
print(f"[mLSTM] {message}", flush=True)
import sys
sys.stdout.flush()
sys.stderr.flush()
emit_progress("开始mLSTM模型训练...")
# 1. 确定模型标识符和版本
model_type = 'mlstm'
if training_mode == 'store':
scope = f"{store_id}_{product_id}"
elif training_mode == 'global':
scope = f"{product_id}" if product_id else "all"
else:
scope = f"{product_id}_all"
model_identifier = model_manager.get_model_identifier(model_type, training_mode, scope, aggregation_method)
version = model_manager.get_next_version_number(model_identifier)
emit_progress(f"开始训练 mLSTM 模型 v{version}")
# 2. 获取模型版本路径
model_version_path = model_manager.get_model_version_path(model_type, training_mode, scope, version, aggregation_method)
emit_progress(f"模型将保存到: {model_version_path}")
if training_mode == 'store' and store_id:
training_scope = f"店铺 {store_id}"
elif training_mode == 'global':
training_scope = f"全局聚合({aggregation_method})"
else:
training_scope = "所有店铺"
min_required_samples = LOOK_BACK + FORECAST_HORIZON
if len(product_df) < min_required_samples:
error_msg = f"数据不足: 需要 {min_required_samples} 天, 实际 {len(product_df)} 天。"
print(error_msg)
emit_progress(f"训练失败:{error_msg}")
raise ValueError(error_msg)
product_name = product_df['product_name'].iloc[0]
print(f"[mLSTM] 使用mLSTM模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型", flush=True)
print(f"[mLSTM] 训练范围: {training_scope}", flush=True)
print(f"[mLSTM] 版本: v{version}", flush=True)
print(f"[mLSTM] 使用设备: {DEVICE}", flush=True)
print(f"[mLSTM] 数据量: {len(product_df)} 条记录", flush=True)
emit_progress(f"训练产品: {product_name} (ID: {product_id}) - {training_scope}")
# 创建特征和目标变量
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
print(f"[mLSTM] 开始数据预处理,特征: {features}", flush=True)
# 预处理数据
X = product_df[features].values
y = product_df[['sales']].values
print(f"[mLSTM] 特征矩阵形状: {X.shape}, 目标矩阵形状: {y.shape}", flush=True)
emit_progress("数据预处理中...")
scaler_X = MinMaxScaler(feature_range=(0, 1))
scaler_y = MinMaxScaler(feature_range=(0, 1))
X_scaled = scaler_X.fit_transform(X)
y_scaled = scaler_y.fit_transform(y)
print(f"[mLSTM] 数据归一化完成", flush=True)
train_size = int(len(X_scaled) * 0.8)
X_train, X_test = X_scaled[:train_size], X_scaled[train_size:]
y_train, y_test = y_scaled[:train_size], y_scaled[train_size:]
trainX, trainY = create_dataset(X_train, y_train, LOOK_BACK, FORECAST_HORIZON)
testX, testY = create_dataset(X_test, y_test, LOOK_BACK, FORECAST_HORIZON)
train_loader = DataLoader(PharmacyDataset(torch.Tensor(trainX), torch.Tensor(trainY)), batch_size=32, shuffle=True)
test_loader = DataLoader(PharmacyDataset(torch.Tensor(testX), torch.Tensor(testY)), batch_size=32, shuffle=False)
print(f"[mLSTM] 数据加载器创建完成 - 批次数: {total_batches}, 样本数: {total_samples}", flush=True)
emit_progress(f"数据加载器准备完成 - 批次数: {total_batches}, 样本数: {total_samples}")
input_dim = X_train.shape[1]
output_dim = FORECAST_HORIZON
hidden_size, num_heads, dropout_rate, num_blocks, embed_dim, dense_dim = 128, 4, 0.1, 3, 32, 32
model = MatrixLSTM(
num_features=input_dim, hidden_size=hidden_size, mlstm_layers=2, embed_dim=embed_dim,
dense_dim=dense_dim, num_heads=num_heads, dropout_rate=dropout_rate,
num_blocks=num_blocks, output_sequence_length=output_dim
).to(DEVICE)
print(f"[mLSTM] 模型创建完成", flush=True)
emit_progress("mLSTM模型初始化完成")
if continue_training:
emit_progress("继续训练模式启动,但当前重构版本将从头开始。")
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience // 2, factor=0.5)
emit_progress("数据预处理完成,开始模型训练...", progress=10)
train_losses, test_losses = [], []
start_time = time.time()
checkpoint_interval = max(1, epochs // 10)
best_loss = float('inf')
epochs_no_improve = 0
emit_progress(f"开始训练 - 总epoch: {epochs}, 检查点间隔: {checkpoint_interval}, 耐心值: {patience}")
for epoch in range(epochs):
emit_progress(f"开始训练 Epoch {epoch+1}/{epochs}")
model.train()
epoch_loss = 0
for X_batch, y_batch in train_loader:
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
optimizer.zero_grad()
outputs = model(X_batch)
loss = criterion(outputs, y_batch)
loss.backward()
if clip_norm:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_norm)
optimizer.step()
epoch_loss += loss.item()
# 计算训练损失
train_loss = epoch_loss / len(train_loader)
train_losses.append(train_loss)
# 在测试集上评估
model.eval()
test_loss = 0
with torch.no_grad():
for X_batch, y_batch in test_loader:
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
outputs = model(X_batch)
loss = criterion(outputs, y_batch)
test_loss += loss.item()
test_loss /= len(test_loader)
test_losses.append(test_loss)
# 更新学习率
scheduler.step(test_loss)
emit_progress(f"Epoch {epoch+1}/{epochs} 完成 - Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}",
progress=10 + ((epoch + 1) / epochs) * 85)
# 定期保存检查点
# 3. 保存检查点 checkpoint_data = {
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scaler_X': scaler_X,
'scaler_y': scaler_y,
}
if (epoch + 1) % checkpoint_interval == 0:
model_manager.save_model_artifact(checkpoint_data, f"checkpoint_epoch_{epoch+1}.pth", model_version_path)
emit_progress(f"💾 保存训练检查点 epoch_{epoch+1}")
if test_loss < best_loss:
best_loss = test_loss
model_manager.save_model_artifact(checkpoint_data, "checkpoint_best.pth", model_version_path)
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
epochs_no_improve = 0
else:
epochs_no_improve += 1
if epochs_no_improve >= patience:
emit_progress(f"连续 {patience} 个epoch测试损失未改善提前停止训练。")
break
training_time = time.time() - start_time
loss_fig = plt.figure(figsize=(10, 6))
plt.plot(train_losses, label='Training Loss')
plt.plot(test_losses, label='Test Loss')
plt.title(f'mLSTM 损失曲线 - {product_name} (v{version}) - {training_scope}')
plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.grid(True)
model_manager.save_model_artifact(loss_fig, "loss_curve.png", model_version_path)
plt.close(loss_fig)
print(f"损失曲线已保存到: {os.path.join(model_version_path, 'loss_curve.png')}")
model.eval()
with torch.no_grad():
test_pred = model(torch.Tensor(testX).to(DEVICE)).cpu().numpy()
metrics = evaluate_model(scaler_y.inverse_transform(testY), scaler_y.inverse_transform(test_pred))
metrics['training_time'] = training_time
# 打印评估指标
print("\n模型评估指标:")
print(f"MSE: {metrics['mse']:.4f}")
print(f"RMSE: {metrics['rmse']:.4f}")
print(f"MAE: {metrics['mae']:.4f}")
print(f"R²: {metrics['r2']:.4f}")
print(f"MAPE: {metrics['mape']:.2f}%")
print(f"训练时间: {training_time:.2f}")
final_model_data = {
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'scaler_X': scaler_X,
'scaler_y': scaler_y,
}
model_manager.save_model_artifact(final_model_data, "model.pth", model_version_path)
metadata = {
'product_id': product_id, 'product_name': product_name, 'model_type': model_type,
'version': f'v{version}', 'training_mode': training_mode, 'scope': scope,
'aggregation_method': aggregation_method, 'training_scope_description': training_scope,
'timestamp': datetime.now().isoformat(), 'metrics': metrics,
'config': {
'input_dim': input_dim, 'output_dim': output_dim, 'hidden_size': hidden_size,
'num_heads': num_heads, 'dropout': dropout_rate, 'num_blocks': num_blocks,
'embed_dim': embed_dim, 'dense_dim': dense_dim,
'sequence_length': LOOK_BACK, 'forecast_horizon': FORECAST_HORIZON,
}
}
model_manager.save_model_artifact(metadata, "metadata.json", model_version_path)
# 6. 更新版本文件
model_manager.update_version(model_identifier, version)
emit_progress(f"✅ mLSTM模型 v{version} 训练完成!", progress=100, metrics=metrics)
return model, metrics, version, model_version_path