""" 药店销售预测系统 - KAN模型训练函数 """ import os import time import pandas as pd import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from sklearn.preprocessing import MinMaxScaler import matplotlib.pyplot as plt from tqdm import tqdm from models.kan_model import KANForecaster from models.optimized_kan_forecaster import OptimizedKANForecaster from utils.data_utils import create_dataset, PharmacyDataset from utils.visualization import plot_loss_curve from analysis.metrics import evaluate_model from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON def train_product_model_with_kan(product_id, epochs=50, use_optimized=False, model_dir=DEFAULT_MODEL_DIR): """ 使用KAN模型训练产品销售预测模型 参数: product_id: 产品ID epochs: 训练轮次 use_optimized: 是否使用优化版KAN model_dir: 模型保存目录,默认使用配置中的DEFAULT_MODEL_DIR 返回: model: 训练好的模型 metrics: 模型评估指标 """ # 读取生成的药店销售数据 df = pd.read_excel('pharmacy_sales.xlsx') # 筛选特定产品数据 product_df = df[df['product_id'] == product_id].sort_values('date') product_name = product_df['product_name'].iloc[0] model_type = "优化版KAN" if use_optimized else "KAN" print(f"使用{model_type}模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型") print(f"使用设备: {DEVICE}") print(f"模型将保存到目录: {model_dir}") # 创建特征和目标变量 features = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'] # 预处理数据 X = product_df[features].values y = product_df[['sales']].values # 保持为二维数组 # 归一化数据 scaler_X = MinMaxScaler(feature_range=(0, 1)) scaler_y = MinMaxScaler(feature_range=(0, 1)) X_scaled = scaler_X.fit_transform(X) y_scaled = scaler_y.fit_transform(y) # 划分训练集和测试集(80% 训练,20% 测试) train_size = int(len(X_scaled) * 0.8) X_train, X_test = X_scaled[:train_size], X_scaled[train_size:] y_train, y_test = y_scaled[:train_size], y_scaled[train_size:] # 创建时间序列数据 trainX, trainY = create_dataset(X_train, y_train, LOOK_BACK, FORECAST_HORIZON) testX, testY = create_dataset(X_test, y_test, LOOK_BACK, FORECAST_HORIZON) # 转换为PyTorch的Tensor trainX_tensor = torch.Tensor(trainX) trainY_tensor = torch.Tensor(trainY) testX_tensor = torch.Tensor(testX) testY_tensor = torch.Tensor(testY) # 创建数据加载器 train_dataset = PharmacyDataset(trainX_tensor, trainY_tensor) test_dataset = PharmacyDataset(testX_tensor, testY_tensor) batch_size = 32 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # 初始化KAN模型 input_dim = X_train.shape[1] output_dim = FORECAST_HORIZON hidden_size = 64 if use_optimized: model = OptimizedKANForecaster( input_features=input_dim, hidden_sizes=[hidden_size, hidden_size*2, hidden_size], output_sequence_length=output_dim ) else: model = KANForecaster( input_features=input_dim, hidden_sizes=[hidden_size, hidden_size*2, hidden_size], output_sequence_length=output_dim ) # 将模型移动到设备上 model = model.to(DEVICE) criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=0.001) # 训练模型 train_losses = [] test_losses = [] start_time = time.time() for epoch in range(epochs): model.train() epoch_loss = 0 for X_batch, y_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False): X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE) # 确保目标张量有正确的形状 (batch_size, forecast_horizon, 1) if y_batch.dim() == 2: y_batch = y_batch.unsqueeze(-1) # 前向传播 outputs = model(X_batch) # 确保输出形状与目标匹配 if outputs.dim() == 2: outputs = outputs.unsqueeze(-1) loss = criterion(outputs, y_batch) # 如果是KAN模型,加入正则化损失 if hasattr(model, 'regularization_loss'): loss = loss + model.regularization_loss() * 0.01 # 反向传播和优化 optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss += loss.item() # 计算训练损失 train_loss = epoch_loss / len(train_loader) train_losses.append(train_loss) # 在测试集上评估 model.eval() test_loss = 0 with torch.no_grad(): for X_batch, y_batch in test_loader: X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE) # 确保目标张量有正确的形状 if y_batch.dim() == 2: y_batch = y_batch.unsqueeze(-1) outputs = model(X_batch) # 确保输出形状与目标匹配 if outputs.dim() == 2: outputs = outputs.unsqueeze(-1) loss = criterion(outputs, y_batch) test_loss += loss.item() test_loss = test_loss / len(test_loader) test_losses.append(test_loss) if (epoch + 1) % 10 == 0: print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}") # 计算训练时间 training_time = time.time() - start_time # 绘制损失曲线并保存到模型目录 model_name = 'optimized_kan' if use_optimized else 'kan' loss_curve_path = plot_loss_curve( train_losses, test_losses, product_name, model_type, model_dir=model_dir ) print(f"损失曲线已保存到: {loss_curve_path}") # 评估模型 model.eval() with torch.no_grad(): test_pred = model(testX_tensor.to(DEVICE)).cpu().numpy() # 处理输出形状 if len(test_pred.shape) == 3: test_pred = test_pred.squeeze(-1) # 反归一化预测结果和真实值 test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, 1)).flatten() test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, 1)).flatten() # 计算评估指标 metrics = evaluate_model(test_true_inv, test_pred_inv) metrics['training_time'] = training_time # 打印评估指标 print("\n模型评估指标:") print(f"MSE: {metrics['mse']:.4f}") print(f"RMSE: {metrics['rmse']:.4f}") print(f"MAE: {metrics['mae']:.4f}") print(f"R²: {metrics['r2']:.4f}") print(f"MAPE: {metrics['mape']:.2f}%") print(f"训练时间: {training_time:.2f}秒") # 保存模型 if not os.path.exists(model_dir): os.makedirs(model_dir) # 构建模型文件名 model_file_prefix = 'kan_optimized' if use_optimized else 'kan' model_path = os.path.join(model_dir, f"{model_file_prefix}_model_product_{product_id}.pth") torch.save({ 'model_state_dict': model.state_dict(), 'scaler_X': scaler_X, 'scaler_y': scaler_y, 'config': { 'input_dim': input_dim, 'output_dim': output_dim, 'hidden_size': hidden_size, 'hidden_sizes': [hidden_size, hidden_size*2, hidden_size], 'sequence_length': LOOK_BACK, 'forecast_horizon': FORECAST_HORIZON, 'model_type': model_name, 'use_optimized': use_optimized }, 'metrics': metrics, 'loss_history': { 'train': train_losses, 'test': test_losses, 'epochs': list(range(1, epochs + 1)) }, 'loss_curve_path': loss_curve_path }, model_path) print(f"模型已保存到 {model_path}") return model, metrics