ShopTRAINING/server/trainers/kan_trainer.py

325 lines
11 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
药店销售预测系统 - KAN模型训练函数
"""
import os
import time
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
from tqdm import tqdm
from models.kan_model import KANForecaster
from models.optimized_kan_forecaster import OptimizedKANForecaster
from utils.data_utils import prepare_data, PharmacyDataset, prepare_sequences
from utils.visualization import plot_loss_curve
from analysis.metrics import evaluate_model
from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
def train_product_model_with_kan(
model_identifier: str,
training_df: pd.DataFrame,
feature_list: list,
training_mode: str,
epochs: int = 50,
sequence_length: int = LOOK_BACK,
forecast_horizon: int = FORECAST_HORIZON,
use_optimized: bool = False,
model_dir: str = DEFAULT_MODEL_DIR,
product_id: str = None,
store_id: str = None,
aggregation_method: str = None,
version: str = None,
**kwargs
):
"""
使用KAN模型训练产品销售预测模型 (新数据管道版)
"""
min_required_samples = sequence_length + forecast_horizon
if len(training_df) < min_required_samples:
raise ValueError(f"数据不足: 需要 {min_required_samples} 条, 实际 {len(training_df)} 条。")
product_name = training_df['product_name'].iloc[0] if 'product_name' in training_df.columns else model_identifier
model_type_name = "优化版KAN" if use_optimized else "KAN"
print(f"开始为 '{product_name}' (标识: {model_identifier}) 训练{model_type_name}模型")
# --- 新数据管道核心改造 ---
print(f"[{model_type_name}] 开始数据预处理,使用 {len(feature_list)} 个预选特征...")
# 1. 使用标准化的 prepare_data 函数处理数据
_, _, trainX, testX, trainY, testY, scaler_X, scaler_y = prepare_data(
training_df=training_df,
feature_list=feature_list,
target_column='net_sales_quantity',
sequence_length=sequence_length,
forecast_horizon=forecast_horizon
)
# 2. 使用标准化的 prepare_sequences 函数创建 DataLoader
batch_size = 32
train_loader = prepare_sequences(trainX, trainY, batch_size)
test_loader = prepare_sequences(testX, testY, batch_size)
# 初始化KAN模型
input_dim = trainX.shape[2]
output_dim = forecast_horizon
hidden_size = 64
if use_optimized:
model = OptimizedKANForecaster(
input_features=input_dim,
hidden_sizes=[hidden_size, hidden_size*2, hidden_size],
output_sequence_length=output_dim
)
else:
model = KANForecaster(
input_features=input_dim,
hidden_sizes=[hidden_size, hidden_size*2, hidden_size],
output_sequence_length=output_dim
)
# 将模型移动到设备上
model = model.to(DEVICE)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练模型
from utils.model_manager import model_manager
model_type_name = 'optimized_kan' if use_optimized else 'kan'
current_version = model_manager.peek_next_version(
model_type=model_type_name,
product_id=product_id,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method
)
print(f"🔒 本次训练版本锁定为: {current_version}")
train_losses = []
test_losses = []
start_time = time.time()
best_loss = float('inf')
best_model_path = None
for epoch in range(epochs):
model.train()
epoch_loss = 0
for X_batch, y_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False):
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
# 确保目标张量有正确的形状 (batch_size, forecast_horizon, 1)
if y_batch.dim() == 2:
y_batch = y_batch.unsqueeze(-1)
# 前向传播
outputs = model(X_batch)
# 确保输出形状与目标匹配
if outputs.dim() == 2:
outputs = outputs.unsqueeze(-1)
loss = criterion(outputs, y_batch)
# 如果是KAN模型加入正则化损失
if hasattr(model, 'regularization_loss'):
loss = loss + model.regularization_loss() * 0.01
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
# 计算训练损失
train_loss = epoch_loss / len(train_loader)
train_losses.append(train_loss)
# 在测试集上评估
model.eval()
test_loss = 0
with torch.no_grad():
for X_batch, y_batch in test_loader:
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
# 确保目标张量有正确的形状
if y_batch.dim() == 2:
y_batch = y_batch.unsqueeze(-1)
outputs = model(X_batch)
# 确保输出形状与目标匹配
if outputs.dim() == 2:
outputs = outputs.unsqueeze(-1)
loss = criterion(outputs, y_batch)
test_loss += loss.item()
test_loss = test_loss / len(test_loader)
test_losses.append(test_loss)
# 检查是否为最佳模型
model_type_name = 'optimized_kan' if use_optimized else 'kan'
if test_loss < best_loss:
best_loss = test_loss
print(f"🎉 新的最佳模型发现在 epoch {epoch+1},测试损失: {test_loss:.4f}")
# 为保存最佳模型准备数据
best_model_data = {
'model_state_dict': model.state_dict(),
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
'input_dim': input_dim,
'output_dim': output_dim,
'hidden_size': hidden_size,
'hidden_sizes': [hidden_size, hidden_size * 2, hidden_size],
'sequence_length': sequence_length,
'forecast_horizon': forecast_horizon,
'model_type': model_type_name,
'use_optimized': use_optimized
},
'epoch': epoch + 1
}
# 使用模型管理器保存 'best' 版本
from utils.model_manager import model_manager
best_model_path, _ = model_manager.save_model(
model_data=best_model_data,
product_id=product_id,
model_type=model_type_name,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name,
version=f"{current_version}_best"
)
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
# 计算训练时间
training_time = time.time() - start_time
# 绘制损失曲线并保存到模型目录
# 评估模型
model.eval()
with torch.no_grad():
all_test_X = []
all_test_Y = []
for X_batch, y_batch in test_loader:
all_test_X.append(X_batch)
all_test_Y.append(y_batch)
testX_tensor = torch.cat(all_test_X, dim=0)
testY_tensor = torch.cat(all_test_Y, dim=0)
test_pred = model(testX_tensor.to(DEVICE)).cpu().numpy()
test_true = testY_tensor.cpu().numpy()
# 处理输出形状
if len(test_pred.shape) == 3:
test_pred = test_pred.squeeze(-1)
# 反归一化预测结果和真实值
test_pred_inv = scaler_y.inverse_transform(test_pred)
test_true_inv = scaler_y.inverse_transform(test_true)
# 计算评估指标
metrics = evaluate_model(test_true_inv, test_pred_inv)
metrics['training_time'] = training_time
# 打印评估指标
print("\n模型评估指标:")
print(f"MSE: {metrics['mse']:.4f}")
print(f"RMSE: {metrics['rmse']:.4f}")
print(f"MAE: {metrics['mae']:.4f}")
print(f"R²: {metrics['r2']:.4f}")
print(f"MAPE: {metrics['mape']:.2f}%")
print(f"训练时间: {training_time:.2f}")
# --- 5. 保存工件 ---
model_type_name = 'optimized_kan' if use_optimized else 'kan'
# 准备 scope 和 identifier 以生成标准化的文件名
scope = training_mode
if scope == 'product':
identifier = model_identifier
elif scope == 'store':
identifier = store_id
elif scope == 'global':
identifier = aggregation_method
else:
identifier = product_name # 后备方案
# 绘制带有版本号的损失曲线图
loss_curve_path = plot_loss_curve(
train_losses=train_losses,
val_losses=test_losses,
model_type=model_type_name,
scope=scope,
identifier=identifier,
version=current_version, # 使用锁定的版本
model_dir=model_dir
)
print(f"📈 带版本号的损失曲线已保存: {loss_curve_path}")
# 准备要保存的最终模型数据
model_data = {
'model_state_dict': model.state_dict(),
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
'input_dim': input_dim,
'output_dim': output_dim,
'hidden_size': hidden_size,
'hidden_sizes': [hidden_size, hidden_size * 2, hidden_size],
'sequence_length': sequence_length,
'forecast_horizon': forecast_horizon,
'model_type': model_type_name,
'use_optimized': use_optimized
},
'metrics': metrics,
'loss_history': {
'train': train_losses,
'test': test_losses,
'epochs': list(range(1, epochs + 1))
},
'loss_curve_path': loss_curve_path # 直接包含路径
}
# 使用模型管理器保存最终模型
from utils.model_manager import model_manager
final_model_path, final_version = model_manager.save_model(
model_data=model_data,
product_id=product_id,
model_type=model_type_name,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name,
version=current_version # 使用锁定的版本
)
print(f"{model_type_name} 最终模型已保存,版本: {final_version}")
# 组装返回的工件
artifacts = {
"versioned_model": final_model_path,
"loss_curve_plot": loss_curve_path,
"best_model": best_model_path,
"version": final_version
}
return metrics, artifacts
# --- 将此训练器注册到系统中 ---
from models.model_registry import register_trainer
register_trainer('kan', train_product_model_with_kan)
register_trainer('optimized_kan', train_product_model_with_kan)