ShopTRAINING/server/trainers/transformer_trainer.py
xz2000 e999ed4af2 ### 2025-07-15 (续): 训练器与核心调用层重构
**核心目标**: 将新的 `ModelManager` 统一应用到项目中所有剩余的模型训练器,并重构核心调用逻辑,确保整个训练链路的架构一致性。

**1. 修改 `server/trainers/kan_trainer.py`**
*   **内容**: 完全重写了 `kan_trainer.py`。
    *   **适配接口**: 函数签名与 `mlstm_trainer` 对齐,增加了 `socketio`, `task_id`, `patience` 等参数。
    *   **集成 `ModelManager`**: 移除了所有旧的、手动的保存逻辑,改为在训练开始时调用 `model_manager` 获取版本号和路径。
    *   **标准化产物保存**: 所有产物(模型、元数据、检查点、损失曲线)均通过 `model_manager.save_model_artifact()` 保存。
    *   **增加健壮性**: 引入了早停(Early Stopping)和保存最佳检查点(Best Checkpoint)的逻辑。

**2. 修改 `server/trainers/tcn_trainer.py`**
*   **内容**: 完全重写了 `tcn_trainer.py`,应用了与 `kan_trainer` 完全相同的重构模式。
    *   移除了旧的 `save_checkpoint` 辅助函数和基于 `core.config` 的版本管理。
    *   全面转向使用 `model_manager` 进行版本控制和文件保存。
    *   统一了函数签名和进度反馈逻辑。

**3. 修改 `server/trainers/transformer_trainer.py`**
*   **内容**: 完全重写了 `transformer_trainer.py`,完成了对所有训练器的统一重构。
    *   移除了所有遗留的、基于文件名的路径拼接和保存逻辑。
    *   实现了与其它训练器一致的、基于 `ModelManager` 的标准化训练流程。

**4. 修改 `server/core/predictor.py`**
*   **内容**: 对核心预测器类 `PharmacyPredictor` 进行了彻底重构。
    *   **统一调用接口**: `train_model` 方法现在以完全一致的方式调用所有(`mlstm`, `kan`, `tcn`, `transformer`)训练器。
    *   **移除旧逻辑**: 删除了 `_parse_model_filename` 等所有基于文件名解析的旧方法。
    *   **适配 `ModelManager`**: `list_models` 和 `delete_model` 等方法现在直接调用 `model_manager` 的相应功能,不再自己实现逻辑。
    *   **简化 `predict`**: 预测方法现在直接接收标准化的模型版本路径 (`model_version_path`) 作为输入,逻辑更清晰。
2025-07-15 20:09:09 +08:00

