""" 药店销售预测系统 - KAN模型训练函数 (已重构) """ import os import time import pandas as pd import numpy as np import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from sklearn.preprocessing import MinMaxScaler import matplotlib.pyplot as plt from tqdm import tqdm from datetime import datetime from models.kan_model import KANForecaster from models.optimized_kan_forecaster import OptimizedKANForecaster from utils.data_utils import create_dataset, PharmacyDataset from analysis.metrics import evaluate_model from core.config import DEVICE, LOOK_BACK, FORECAST_HORIZON from utils.model_manager import model_manager def train_product_model_with_kan( product_id, product_df=None, store_id=None, training_mode='product', aggregation_method='sum', epochs=50, use_optimized=False, socketio=None, task_id=None, progress_callback=None, patience=10, learning_rate=0.001 ): """ 使用KAN模型训练产品销售预测模型 (已适配新的ModelManager) """ def emit_progress(message, progress=None, metrics=None): """发送训练进度到前端""" progress_data = { 'task_id': task_id, 'message': f"[KAN] {message}", 'timestamp': time.time() } if progress is not None: progress_data['progress'] = progress if metrics is not None: progress_data['metrics'] = metrics if progress_callback: try: progress_callback(progress_data) except Exception as e: print(f"[KAN] 进度回调失败: {e}") if socketio and task_id: try: socketio.emit('training_progress', progress_data, namespace='/training') except Exception as e: print(f"[KAN] WebSocket发送失败: {e}") print(f"[KAN] {message}", flush=True) emit_progress("开始KAN模型训练...") # 1. 确定模型标识符和版本 model_type = 'optimized_kan' if use_optimized else 'kan' if training_mode == 'store': scope = f"{store_id}_{product_id}" elif training_mode == 'global': scope = f"{product_id}" if product_id else "all" else: # 'product' mode scope = f"{product_id}_all" model_identifier = model_manager.get_model_identifier(model_type, training_mode, scope, aggregation_method) version = model_manager.get_next_version_number(model_identifier) emit_progress(f"开始训练 {model_type} 模型 v{version}") # 2. 获取模型版本路径 model_version_path = model_manager.get_model_version_path(model_type, training_mode, scope, version, aggregation_method) emit_progress(f"模型将保存到: {model_version_path}") # 3. 数据加载和预处理 if product_df is None: # 此处保留了原有的数据加载逻辑作为后备 from utils.multi_store_data_utils import load_multi_store_data, get_store_product_sales_data, aggregate_multi_store_data try: if training_mode == 'store' and store_id: product_df = get_store_product_sales_data(store_id, product_id, 'pharmacy_sales_multi_store.csv') elif training_mode == 'global': product_df = aggregate_multi_store_data(product_id, aggregation_method=aggregation_method, file_path='pharmacy_sales_multi_store.csv') else: product_df = load_multi_store_data('pharmacy_sales_multi_store.csv', product_id=product_id) except Exception as e: emit_progress(f"多店铺数据加载失败: {e}, 尝试后备方案...") df = pd.read_excel('pharmacy_sales.xlsx') product_df = df[df['product_id'] == product_id].sort_values('date') if training_mode == 'store' and store_id: training_scope = f"店铺 {store_id}" elif training_mode == 'global': training_scope = f"全局聚合({aggregation_method})" else: training_scope = "所有店铺" min_required_samples = LOOK_BACK + FORECAST_HORIZON if len(product_df) < min_required_samples: error_msg = f"数据不足: 需要 {min_required_samples} 天, 实际 {len(product_df)} 天。" emit_progress(f"训练失败:{error_msg}") raise ValueError(error_msg) product_df = product_df.sort_values('date') product_name = product_df['product_name'].iloc[0] emit_progress(f"训练产品: '{product_name}' (ID: {product_id}) - {training_scope}") emit_progress(f"使用设备: {DEVICE}, 数据量: {len(product_df)} 条") features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'] X = product_df[features].values y = product_df[['sales']].values scaler_X = MinMaxScaler(feature_range=(0, 1)) scaler_y = MinMaxScaler(feature_range=(0, 1)) X_scaled = scaler_X.fit_transform(X) y_scaled = scaler_y.fit_transform(y) train_size = int(len(X_scaled) * 0.8) X_train, X_test = X_scaled[:train_size], X_scaled[train_size:] y_train, y_test = y_scaled[:train_size], y_scaled[train_size:] trainX, trainY = create_dataset(X_train, y_train, LOOK_BACK, FORECAST_HORIZON) testX, testY = create_dataset(X_test, y_test, LOOK_BACK, FORECAST_HORIZON) train_loader = DataLoader(PharmacyDataset(torch.Tensor(trainX), torch.Tensor(trainY)), batch_size=32, shuffle=True) test_loader = DataLoader(PharmacyDataset(torch.Tensor(testX), torch.Tensor(testY)), batch_size=32, shuffle=False) # 4. 模型初始化 input_dim = X_train.shape[1] output_dim = FORECAST_HORIZON hidden_size = 64 if use_optimized: model = OptimizedKANForecaster(input_features=input_dim, hidden_sizes=[hidden_size, hidden_size*2, hidden_size], output_sequence_length=output_dim) else: model = KANForecaster(input_features=input_dim, hidden_sizes=[hidden_size, hidden_size*2, hidden_size], output_sequence_length=output_dim) model = model.to(DEVICE) criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=learning_rate) emit_progress("数据预处理完成,开始模型训练...", progress=10) # 5. 训练循环 train_losses, test_losses = [], [] start_time = time.time() best_loss = float('inf') epochs_no_improve = 0 for epoch in range(epochs): model.train() epoch_loss = 0 for X_batch, y_batch in train_loader: X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE) if y_batch.dim() == 2: y_batch = y_batch.unsqueeze(-1) outputs = model(X_batch) if outputs.dim() == 2: outputs = outputs.unsqueeze(-1) loss = criterion(outputs, y_batch) if hasattr(model, 'regularization_loss'): loss = loss + model.regularization_loss() * 0.01 optimizer.zero_grad() loss.backward() optimizer.step() epoch_loss += loss.item() train_loss = epoch_loss / len(train_loader) train_losses.append(train_loss) model.eval() test_loss = 0 with torch.no_grad(): for X_batch, y_batch in test_loader: X_batch, y_batch = X_batch.to(DEVICE), y_batch.to(DEVICE) if y_batch.dim() == 2: y_batch = y_batch.unsqueeze(-1) outputs = model(X_batch) if outputs.dim() == 2: outputs = outputs.unsqueeze(-1) loss = criterion(outputs, y_batch) test_loss += loss.item() test_loss /= len(test_loader) test_losses.append(test_loss) progress_percentage = 10 + ((epoch + 1) / epochs) * 85 emit_progress(f"Epoch {epoch+1}/{epochs} - Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}", progress=progress_percentage) if test_loss < best_loss: best_loss = test_loss epochs_no_improve = 0 checkpoint_data = { 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scaler_X': scaler_X, 'scaler_y': scaler_y, } model_manager.save_model_artifact(checkpoint_data, "checkpoint_best.pth", model_version_path) emit_progress(f"💾 保存最佳模型检查点 (epoch {epoch+1}, test_loss: {test_loss:.4f})") else: epochs_no_improve += 1 if epochs_no_improve >= patience: emit_progress(f"连续 {patience} 个epoch测试损失未改善,提前停止训练。") break training_time = time.time() - start_time # 6. 保存产物和评估 loss_fig = plt.figure(figsize=(10, 6)) plt.plot(train_losses, label='Training Loss') plt.plot(test_losses, label='Test Loss') plt.title(f'{model_type} 损失曲线 - {product_name} (v{version}) - {training_scope}') plt.xlabel('Epoch'); plt.ylabel('Loss'); plt.legend(); plt.grid(True) model_manager.save_model_artifact(loss_fig, "loss_curve.png", model_version_path) plt.close(loss_fig) emit_progress(f"损失曲线已保存到: {os.path.join(model_version_path, 'loss_curve.png')}") model.eval() with torch.no_grad(): testX_tensor = torch.Tensor(testX).to(DEVICE) test_pred = model(testX_tensor).cpu().numpy() if len(test_pred.shape) == 3: test_pred = test_pred.squeeze(-1) test_pred_inv = scaler_y.inverse_transform(test_pred.reshape(-1, FORECAST_HORIZON)) test_true_inv = scaler_y.inverse_transform(testY.reshape(-1, FORECAST_HORIZON)) metrics = evaluate_model(test_true_inv.flatten(), test_pred_inv.flatten()) metrics['training_time'] = training_time emit_progress(f"模型评估完成 - RMSE: {metrics['rmse']:.4f}, R²: {metrics['r2']:.4f}") # 7. 保存最终模型和元数据 final_model_data = { 'model_state_dict': model.state_dict(), 'scaler_X': scaler_X, 'scaler_y': scaler_y, } model_manager.save_model_artifact(final_model_data, "model.pth", model_version_path) metadata = { 'product_id': product_id, 'product_name': product_name, 'model_type': model_type, 'version': f'v{version}', 'training_mode': training_mode, 'scope': scope, 'aggregation_method': aggregation_method, 'training_scope_description': training_scope, 'timestamp': datetime.now().isoformat(), 'metrics': metrics, 'config': { 'input_dim': input_dim, 'output_dim': output_dim, 'hidden_sizes': [hidden_size, hidden_size*2, hidden_size], 'sequence_length': LOOK_BACK, 'forecast_horizon': FORECAST_HORIZON, 'use_optimized': use_optimized } } model_manager.save_model_artifact(metadata, "metadata.json", model_version_path) # 8. 更新版本文件 model_manager.update_version(model_identifier, version) emit_progress(f"✅ {model_type}模型 v{version} 训练完成!", progress=100, metrics=metrics) return model, metrics, version, model_version_path