""" 药店销售预测系统 - KAN模型训练函数 """ import os import time import pandas as pd import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from sklearn.preprocessing import MinMaxScaler import matplotlib.pyplot as plt from tqdm import tqdm from models.kan_model import KANForecaster from models.optimized_kan_forecaster import OptimizedKANForecaster from utils.data_utils import create_dataset, PharmacyDataset from utils.visualization import plot_loss_curve from analysis.metrics import evaluate_model from core.config import DEVICE, DEFAULT_MODEL_DIR, LOOK_BACK, FORECAST_HORIZON def train_product_model_with_kan(product_id, model_identifier, product_df=None, store_id=None, training_mode='product', aggregation_method='sum', epochs=50, sequence_length=LOOK_BACK, forecast_horizon=FORECAST_HORIZON, use_optimized=False, model_dir=DEFAULT_MODEL_DIR): """ 使用KAN模型训练产品销售预测模型 参数: product_id: 产品ID epochs: 训练轮次 use_optimized: 是否使用优化版KAN model_dir: 模型保存目录,默认使用配置中的DEFAULT_MODEL_DIR 返回: model: 训练好的模型 metrics: 模型评估指标 """ # 如果没有传入product_df,则根据训练模式加载数据 if product_df is None: from utils.multi_store_data_utils import load_multi_store_data, get_store_product_sales_data, aggregate_multi_store_data try: if training_mode == 'store' and store_id: # 加载特定店铺的数据 product_df = get_store_product_sales_data( store_id, product_id, 'pharmacy_sales_multi_store.csv' ) training_scope = f"店铺 {store_id}" elif training_mode == 'global': # 聚合所有店铺的数据 product_df = aggregate_multi_store_data( product_id, aggregation_method=aggregation_method, file_path='pharmacy_sales_multi_store.csv' ) training_scope = f"全局聚合({aggregation_method})" else: # 默认:加载所有店铺的产品数据 product_df = load_multi_store_data('pharmacy_sales_multi_store.csv', product_id=product_id) training_scope = "所有店铺" except Exception as e: print(f"多店铺数据加载失败: {e}") # 后备方案:尝试原始数据 df = pd.read_excel('pharmacy_sales.xlsx') product_df = df[df['product_id'] == product_id].sort_values('date') training_scope = "原始数据" else: # 如果传入了product_df,直接使用 if training_mode == 'store' and store_id: training_scope = f"店铺 {store_id}" elif training_mode == 'global': training_scope = f"全局聚合({aggregation_method})" else: training_scope = "所有店铺" if product_df.empty: raise ValueError(f"产品 {product_id} 没有可用的销售数据") # 数据量检查 min_required_samples = sequence_length + forecast_horizon if len(product_df) < min_required_samples: error_msg = ( f"❌ 训练数据不足错误\n" f"当前配置需要: {min_required_samples} 天数据 (LOOK_BACK={sequence_length} + FORECAST_HORIZON={forecast_horizon})\n" f"实际数据量: {len(product_df)} 天\n" f"产品ID: {product_id}, 训练模式: {training_mode}\n" f"建议解决方案:\n" f"1. 生成更多数据: uv run generate_multi_store_data.py\n" f"2. 调整配置参数: 减小 LOOK_BACK 或 FORECAST_HORIZON\n" f"3. 使用全局训练模式聚合更多数据" ) print(error_msg) raise ValueError(error_msg) product_df = product_df.sort_values('date') product_name = product_df['product_name'].iloc[0] model_type = "优化版KAN" if use_optimized else "KAN" print(f"使用{model_type}模型训练产品 '{product_name}' (ID: {product_id}) 的销售预测模型") print(f"训练范围: {training_scope}") print(f"使用设备: {DEVICE}") print(f"模型将保存到目录: {model_dir}") # 创建特征和目标变量 features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'] # 预处理数据 X = product_df[features].values y = product_df[['sales']].values # 保持为二维数组 # 归一化数据 scaler_X = MinMaxScaler(feature_range=(0, 1)) scaler_y = MinMaxScaler(feature_range=(0, 1)) X_scaled = scaler_X.fit_transform(X) y_scaled = scaler_y.fit_transform(y) # 划分训练集和测试集(80% 训练,20% 测试) train_size = int(len(X_scaled) * 0.8) X_train, X_test = X_scaled[:train_size], X_scaled[train_size:] y_train, y_test = y_scaled[:train_size], y_scaled[train_size:] # 创建时间序列数据 trainX, trainY = create_dataset(X_train, y_train, sequence_length, forecast_horizon) testX, testY = create_dataset(X_test, y_test, sequence_length, forecast_horizon) # 转换为PyTorch的Tensor trainX_tensor = torch.Tensor(trainX) trainY_tensor = torch.Tensor(trainY) testX_tensor = torch.Tensor(testX) testY_tensor = torch.Tensor(testY) # 创建数据加载器 train_dataset = PharmacyDataset(trainX_tensor, trainY_tensor) test_dataset = PharmacyDataset(testX_tensor, testY_tensor) batch_size = 32 train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # 初始化KAN模型 input_dim = X_train.shape[1] output_dim = forecast_horizon hidden_size = 64 if use_optimized: model = OptimizedKANForecaster( input_features=input_dim, hidden_sizes=[hidden_size, hidden_size*2, hidden_size], output_sequence_length=output_dim ) else: model = KANForecaster( input_features=input_dim, hidden_sizes=[hidden_size, hidden_size*2, hidden_size], output_sequence_length=output_dim ) # 将模型移动到设备上 model = model.to(DEVICE) criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=0.001) # 训练模型 from utils.model_manager import model_manager model_type_name = 'optimized_kan' if use_optimized else 'kan' current_version = model_manager.peek_next_version( model_type=model_type_name, product_id=model_identifier, store_id=store_id, training_mode=training_mode, aggregation_method=aggregation_method ) print(f"🔒 本次训练版本锁定为: {current_version}") train_losses = [] test_losses = [] start_time = time.time() best_loss = float('inf') best_model_path = None for epoch in range(epochs): model.train() epoch_loss = 0 for X_batch, y_batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False): X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE) # 确保目标张量有正确的形状 (batch_size, forecast_horizon, 1) if y_batch.dim() == 2: y_batch = y_batch.unsqueeze(-1) # 前向传播 outputs = model(X_batch) # 确保输出形状与目标匹配 if outputs.dim() == 2: outputs = outputs.unsqueeze(-1) loss = criterion(outputs, y_batch) # 如果是KAN模型,加入正则化损失 if hasattr(model, 'regularization_loss'): loss = loss + model.regularization_loss() * 0.01 # 反向传播和优化 optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss += loss.item() # 计算训练损失 train_loss = epoch_loss / len(train_loader) train_losses.append(train_loss) # 在测试集上评估 model.eval() test_loss = 0 with torch.no_grad(): for X_batch, y_batch in test_loader: X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE) # 确保目标张量有正确的形状 if y_batch.dim() == 2: y_batch = y_batch.unsqueeze(-1) outputs = model(X_batch) # 确保输出形状与目标匹配 if outputs.dim() == 2: outputs = outputs.unsqueeze(-1) loss = criterion(outputs, y_batch) test_loss += loss.item() test_loss = test_loss / len(test_loader) test_losses.append(test_loss) # 检查是否为最佳模型 model_type_name = 'optimized_kan' if use_optimized else 'kan' if test_loss < best_loss: best_loss = test_loss print(f"🎉 新的最佳模型发现在 epoch {epoch+1},测试损失: {test_loss:.4f}") # 为保存最佳模型准备数据 best_model_data = { 'model_state_dict': model.state_dict(), 'scaler_X': scaler_X, 'scaler_y': scaler_y, 'config': { 'input_dim': input_dim, 'output_dim': output_dim, 'hidden_size': hidden_size, 'hidden_sizes': [hidden_size, hidden_size * 2, hidden_size], 'sequence_length': sequence_length, 'forecast_horizon': forecast_horizon, 'model_type': model_type_name, 'use_optimized': use_optimized }, 'epoch': epoch + 1 } # 使用模型管理器保存 'best' 版本 from utils.model_manager import model_manager best_model_path, _ = model_manager.save_model( model_data=best_model_data, product_id=model_identifier, # 修正:使用唯一的标识符 model_type=model_type_name, store_id=store_id, training_mode=training_mode, aggregation_method=aggregation_method, product_name=product_name, version=f"{current_version}_best" ) if (epoch + 1) % 10 == 0: print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}") # 计算训练时间 training_time = time.time() - start_time # 绘制损失曲线并保存到模型目录 # 评估模型 model.eval() with torch.no_grad(): test_pred = model(testX_tensor.to(DEVICE)).cpu().numpy() # 处理输出形状 if len(test_pred.shape) == 3: test_pred = test_pred.squeeze(-1) # 反归一化预测结果和真实值 test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, 1)).flatten() test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, 1)).flatten() # 计算评估指标 metrics = evaluate_model(test_true_inv, test_pred_inv) metrics['training_time'] = training_time # 打印评估指标 print("\n模型评估指标:") print(f"MSE: {metrics['mse']:.4f}") print(f"RMSE: {metrics['rmse']:.4f}") print(f"MAE: {metrics['mae']:.4f}") print(f"R²: {metrics['r2']:.4f}") print(f"MAPE: {metrics['mape']:.2f}%") print(f"训练时间: {training_time:.2f}秒") # --- 5. 保存工件 --- model_type_name = 'optimized_kan' if use_optimized else 'kan' # 准备 scope 和 identifier 以生成标准化的文件名 scope = training_mode if scope == 'product': identifier = model_identifier elif scope == 'store': identifier = store_id elif scope == 'global': identifier = aggregation_method else: identifier = product_name # 后备方案 # 绘制带有版本号的损失曲线图 loss_curve_path = plot_loss_curve( train_losses=train_losses, val_losses=test_losses, model_type=model_type_name, scope=scope, identifier=identifier, version=current_version, # 使用锁定的版本 model_dir=model_dir ) print(f"📈 带版本号的损失曲线已保存: {loss_curve_path}") # 准备要保存的最终模型数据 model_data = { 'model_state_dict': model.state_dict(), 'scaler_X': scaler_X, 'scaler_y': scaler_y, 'config': { 'input_dim': input_dim, 'output_dim': output_dim, 'hidden_size': hidden_size, 'hidden_sizes': [hidden_size, hidden_size * 2, hidden_size], 'sequence_length': sequence_length, 'forecast_horizon': forecast_horizon, 'model_type': model_type_name, 'use_optimized': use_optimized }, 'metrics': metrics, 'loss_history': { 'train': train_losses, 'test': test_losses, 'epochs': list(range(1, epochs + 1)) }, 'loss_curve_path': loss_curve_path # 直接包含路径 } # 使用模型管理器保存最终模型 from utils.model_manager import model_manager final_model_path, final_version = model_manager.save_model( model_data=model_data, product_id=model_identifier, model_type=model_type_name, store_id=store_id, training_mode=training_mode, aggregation_method=aggregation_method, product_name=product_name, version=current_version # 使用锁定的版本 ) print(f"✅ {model_type_name} 最终模型已保存,版本: {final_version}") # 组装返回的工件 artifacts = { "versioned_model": final_model_path, "loss_curve_plot": loss_curve_path, "best_model": best_model_path, "version": final_version } return model, metrics, artifacts # --- 将此训练器注册到系统中 --- from models.model_registry import register_trainer register_trainer('kan', train_product_model_with_kan) register_trainer('optimized_kan', train_product_model_with_kan)