277 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
药店销售预测系统 - Transformer模型训练函数 (已重构)
"""
import os
import time
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
from datetime import datetime
from models.transformer_model import TimeSeriesTransformer
from utils.data_utils import create_dataset, PharmacyDataset
from analysis.metrics import evaluate_model
from core.config import DEVICE, LOOK_BACK, FORECAST_HORIZON
from utils.model_manager import model_manager
def train_product_model_with_transformer(
product_id,
product_df=None,
store_id=None,
training_mode='product',
aggregation_method='sum',
epochs=50,
socketio=None,
task_id=None,
progress_callback=None,
patience=10,
learning_rate=0.001,
clip_norm=1.0
):
"""
使用Transformer模型训练产品销售预测模型 (已适配新的ModelManager)
"""
def emit_progress(message, progress=None, metrics=None):
"""发送训练进度到前端"""
progress_data = {
'task_id': task_id,
'message': f"[Transformer] {message}",
'timestamp': time.time()
}
if progress is not None:
progress_data['progress'] = progress
if metrics is not None:
progress_data['metrics'] = metrics
if progress_callback:
try:
progress_callback(progress_data)
except Exception as e:
print(f"[Transformer] 进度回调失败: {e}")
if socketio and task_id:
try:
socketio.emit('training_progress', progress_data, namespace='/training')
except Exception as e:
print(f"[Transformer] WebSocket发送失败: {e}")
print(f"[Transformer] {message}", flush=True)
emit_progress("开始Transformer模型训练...")
# 1. 确定模型标识符和版本
model_type = 'transformer'
if training_mode == 'store':
scope = f"{store_id}_{product_id}"
elif training_mode == 'global':
scope = f"{product_id}" if product_id else "all"
else: # 'product' mode
scope = f"{product_id}_all"
model_identifier = model_manager.get_model_identifier(model_type, training_mode, scope, aggregation_method)
version = model_manager.get_next_version_number(model_identifier)
emit_progress(f"开始训练 {model_type} 模型 v{version}")
# 2. 获取模型版本路径
model_version_path = model_manager.get_model_version_path(model_type, training_mode, scope, version, aggregation_method)
emit_progress(f"模型将保存到: {model_version_path}")
# 3. 数据加载和预处理
if product_df is None:
from utils.multi_store_data_utils import load_multi_store_data, get_store_product_sales_data, aggregate_multi_store_data
try:
if training_mode == 'store' and store_id:
product_df = get_store_product_sales_data(store_id, product_id, 'pharmacy_sales_multi_store.csv')
elif training_mode == 'global':
product_df = aggregate_multi_store_data(product_id, aggregation_method=aggregation_method, file_path='pharmacy_sales_multi_store.csv')
else:
product_df = load_multi_store_data('pharmacy_sales_multi_store.csv', product_id=product_id)
except Exception as e:
emit_progress(f"多店铺数据加载失败: {e}, 尝试后备方案...")
df = pd.read_excel('pharmacy_sales.xlsx')
product_df = df[df['product_id'] == product_id].sort_values('date')
if training_mode == 'store' and store_id:
training_scope = f"店铺 {store_id}"
elif training_mode == 'global':
training_scope = f"全局聚合({aggregation_method})"
else:
training_scope = "所有店铺"
min_required_samples = LOOK_BACK + FORECAST_HORIZON
if len(product_df) < min_required_samples:
error_msg = f"数据不足: 需要 {min_required_samples} 天, 实际 {len(product_df)} 天。"
emit_progress(f"训练失败:{error_msg}")
raise ValueError(error_msg)
product_df = product_df.sort_values('date')
product_name = product_df['product_name'].iloc[0]
emit_progress(f"训练产品: '{product_name}' (ID: {product_id}) - {training_scope}")
emit_progress(f"使用设备: {DEVICE}, 数据量: {len(product_df)}")
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
X = product_df[features].values
y = product_df[['sales']].values
scaler_X = MinMaxScaler(feature_range=(0, 1))
scaler_y = MinMaxScaler(feature_range=(0, 1))
X_scaled = scaler_X.fit_transform(X)
y_scaled = scaler_y.fit_transform(y)
train_size = int(len(X_scaled) * 0.8)
X_train, X_test = X_scaled[:train_size], X_scaled[train_size:]
y_train, y_test = y_scaled[:train_size], y_scaled[train_size:]
trainX, trainY = create_dataset(X_train, y_train, LOOK_BACK, FORECAST_HORIZON)
testX, testY = create_dataset(X_test, y_test, LOOK_BACK, FORECAST_HORIZON)
batch_size = 32
train_loader = DataLoader(PharmacyDataset(torch.Tensor(trainX), torch.Tensor(trainY)), batch_size=batch_size, shuffle=True)
test_loader = DataLoader(PharmacyDataset(torch.Tensor(testX), torch.Tensor(testY)), batch_size=batch_size, shuffle=False)
# 4. 模型初始化
input_dim = X_train.shape[1]
output_dim = FORECAST_HORIZON
hidden_size = 64
num_heads = 4
dropout_rate = 0.1
num_layers = 3
model = TimeSeriesTransformer(
num_features=input_dim,
d_model=hidden_size,
nhead=num_heads,
num_encoder_layers=num_layers,
dim_feedforward=hidden_size * 2,
dropout=dropout_rate,
output_sequence_length=output_dim,
seq_length=LOOK_BACK
).to(DEVICE)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=patience // 2, factor=0.5)
emit_progress("数据预处理完成,开始模型训练...", progress=10)
# 5. 训练循环
train_losses, test_losses = [], []
start_time = time.time()
best_loss = float('inf')
epochs_no_improve = 0
for epoch in range(epochs):
model.train()
epoch_loss = 0
for X_batch, y_batch in train_loader:
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
outputs = model(X_batch)
loss = criterion(outputs, y_batch)
optimizer.zero_grad()
loss.backward()
if clip_norm:
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=clip_norm)
optimizer.step()
epoch_loss += loss.item()
train_loss = epoch_loss / len(train_loader)
train_losses.append(train_loss)
model.eval()
test_loss = 0
with torch.no_grad():
for X_batch, y_batch in test_loader:
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
outputs = model(X_batch)
loss = criterion(outputs, y_batch)
test_loss += loss.item()
test_loss /= len(test_loader)
test_losses.append(test_loss)
scheduler.step(test_loss)
progress_percentage = 10 + ((epoch + 1) / epochs) * 85
emit_progress(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}", progress=progress_percentage)
if test_loss < best_loss:
best_loss = test_loss
epochs_no_improve = 0
checkpoint_data = {
'epoch': epoch + 1,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'scaler_X': scaler_X,
'scaler_y': scaler_y,
}
model_manager.save_model_artifact(checkpoint_data, "checkpoint_best.pth", model_version_path)
emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})")
else:
epochs_no_improve += 1
if epochs_no_improve >= patience:
emit_progress(f"连续 {patience} 个epoch测试损失未改善提前停止训练。")
break
training_time = time.time() - start_time
# 6. 保存产物和评估
loss_fig = plt.figure(figsize=(10, 6))
plt.plot(train_losses, label='Training Loss')
plt.plot(test_losses, label='Test Loss')
plt.title(f'{model_type.upper()} 损失曲线 - {product_name} (v{version}) - {training_scope}')
plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.grid(True)
model_manager.save_model_artifact(loss_fig, "loss_curve.png", model_version_path)
plt.close(loss_fig)
emit_progress(f"损失曲线已保存到: {os.path.join(model_version_path, 'loss_curve.png')}")
model.eval()
with torch.no_grad():
testX_tensor = torch.Tensor(testX).to(DEVICE)
test_pred = model(testX_tensor).cpu().numpy()
test_pred_inv = scaler_y.inverse_transform(test_pred)
test_true_inv = scaler_y.inverse_transform(testY)
metrics = evaluate_model(test_true_inv.flatten(), test_pred_inv.flatten())
metrics['training_time'] = training_time
emit_progress(f"模型评估完成 - RMSE: {metrics['rmse']:.4f}, R²: {metrics['r2']:.4f}")
# 7. 保存最终模型和元数据
final_model_data = {
'model_state_dict': model.state_dict(),
'scaler_X': scaler_X,
'scaler_y': scaler_y,
}
model_manager.save_model_artifact(final_model_data, "model.pth", model_version_path)
metadata = {
'product_id': product_id, 'product_name': product_name, 'model_type': model_type,
'version': f'v{version}', 'training_mode': training_mode, 'scope': scope,
'aggregation_method': aggregation_method, 'training_scope_description': training_scope,
'timestamp': datetime.now().isoformat(), 'metrics': metrics,
'config': {
'input_dim': input_dim, 'output_dim': output_dim, 'd_model': hidden_size,
'nhead': num_heads, 'num_encoder_layers': num_layers, 'dim_feedforward': hidden_size * 2,
'dropout': dropout_rate, 'sequence_length': LOOK_BACK, 'forecast_horizon': FORECAST_HORIZON,
}
}
model_manager.save_model_artifact(metadata, "metadata.json", model_version_path)
# 8. 更新版本文件
model_manager.update_version(model_identifier, version)
emit_progress(f"{model_type.upper()}模型 v{version} 训练完成!", progress=100, metrics=metrics)
return model, metrics, version, model_version_path