""" 药店销售预测系统 - 模型预测函数 """ import os import torch import pandas as pd import numpy as np from datetime import datetime, timedelta import matplotlib.pyplot as plt from sklearn.preprocessing import MinMaxScaler import sklearn.preprocessing._data # 添加这一行以支持MinMaxScaler的反序列化 from models.transformer_model import TimeSeriesTransformer from models.slstm_model import sLSTM as ScalarLSTM from models.mlstm_model import MLSTMTransformer as MatrixLSTM from models.kan_model import KANForecaster from models.tcn_model import TCNForecaster from models.optimized_kan_forecaster import OptimizedKANForecaster from analysis.trend_analysis import analyze_prediction_result from utils.visualization import plot_prediction_results from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data from core.config import DEVICE, get_model_file_path, DEFAULT_DATA_PATH def load_model_and_predict(product_id, model_type, store_id=None, future_days=7, start_date=None, analyze_result=False, version=None, training_mode='product'): """ 加载已训练的模型并进行预测 参数: product_id: 产品ID model_type: 模型类型 ('transformer', 'mlstm', 'kan', 'tcn', 'optimized_kan') store_id: 店铺ID,为None时使用全局模型 future_days: 预测未来天数 start_date: 预测起始日期,如果为None则使用最后一个已知日期 analyze_result: 是否分析预测结果 version: 模型版本,如果为None则使用最新版本 返回: 预测结果和分析(如果analyze_result为True) """ try: # 确定模型文件路径(支持多店铺) model_path = None if version: # 使用版本管理系统获取正确的文件路径 model_path = get_model_file_path(product_id, model_type, version) else: # 根据store_id确定搜索目录 if store_id: # 查找特定店铺的模型 possible_dirs = [ os.path.join('saved_models', model_type, store_id), os.path.join('models', model_type, store_id) ] else: # 查找全局模型 possible_dirs = [ os.path.join('saved_models', model_type, 'global'), os.path.join('models', model_type, 'global'), os.path.join('saved_models', model_type), # 后向兼容 'saved_models' # 最基本的目录 ] # 文件名模式 model_suffix = '_optimized' if model_type == 'optimized_kan' else '' file_model_type = 'kan' if model_type == 'optimized_kan' else model_type possible_names = [ f"{product_id}_{model_type}_v1_model.pt", # 新多店铺格式 f"{product_id}_{model_type}_v1_global_model.pt", # 全局模型格式 f"{product_id}_{model_type}_v1.pth", # 旧版本格式 f"{file_model_type}{model_suffix}_model_product_{product_id}.pth", # 原始格式 f"{model_type}_model_product_{product_id}.pth" # 简化格式 ] # 搜索模型文件 for dir_path in possible_dirs: if not os.path.exists(dir_path): continue for name in possible_names: test_path = os.path.join(dir_path, name) if os.path.exists(test_path): model_path = test_path break if model_path: break if not model_path: scope_msg = f"店铺 {store_id}" if store_id else "全局" print(f"找不到产品 {product_id} 的 {model_type} 模型文件 ({scope_msg})") print(f"搜索目录: {possible_dirs}") return None print(f"尝试加载模型文件: {model_path}") if not os.path.exists(model_path): print(f"模型文件 {model_path} 不存在") return None # 加载销售数据(支持多店铺) try: from utils.multi_store_data_utils import aggregate_multi_store_data # 根据训练模式加载相应的数据 if training_mode == 'store' and store_id: # 店铺模型:聚合该店铺的所有产品数据 product_df = aggregate_multi_store_data( store_id=store_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH ) store_name = product_df['store_name'].iloc[0] if 'store_name' in product_df.columns and not product_df.empty else f"店铺{store_id}" prediction_scope = f"店铺 '{store_name}' ({store_id})" product_name = store_name elif training_mode == 'global': # 全局模型:聚合所有数据 product_df = aggregate_multi_store_data( aggregation_method='sum', file_path=DEFAULT_DATA_PATH ) prediction_scope = "全局聚合数据" product_name = "全局销售数据" else: # 产品模型(默认):聚合该产品在所有店铺的数据 product_df = aggregate_multi_store_data( product_id=product_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH ) prediction_scope = "全部店铺(聚合数据)" product_name = product_df['product_name'].iloc[0] if not product_df.empty else product_id except Exception as e: print(f"加载数据失败: {e}") return None if product_df.empty: print(f"产品 {product_id} 或店铺 {store_id} 没有销售数据") return None print(f"使用 {model_type} 模型预测产品 '{product_name}' (ID: {product_id}) 的未来 {future_days} 天销量") print(f"预测范围: {prediction_scope}") # 添加安全的全局变量以支持MinMaxScaler的反序列化 try: torch.serialization.add_safe_globals([sklearn.preprocessing._data.MinMaxScaler]) except Exception as e: print(f"添加安全全局变量失败,但这可能不影响模型加载: {str(e)}") # 加载模型和配置 try: # 首先尝试使用weights_only=False加载 try: print("尝试使用 weights_only=False 加载模型") checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False) except Exception as e: print(f"使用weights_only=False加载失败: {str(e)}") print("尝试使用默认参数加载模型") checkpoint = torch.load(model_path, map_location=DEVICE) print(f"模型加载成功,检查checkpoint类型: {type(checkpoint)}") if isinstance(checkpoint, dict): print(f"checkpoint包含的键: {list(checkpoint.keys())}") else: print(f"checkpoint不是字典类型,而是: {type(checkpoint)}") return None except Exception as e: print(f"加载模型失败: {str(e)}") return None # 检查并获取配置 if 'config' not in checkpoint: print("模型文件中没有配置信息") return None config = checkpoint['config'] print(f"模型配置: {config}") # 检查并获取缩放器 if 'scaler_X' not in checkpoint or 'scaler_y' not in checkpoint: print("模型文件中没有缩放器信息") return None scaler_X = checkpoint['scaler_X'] scaler_y = checkpoint['scaler_y'] # 创建模型实例 try: if model_type == 'transformer': model = TimeSeriesTransformer( num_features=config['input_dim'], d_model=config['hidden_size'], nhead=config['num_heads'], num_encoder_layers=config['num_layers'], dim_feedforward=config['hidden_size'] * 2, dropout=config['dropout'], output_sequence_length=config['output_dim'], seq_length=config['sequence_length'], batch_size=32 ).to(DEVICE) elif model_type == 'slstm': model = ScalarLSTM( input_dim=config['input_dim'], hidden_dim=config['hidden_size'], output_dim=config['output_dim'], num_layers=config['num_layers'], dropout=config['dropout'] ).to(DEVICE) elif model_type == 'mlstm': # 获取配置参数,如果不存在则使用默认值 embed_dim = config.get('embed_dim', 32) dense_dim = config.get('dense_dim', 32) num_heads = config.get('num_heads', 4) num_blocks = config.get('num_blocks', 3) model = MatrixLSTM( num_features=config['input_dim'], hidden_size=config['hidden_size'], mlstm_layers=config['mlstm_layers'], embed_dim=embed_dim, dense_dim=dense_dim, num_heads=num_heads, dropout_rate=config['dropout_rate'], num_blocks=num_blocks, output_sequence_length=config['output_dim'] ).to(DEVICE) elif model_type == 'kan': model = KANForecaster( input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim'] ).to(DEVICE) elif model_type == 'optimized_kan': model = OptimizedKANForecaster( input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim'] ).to(DEVICE) elif model_type == 'tcn': model = TCNForecaster( num_features=config['input_dim'], output_sequence_length=config['output_dim'], num_channels=[config['hidden_size']] * config['num_layers'], kernel_size=config['kernel_size'], dropout=config['dropout'] ).to(DEVICE) else: print(f"不支持的模型类型: {model_type}") return None print(f"模型实例创建成功: {type(model)}") except Exception as e: print(f"创建模型实例失败: {str(e)}") return None # 加载模型参数 try: model.load_state_dict(checkpoint['model_state_dict']) model.eval() print("模型参数加载成功") except Exception as e: print(f"加载模型参数失败: {str(e)}") return None # 准备输入数据 try: features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'] sequence_length = config['sequence_length'] # 获取最近的sequence_length天数据作为输入 recent_data = product_df.iloc[-sequence_length:].copy() # 如果指定了起始日期,则使用该日期之后的数据 if start_date: if isinstance(start_date, str): start_date = datetime.strptime(start_date, '%Y-%m-%d') recent_data = product_df[product_df['date'] >= start_date].iloc[:sequence_length].copy() if len(recent_data) < sequence_length: print(f"警告: 从指定日期 {start_date} 开始的数据少于所需的 {sequence_length} 天") # 补充数据 missing_days = sequence_length - len(recent_data) additional_data = product_df[product_df['date'] < start_date].iloc[-missing_days:].copy() recent_data = pd.concat([additional_data, recent_data]).reset_index(drop=True) print(f"输入数据准备完成,形状: {recent_data.shape}") except Exception as e: print(f"准备输入数据失败: {str(e)}") return None # 归一化输入数据 try: X = recent_data[features].values X_scaled = scaler_X.transform(X) # 转换为模型输入格式 X_input = torch.tensor(X_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE) print(f"输入张量准备完成,形状: {X_input.shape}") except Exception as e: print(f"归一化输入数据失败: {str(e)}") return None # 预测 try: with torch.no_grad(): y_pred_scaled = model(X_input).cpu().numpy() print(f"原始预测输出形状: {y_pred_scaled.shape}") # 处理TCN、Transformer、mLSTM和KAN模型的输出,确保形状正确 if model_type in ['tcn', 'transformer', 'mlstm', 'kan', 'optimized_kan'] and len(y_pred_scaled.shape) == 3: y_pred_scaled = y_pred_scaled.squeeze(-1) print(f"处理后的预测输出形状: {y_pred_scaled.shape}") # 反归一化预测结果 y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten() print(f"反归一化后的预测结果: {y_pred}") # 生成预测日期 last_date = recent_data['date'].iloc[-1] pred_dates = [(last_date + timedelta(days=i+1)) for i in range(len(y_pred))] print(f"预测日期: {pred_dates}") except Exception as e: print(f"执行预测失败: {str(e)}") return None # 创建预测结果DataFrame try: predictions_df = pd.DataFrame({ 'date': pred_dates, 'sales': y_pred # 使用sales字段名而不是predicted_sales,以便与历史数据兼容 }) print(f"预测结果DataFrame创建成功,形状: {predictions_df.shape}") except Exception as e: print(f"创建预测结果DataFrame失败: {str(e)}") return None # 绘制预测结果 try: plt.figure(figsize=(12, 6)) plt.plot(product_df['date'], product_df['sales'], 'b-', label='历史销量') plt.plot(predictions_df['date'], predictions_df['sales'], 'r--', label='预测销量') plt.title(f'{product_name} - {model_type}模型销量预测') plt.xlabel('日期') plt.ylabel('销量') plt.legend() plt.grid(True) plt.xticks(rotation=45) plt.tight_layout() # 保存图像 plt.savefig(f'{product_id}_{model_type}_prediction.png') plt.close() print(f"预测结果已保存到 {product_id}_{model_type}_prediction.png") except Exception as e: print(f"绘制预测结果图表失败: {str(e)}") # 这个错误不影响主要功能,继续执行 # 分析预测结果 analysis = None if analyze_result: try: analysis = analyze_prediction_result(product_id, model_type, y_pred, X) print("\n预测结果分析:") if analysis and 'explanation' in analysis: print(analysis['explanation']) else: print("分析结果不包含explanation字段") except Exception as e: print(f"分析预测结果失败: {str(e)}") # 分析失败不影响主要功能,继续执行 # 准备用于图表展示的历史数据 history_df = product_df if start_date: try: # 筛选出所有早于预测起始日期的数据 history_df = product_df[product_df['date'] < pd.to_datetime(start_date)] except Exception as e: print(f"筛选历史数据时日期格式错误: {e}") # 从正确的历史记录中取最后30天 recent_history = history_df.tail(30) return { 'product_id': product_id, 'product_name': product_name, 'model_type': model_type, 'predictions': predictions_df, 'history_data': recent_history, # 将历史数据添加到返回结果中 'analysis': analysis } except Exception as e: print(f"预测过程中出现未捕获的异常: {str(e)}") import traceback traceback.print_exc() return None