""" 药店销售预测系统 - 模型预测函数 """ import os import torch import pandas as pd import numpy as np from datetime import datetime, timedelta import matplotlib.pyplot as plt from sklearn.preprocessing import MinMaxScaler import sklearn.preprocessing._data # 添加这一行以支持MinMaxScaler的反序列化 from typing import Optional from models.transformer_model import TimeSeriesTransformer from models.slstm_model import sLSTM as ScalarLSTM from models.mlstm_model import MLSTMTransformer as MatrixLSTM from models.kan_model import KANForecaster from models.tcn_model import TCNForecaster from models.optimized_kan_forecaster import OptimizedKANForecaster from models.cnn_bilstm_attention import CnnBiLstmAttention import xgboost as xgb from analysis.trend_analysis import analyze_prediction_result from utils.visualization import plot_prediction_results from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data from core.config import DEVICE, get_model_file_path, DEFAULT_DATA_PATH from models.model_registry import get_predictor, register_predictor def default_pytorch_predictor(model, checkpoint, product_df, future_days, start_date, history_lookback_days): """ 默认的PyTorch模型预测逻辑,支持自动回归。 """ config = checkpoint['config'] scaler_X = checkpoint['scaler_X'] scaler_y = checkpoint['scaler_y'] features = config.get('features', ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']) sequence_length = config['sequence_length'] if start_date: start_date_dt = pd.to_datetime(start_date) prediction_input_df = product_df[product_df['date'] < start_date_dt].tail(sequence_length) else: prediction_input_df = product_df.tail(sequence_length) start_date_dt = product_df['date'].iloc[-1] + timedelta(days=1) if len(prediction_input_df) < sequence_length: raise ValueError(f"预测所需的历史数据不足。需要 {sequence_length} 天, 但只有 {len(prediction_input_df)} 天。") history_for_chart_df = product_df[product_df['date'] < start_date_dt].tail(history_lookback_days) all_predictions = [] current_sequence_df = prediction_input_df.copy() for _ in range(future_days): X_current_scaled = scaler_X.transform(current_sequence_df[features].values) # **核心改进**: 智能判断模型类型并调用相应的预测方法 if isinstance(model, xgb.Booster): # XGBoost 模型预测路径 X_input_reshaped = X_current_scaled.reshape(1, -1) d_input = xgb.DMatrix(X_input_reshaped) # **关键修复**: 使用 best_iteration 进行预测,以匹配早停策略 y_pred_scaled = model.predict(d_input, iteration_range=(0, model.best_iteration)) next_step_pred_scaled = y_pred_scaled.reshape(1, -1) else: # 默认 PyTorch 模型预测路径 X_input = torch.tensor(X_current_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE) with torch.no_grad(): y_pred_scaled = model(X_input).cpu().numpy() next_step_pred_scaled = y_pred_scaled[0, 0].reshape(1, -1) next_step_pred_unscaled = float(max(0, scaler_y.inverse_transform(next_step_pred_scaled)[0][0])) next_date = current_sequence_df['date'].iloc[-1] + timedelta(days=1) all_predictions.append({'date': next_date, 'predicted_sales': next_step_pred_unscaled}) new_row = {'date': next_date, 'sales': next_step_pred_unscaled, 'weekday': next_date.weekday(), 'month': next_date.month, 'is_holiday': 0, 'is_weekend': 1 if next_date.weekday() >= 5 else 0, 'is_promotion': 0, 'temperature': current_sequence_df['temperature'].iloc[-1]} new_row_df = pd.DataFrame([new_row]) current_sequence_df = pd.concat([current_sequence_df.iloc[1:], new_row_df], ignore_index=True) predictions_df = pd.DataFrame(all_predictions) return predictions_df, history_for_chart_df, prediction_input_df # 注册默认的PyTorch预测器 register_predictor('default', default_pytorch_predictor) # 将增强后的默认预测器也注册给xgboost register_predictor('xgboost', default_pytorch_predictor) # 将新模型也注册给默认预测器 register_predictor('cnn_bilstm_attention', default_pytorch_predictor) def load_model_and_predict(model_path: str, product_id: str, model_type: str, store_id: Optional[str] = None, future_days: int = 7, start_date: Optional[str] = None, analyze_result: bool = False, version: Optional[str] = None, training_mode: str = 'product', history_lookback_days: int = 30): """ 加载已训练的模型并进行预测 (v4版 - 插件式架构) """ try: if not os.path.exists(model_path): raise FileNotFoundError(f"模型文件 {model_path} 不存在") # --- 数据加载部分保持不变 --- from utils.multi_store_data_utils import aggregate_multi_store_data if training_mode == 'store' and store_id: from utils.multi_store_data_utils import load_multi_store_data store_df_for_name = load_multi_store_data(store_id=store_id) product_name = store_df_for_name['store_name'].iloc[0] if not store_df_for_name.empty else f"店铺 {store_id}" product_df = aggregate_multi_store_data(store_id=store_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH) elif training_mode == 'global': product_df = aggregate_multi_store_data(aggregation_method='sum', file_path=DEFAULT_DATA_PATH) product_name = "全局销售数据" else: product_df = aggregate_multi_store_data(product_id=product_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH) product_name = product_df['product_name'].iloc[0] if not product_df.empty else product_id if product_df.empty: raise ValueError(f"产品 {product_id} 或店铺 {store_id} 没有销售数据") # --- 模型加载与实例化 (重构) --- try: torch.serialization.add_safe_globals([sklearn.preprocessing._data.MinMaxScaler]) except Exception: pass checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False) config = checkpoint.get('config', {}) loaded_model_type = config.get('model_type', model_type) # 优先使用模型内保存的类型 # 根据模型类型决定如何获取模型实例 if loaded_model_type == 'xgboost': # 对于XGBoost, 模型对象直接保存在'model_state_dict'键中 model = checkpoint['model_state_dict'] else: # 对于PyTorch模型, 需要重新构建实例并加载state_dict if loaded_model_type == 'transformer': model = TimeSeriesTransformer(num_features=config['input_dim'], d_model=config['hidden_size'], nhead=config['num_heads'], num_encoder_layers=config['num_layers'], dim_feedforward=config['hidden_size'] * 2, dropout=config['dropout'], output_sequence_length=config['output_dim'], seq_length=config['sequence_length'], batch_size=32).to(DEVICE) elif loaded_model_type == 'mlstm': model = MatrixLSTM(num_features=config['input_dim'], hidden_size=config['hidden_size'], mlstm_layers=config['mlstm_layers'], embed_dim=config.get('embed_dim', 32), dense_dim=config.get('dense_dim', 32), num_heads=config.get('num_heads', 4), dropout_rate=config['dropout_rate'], num_blocks=config.get('num_blocks', 3), output_sequence_length=config['output_dim']).to(DEVICE) elif loaded_model_type == 'kan': model = KANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE) elif loaded_model_type == 'optimized_kan': model = OptimizedKANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE) elif loaded_model_type == 'tcn': model = TCNForecaster(num_features=config['input_dim'], output_sequence_length=config['output_dim'], num_channels=[config['hidden_size']] * config['num_layers'], kernel_size=config['kernel_size'], dropout=config['dropout']).to(DEVICE) elif loaded_model_type == 'cnn_bilstm_attention': model = CnnBiLstmAttention( input_dim=config['input_dim'], output_dim=config['output_dim'], sequence_length=config['sequence_length'] ).to(DEVICE) else: raise ValueError(f"不支持的模型类型: {loaded_model_type}") model.load_state_dict(checkpoint['model_state_dict']) model.eval() # --- 动态调用预测器 --- predictor_function = get_predictor(loaded_model_type) if not predictor_function: raise ValueError(f"找不到模型类型 '{loaded_model_type}' 的预测器实现") predictions_df, history_for_chart_df, prediction_input_df = predictor_function( model=model, checkpoint=checkpoint, product_df=product_df, future_days=future_days, start_date=start_date, history_lookback_days=history_lookback_days ) # --- 分析与返回部分保持不变 --- analysis = None if analyze_result: try: analysis = analyze_prediction_result(product_id, loaded_model_type, predictions_df['predicted_sales'].values, prediction_input_df[config.get('features')].values) except Exception as e: print(f"分析预测结果失败: {str(e)}") history_data_json = history_for_chart_df.to_dict('records') if not history_for_chart_df.empty else [] prediction_data_json = predictions_df.to_dict('records') if not predictions_df.empty else [] return { 'product_id': product_id, 'product_name': product_name, 'model_type': loaded_model_type, 'predictions': prediction_data_json, 'prediction_data': prediction_data_json, 'history_data': history_data_json, 'analysis': analysis } except Exception as e: print(f"预测过程中出现未捕获的异常: {str(e)}") import traceback traceback.print_exc() return None