ShopTRAINING/server/trainers/kan_trainer.py

352 lines
13 KiB
Python
Raw Normal View History

"""
药店销售预测系统 - KAN模型训练函数
"""
import os
import time
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
from tqdm import tqdm
from models.kan_model import KANForecaster
from models.optimized_kan_forecaster import OptimizedKANForecaster
from utils.data_utils import create_dataset, PharmacyDataset
from utils.visualization import plot_loss_curve
from analysis.metrics import evaluate_model
from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
2025-07-16 18:50:16 +08:00
def train_product_model_with_kan(product_id, model_identifier, product_df=None, store_id=None, training_mode='product', aggregation_method='sum', epochs=50, sequence_length=LOOK_BACK, forecast_horizon=FORECAST_HORIZON, use_optimized=False, model_dir=DEFAULT_MODEL_DIR):
"""
使用KAN模型训练产品销售预测模型
参数:
product_id: 产品ID
epochs: 训练轮次
use_optimized: 是否使用优化版KAN
model_dir: 模型保存目录默认使用配置中的DEFAULT_MODEL_DIR
返回:
model: 训练好的模型
metrics: 模型评估指标
"""
# 如果没有传入product_df则根据训练模式加载数据
if product_df is None:
from utils.multi_store_data_utils import load_multi_store_data, get_store_product_sales_data, aggregate_multi_store_data
try:
if training_mode == 'store' and store_id:
# 加载特定店铺的数据
product_df = get_store_product_sales_data(
store_id,
product_id,
'pharmacy_sales_multi_store.csv'
)
training_scope = f"店铺 {store_id}"
elif training_mode == 'global':
# 聚合所有店铺的数据
product_df = aggregate_multi_store_data(
product_id,
aggregation_method=aggregation_method,
file_path='pharmacy_sales_multi_store.csv'
)
training_scope = f"全局聚合({aggregation_method})"
else:
# 默认:加载所有店铺的产品数据
product_df = load_multi_store_data('pharmacy_sales_multi_store.csv', product_id=product_id)
training_scope = "所有店铺"
except Exception as e:
print(f"多店铺数据加载失败: {e}")
# 后备方案:尝试原始数据
df = pd.read_excel('pharmacy_sales.xlsx')
product_df = df[df['product_id'] == product_id].sort_values('date')
training_scope = "原始数据"
else:
# 如果传入了product_df直接使用
2025-07-02 11:05:23 +08:00
if training_mode == 'store' and store_id:
training_scope = f"店铺 {store_id}"
elif training_mode == 'global':
training_scope = f"全局聚合({aggregation_method})"
else:
training_scope = "所有店铺"
if product_df.empty:
raise ValueError(f"产品 {product_id} 没有可用的销售数据")
# 数据量检查
2025-07-16 12:59:56 +08:00
min_required_samples = sequence_length + forecast_horizon
2025-07-02 11:05:23 +08:00
if len(product_df) < min_required_samples:
error_msg = (
f"❌ 训练数据不足错误\n"
2025-07-16 12:59:56 +08:00
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={sequence_length} + FORECAST_HORIZON={forecast_horizon})\n"
2025-07-02 11:05:23 +08:00
f"实际数据量: {len(product_df)}\n"
f"产品ID: {product_id}, 训练模式: {training_mode}\n"
f"建议解决方案:\n"
f"1. 生成更多数据: uv run generate_multi_store_data.py\n"
f"2. 调整配置参数: 减小 LOOK_BACK 或 FORECAST_HORIZON\n"
f"3. 使用全局训练模式聚合更多数据"
)
print(error_msg)
raise ValueError(error_msg)
product_df = product_df.sort_values('date')
product_name = product_df['product_name'].iloc[0]
model_type = "优化版KAN" if use_optimized else "KAN"
print(f"使用{model_type}模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型")
2025-07-02 11:05:23 +08:00
print(f"训练范围: {training_scope}")
print(f"使用设备: {DEVICE}")
print(f"模型将保存到目录: {model_dir}")
# 创建特征和目标变量
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
# 预处理数据
X = product_df[features].values
y = product_df[['sales']].values # 保持为二维数组
# 归一化数据
scaler_X = MinMaxScaler(feature_range=(0, 1))
scaler_y = MinMaxScaler(feature_range=(0, 1))
X_scaled = scaler_X.fit_transform(X)
y_scaled = scaler_y.fit_transform(y)
# 划分训练集和测试集80% 训练20% 测试)
train_size = int(len(X_scaled) * 0.8)
X_train, X_test = X_scaled[:train_size], X_scaled[train_size:]
y_train, y_test = y_scaled[:train_size], y_scaled[train_size:]
# 创建时间序列数据
2025-07-16 12:59:56 +08:00
trainX, trainY = create_dataset(X_train, y_train, sequence_length, forecast_horizon)
testX, testY = create_dataset(X_test, y_test, sequence_length, forecast_horizon)
# 转换为PyTorch的Tensor
trainX_tensor = torch.Tensor(trainX)
trainY_tensor = torch.Tensor(trainY)
testX_tensor = torch.Tensor(testX)
testY_tensor = torch.Tensor(testY)
# 创建数据加载器
train_dataset = PharmacyDataset(trainX_tensor, trainY_tensor)
test_dataset = PharmacyDataset(testX_tensor, testY_tensor)
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 初始化KAN模型
input_dim = X_train.shape[1]
2025-07-16 12:59:56 +08:00
output_dim = forecast_horizon
hidden_size = 64
if use_optimized:
model = OptimizedKANForecaster(
input_features=input_dim,
hidden_sizes=[hidden_size, hidden_size*2, hidden_size],
output_sequence_length=output_dim
)
else:
model = KANForecaster(
input_features=input_dim,
hidden_sizes=[hidden_size, hidden_size*2, hidden_size],
output_sequence_length=output_dim
)
# 将模型移动到设备上
model = model.to(DEVICE)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练模型
train_losses = []
test_losses = []
start_time = time.time()
2025-07-17 17:54:53 +08:00
best_loss = float('inf')
for epoch in range(epochs):
model.train()
epoch_loss = 0
for X_batch, y_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False):
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
# 确保目标张量有正确的形状 (batch_size, forecast_horizon, 1)
if y_batch.dim() == 2:
y_batch = y_batch.unsqueeze(-1)
# 前向传播
outputs = model(X_batch)
# 确保输出形状与目标匹配
if outputs.dim() == 2:
outputs = outputs.unsqueeze(-1)
loss = criterion(outputs, y_batch)
# 如果是KAN模型加入正则化损失
if hasattr(model, 'regularization_loss'):
loss = loss + model.regularization_loss() * 0.01
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
# 计算训练损失
train_loss = epoch_loss / len(train_loader)
train_losses.append(train_loss)
# 在测试集上评估
model.eval()
test_loss = 0
with torch.no_grad():
for X_batch, y_batch in test_loader:
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
# 确保目标张量有正确的形状
if y_batch.dim() == 2:
y_batch = y_batch.unsqueeze(-1)
outputs = model(X_batch)
# 确保输出形状与目标匹配
if outputs.dim() == 2:
outputs = outputs.unsqueeze(-1)
loss = criterion(outputs, y_batch)
test_loss += loss.item()
test_loss = test_loss / len(test_loader)
test_losses.append(test_loss)
2025-07-17 17:54:53 +08:00
# 检查是否为最佳模型
model_type_name = 'optimized_kan' if use_optimized else 'kan'
if test_loss < best_loss:
best_loss = test_loss
print(f"🎉 新的最佳模型发现在 epoch {epoch+1},测试损失: {test_loss:.4f}")
# 为保存最佳模型准备数据
best_model_data = {
'model_state_dict': model.state_dict(),
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
'input_dim': input_dim,
'output_dim': output_dim,
'hidden_size': hidden_size,
'hidden_sizes': [hidden_size, hidden_size * 2, hidden_size],
'sequence_length': sequence_length,
'forecast_horizon': forecast_horizon,
'model_type': model_type_name,
'use_optimized': use_optimized
},
'epoch': epoch + 1
}
# 使用模型管理器保存 'best' 版本
from utils.model_manager import model_manager
model_manager.save_model(
model_data=best_model_data,
product_id=model_identifier, # 修正:使用唯一的标识符
2025-07-17 17:54:53 +08:00
model_type=model_type_name,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name,
version='best' # 显式覆盖版本为'best'
2025-07-17 17:54:53 +08:00
)
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
# 计算训练时间
training_time = time.time() - start_time
# 绘制损失曲线并保存到模型目录
model_name = 'optimized_kan' if use_optimized else 'kan'
loss_curve_path = plot_loss_curve(
train_losses,
test_losses,
product_name,
model_type,
model_dir=model_dir
)
print(f"损失曲线已保存到: {loss_curve_path}")
# 评估模型
model.eval()
with torch.no_grad():
test_pred = model(testX_tensor.to(DEVICE)).cpu().numpy()
# 处理输出形状
if len(test_pred.shape) == 3:
test_pred = test_pred.squeeze(-1)
# 反归一化预测结果和真实值
test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, 1)).flatten()
test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, 1)).flatten()
# 计算评估指标
metrics = evaluate_model(test_true_inv, test_pred_inv)
metrics['training_time'] = training_time
# 打印评估指标
print("\n模型评估指标:")
print(f"MSE: {metrics['mse']:.4f}")
print(f"RMSE: {metrics['rmse']:.4f}")
print(f"MAE: {metrics['mae']:.4f}")
print(f"R²: {metrics['r2']:.4f}")
print(f"MAPE: {metrics['mape']:.2f}%")
print(f"训练时间: {training_time:.2f}")
2025-07-02 11:05:23 +08:00
# 使用统一模型管理器保存模型
from utils.model_manager import model_manager
2025-07-02 11:05:23 +08:00
model_type_name = 'optimized_kan' if use_optimized else 'kan'
2025-07-02 11:05:23 +08:00
model_data = {
'model_state_dict': model.state_dict(),
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
'input_dim': input_dim,
'output_dim': output_dim,
'hidden_size': hidden_size,
2025-07-16 18:50:16 +08:00
'hidden_sizes': [hidden_size, hidden_size * 2, hidden_size],
2025-07-16 12:59:56 +08:00
'sequence_length': sequence_length,
'forecast_horizon': forecast_horizon,
2025-07-02 11:05:23 +08:00
'model_type': model_type_name,
'use_optimized': use_optimized
},
'metrics': metrics,
'loss_history': {
'train': train_losses,
'test': test_losses,
'epochs': list(range(1, epochs + 1))
},
'loss_curve_path': loss_curve_path
2025-07-02 11:05:23 +08:00
}
# 保存最终模型,让 model_manager 自动处理版本号
final_model_path, final_version = model_manager.save_model(
2025-07-02 11:05:23 +08:00
model_data=model_data,
product_id=model_identifier, # 修正:使用唯一的标识符
2025-07-02 11:05:23 +08:00
model_type=model_type_name,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name
# 注意此处不传递version参数由管理器自动生成
2025-07-02 11:05:23 +08:00
)
print(f"最终模型已保存,版本: {final_version}, 路径: {final_model_path}")
return model, metrics