ShopTRAINING/server/trainers/kan_trainer.py

364 lines
14 KiB
Python
Raw Normal View History

"""
药店销售预测系统 - KAN模型训练函数
"""
import os
import time
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.preprocessing import MinMaxScaler
import matplotlib.pyplot as plt
from tqdm import tqdm
from models.kan_model import KANForecaster
from models.optimized_kan_forecaster import OptimizedKANForecaster
from utils.data_utils import create_dataset, PharmacyDataset
2025-07-25 18:42:58 +08:00
from utils.new_data_loader import load_new_data
from utils.visualization import plot_loss_curve
from analysis.metrics import evaluate_model
from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON
2025-07-16 18:50:16 +08:00
def train_product_model_with_kan(product_id, model_identifier, product_df=None, store_id=None, training_mode='product', aggregation_method='sum', epochs=50, sequence_length=LOOK_BACK, forecast_horizon=FORECAST_HORIZON, use_optimized=False, model_dir=DEFAULT_MODEL_DIR):
"""
使用KAN模型训练产品销售预测模型
参数:
product_id: 产品ID
epochs: 训练轮次
use_optimized: 是否使用优化版KAN
model_dir: 模型保存目录默认使用配置中的DEFAULT_MODEL_DIR
返回:
model: 训练好的模型
metrics: 模型评估指标
"""
2025-07-25 18:42:58 +08:00
# --- 数据加载与筛选重构 ---
# 统一使用新的数据加载器,替换掉所有旧的、分散的加载逻辑
print("正在使用新的统一数据加载器...")
full_df = load_new_data() # 加载完整的、适配后的新数据
if training_mode == 'store' and store_id:
store_df = full_df[full_df['store_id'] == store_id].copy()
if product_id and product_id != 'unknown' and product_id != 'all_products':
product_df = store_df[store_df['product_id'] == product_id].copy()
training_scope = f"店铺 {store_id} - 产品 {product_id}"
2025-07-02 11:05:23 +08:00
else:
2025-07-25 18:42:58 +08:00
product_df = store_df.groupby('date').agg({
'sales': 'sum', 'weekday': 'first', 'month': 'first',
'is_holiday': 'first', 'is_weekend': 'first',
'is_promotion': 'first', 'temperature': 'mean'
}).reset_index()
training_scope = f"店铺 {store_id} (所有药品聚合)"
2025-07-26 14:41:41 +08:00
# 数据清洗使用0填充聚合后可能产生的NaN值
product_df.fillna(0, inplace=True)
2025-07-25 18:42:58 +08:00
elif training_mode == 'global':
# 筛选特定产品在所有店铺的聚合数据
# 注意:新数据已经是按 (store_id, product_id, date) 展开的,聚合逻辑可能需要重新审视
# 此处暂时只筛选产品ID
product_df = full_df[full_df['product_id'] == product_id].copy()
# 按日期对同一产品在不同店铺的销售额求和
product_df = product_df.groupby('date').agg({
'sales': 'sum',
# 保留其他需要的特征,例如取第一个非空值或平均值
'weekday': 'first',
'month': 'first',
'is_holiday': 'first',
'is_weekend': 'first',
'is_promotion': 'first',
'temperature': 'mean'
}).reset_index()
training_scope = f"全局聚合({aggregation_method})"
else: # 默认 'product' 模式
# 筛选特定产品的数据(可能跨越多个店铺,但此处不聚合)
product_df = full_df[full_df['product_id'] == product_id].copy()
training_scope = f"所有店铺中的产品 {product_id}"
2025-07-02 11:05:23 +08:00
if product_df.empty:
raise ValueError(f"产品 {product_id} 没有可用的销售数据")
# 数据量检查
2025-07-16 12:59:56 +08:00
min_required_samples = sequence_length + forecast_horizon
2025-07-02 11:05:23 +08:00
if len(product_df) < min_required_samples:
error_msg = (
f"❌ 训练数据不足错误\n"
2025-07-16 12:59:56 +08:00
f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={sequence_length} + FORECAST_HORIZON={forecast_horizon})\n"
2025-07-02 11:05:23 +08:00
f"实际数据量: {len(product_df)}\n"
f"产品ID: {product_id}, 训练模式: {training_mode}\n"
f"建议解决方案:\n"
f"1. 生成更多数据: uv run generate_multi_store_data.py\n"
f"2. 调整配置参数: 减小 LOOK_BACK 或 FORECAST_HORIZON\n"
f"3. 使用全局训练模式聚合更多数据"
)
print(error_msg)
raise ValueError(error_msg)
product_df = product_df.sort_values('date')
2025-07-25 18:42:58 +08:00
# 兼容性处理:新数据可能没有 product_name 列
if 'product_name' in product_df.columns:
product_name = product_df['product_name'].iloc[0]
else:
product_name = f"Product {product_id}" # 使用 product_id 作为备用名称
model_type = "优化版KAN" if use_optimized else "KAN"
print(f"使用{model_type}模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型")
2025-07-02 11:05:23 +08:00
print(f"训练范围: {training_scope}")
print(f"使用设备: {DEVICE}")
print(f"模型将保存到目录: {model_dir}")
# 创建特征和目标变量
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
# 预处理数据
X = product_df[features].values
y = product_df[['sales']].values # 保持为二维数组
# 归一化数据
scaler_X = MinMaxScaler(feature_range=(0, 1))
scaler_y = MinMaxScaler(feature_range=(0, 1))
X_scaled = scaler_X.fit_transform(X)
y_scaled = scaler_y.fit_transform(y)
# 划分训练集和测试集80% 训练20% 测试)
train_size = int(len(X_scaled) * 0.8)
X_train, X_test = X_scaled[:train_size], X_scaled[train_size:]
y_train, y_test = y_scaled[:train_size], y_scaled[train_size:]
# 创建时间序列数据
2025-07-16 12:59:56 +08:00
trainX, trainY = create_dataset(X_train, y_train, sequence_length, forecast_horizon)
testX, testY = create_dataset(X_test, y_test, sequence_length, forecast_horizon)
# 转换为PyTorch的Tensor
trainX_tensor = torch.Tensor(trainX)
trainY_tensor = torch.Tensor(trainY)
testX_tensor = torch.Tensor(testX)
testY_tensor = torch.Tensor(testY)
# 创建数据加载器
train_dataset = PharmacyDataset(trainX_tensor, trainY_tensor)
test_dataset = PharmacyDataset(testX_tensor, testY_tensor)
batch_size = 32
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
# 初始化KAN模型
input_dim = X_train.shape[1]
2025-07-16 12:59:56 +08:00
output_dim = forecast_horizon
hidden_size = 64
if use_optimized:
model = OptimizedKANForecaster(
input_features=input_dim,
hidden_sizes=[hidden_size, hidden_size*2, hidden_size],
output_sequence_length=output_dim
)
else:
model = KANForecaster(
input_features=input_dim,
hidden_sizes=[hidden_size, hidden_size*2, hidden_size],
output_sequence_length=output_dim
)
# 将模型移动到设备上
model = model.to(DEVICE)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
# 训练模型
train_losses = []
test_losses = []
start_time = time.time()
2025-07-17 17:54:53 +08:00
best_loss = float('inf')
for epoch in range(epochs):
model.train()
epoch_loss = 0
for X_batch, y_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False):
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
# 确保目标张量有正确的形状 (batch_size, forecast_horizon, 1)
if y_batch.dim() == 2:
y_batch = y_batch.unsqueeze(-1)
# 前向传播
outputs = model(X_batch)
# 确保输出形状与目标匹配
if outputs.dim() == 2:
outputs = outputs.unsqueeze(-1)
loss = criterion(outputs, y_batch)
# 如果是KAN模型加入正则化损失
if hasattr(model, 'regularization_loss'):
loss = loss + model.regularization_loss() * 0.01
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
epoch_loss += loss.item()
# 计算训练损失
train_loss = epoch_loss / len(train_loader)
train_losses.append(train_loss)
# 在测试集上评估
model.eval()
test_loss = 0
with torch.no_grad():
for X_batch, y_batch in test_loader:
X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE)
# 确保目标张量有正确的形状
if y_batch.dim() == 2:
y_batch = y_batch.unsqueeze(-1)
outputs = model(X_batch)
# 确保输出形状与目标匹配
if outputs.dim() == 2:
outputs = outputs.unsqueeze(-1)
loss = criterion(outputs, y_batch)
test_loss += loss.item()
test_loss = test_loss / len(test_loader)
test_losses.append(test_loss)
2025-07-17 17:54:53 +08:00
# 检查是否为最佳模型
model_type_name = 'optimized_kan' if use_optimized else 'kan'
if test_loss < best_loss:
best_loss = test_loss
print(f"🎉 新的最佳模型发现在 epoch {epoch+1},测试损失: {test_loss:.4f}")
# 为保存最佳模型准备数据
best_model_data = {
'model_state_dict': model.state_dict(),
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
'input_dim': input_dim,
'output_dim': output_dim,
'hidden_size': hidden_size,
'hidden_sizes': [hidden_size, hidden_size * 2, hidden_size],
'sequence_length': sequence_length,
'forecast_horizon': forecast_horizon,
'model_type': model_type_name,
'use_optimized': use_optimized
},
'epoch': epoch + 1
}
# 使用模型管理器保存 'best' 版本
from utils.model_manager import model_manager
model_manager.save_model(
model_data=best_model_data,
product_id=model_identifier, # 修正:使用唯一的标识符
2025-07-17 17:54:53 +08:00
model_type=model_type_name,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name,
version='best' # 显式覆盖版本为'best'
2025-07-17 17:54:53 +08:00
)
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}")
# 计算训练时间
training_time = time.time() - start_time
# 绘制损失曲线并保存到模型目录
model_name = 'optimized_kan' if use_optimized else 'kan'
loss_curve_path = plot_loss_curve(
train_losses,
test_losses,
product_name,
model_type,
model_dir=model_dir
)
print(f"损失曲线已保存到: {loss_curve_path}")
# 评估模型
model.eval()
with torch.no_grad():
test_pred = model(testX_tensor.to(DEVICE)).cpu().numpy()
# 处理输出形状
if len(test_pred.shape) == 3:
test_pred = test_pred.squeeze(-1)
# 反归一化预测结果和真实值
test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, 1)).flatten()
test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, 1)).flatten()
# 计算评估指标
metrics = evaluate_model(test_true_inv, test_pred_inv)
metrics['training_time'] = training_time
# 打印评估指标
print("\n模型评估指标:")
print(f"MSE: {metrics['mse']:.4f}")
print(f"RMSE: {metrics['rmse']:.4f}")
print(f"MAE: {metrics['mae']:.4f}")
print(f"R²: {metrics['r2']:.4f}")
print(f"MAPE: {metrics['mape']:.2f}%")
print(f"训练时间: {training_time:.2f}")
2025-07-02 11:05:23 +08:00
# 使用统一模型管理器保存模型
from utils.model_manager import model_manager
2025-07-02 11:05:23 +08:00
model_type_name = 'optimized_kan' if use_optimized else 'kan'
2025-07-02 11:05:23 +08:00
model_data = {
'model_state_dict': model.state_dict(),
'scaler_X': scaler_X,
'scaler_y': scaler_y,
'config': {
'input_dim': input_dim,
'output_dim': output_dim,
'hidden_size': hidden_size,
2025-07-16 18:50:16 +08:00
'hidden_sizes': [hidden_size, hidden_size * 2, hidden_size],
2025-07-16 12:59:56 +08:00
'sequence_length': sequence_length,
'forecast_horizon': forecast_horizon,
2025-07-02 11:05:23 +08:00
'model_type': model_type_name,
'use_optimized': use_optimized
},
'metrics': metrics,
'loss_history': {
'train': train_losses,
'test': test_losses,
'epochs': list(range(1, epochs + 1))
},
'loss_curve_path': loss_curve_path
2025-07-02 11:05:23 +08:00
}
# 保存最终模型,让 model_manager 自动处理版本号
final_model_path, final_version = model_manager.save_model(
2025-07-02 11:05:23 +08:00
model_data=model_data,
product_id=model_identifier, # 修正:使用唯一的标识符
2025-07-02 11:05:23 +08:00
model_type=model_type_name,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
product_name=product_name
# 注意此处不传递version参数由管理器自动生成
2025-07-02 11:05:23 +08:00
)
print(f"最终模型已保存,版本: {final_version}, 路径: {final_model_path}")
2025-07-22 15:40:37 +08:00
return model, metrics
# --- 将此训练器注册到系统中 ---
from models.model_registry import register_trainer
register_trainer('kan', train_product_model_with_kan)
register_trainer('optimized_kan', train_product_model_with_kan)