191 lines
10 KiB
Python
191 lines
10 KiB
Python
"""
|
||
药店销售预测系统 - 模型预测函数
|
||
"""
|
||
|
||
import os
|
||
import torch
|
||
import pandas as pd
|
||
import numpy as np
|
||
from datetime import datetime, timedelta
|
||
import matplotlib.pyplot as plt
|
||
from sklearn.preprocessing import MinMaxScaler
|
||
import sklearn.preprocessing._data # 添加这一行以支持MinMaxScaler的反序列化
|
||
from typing import Optional
|
||
|
||
from models.transformer_model import TimeSeriesTransformer
|
||
from models.slstm_model import sLSTM as ScalarLSTM
|
||
from models.mlstm_model import MLSTMTransformer as MatrixLSTM
|
||
from models.kan_model import KANForecaster
|
||
from models.tcn_model import TCNForecaster
|
||
from models.optimized_kan_forecaster import OptimizedKANForecaster
|
||
from models.cnn_bilstm_attention import CnnBiLstmAttention
|
||
import xgboost as xgb
|
||
|
||
from analysis.trend_analysis import analyze_prediction_result
|
||
from utils.visualization import plot_prediction_results
|
||
from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data
|
||
from core.config import DEVICE, get_model_file_path, DEFAULT_DATA_PATH
|
||
from models.model_registry import get_predictor, register_predictor
|
||
|
||
def default_pytorch_predictor(model, checkpoint, product_df, future_days, start_date, history_lookback_days):
|
||
"""
|
||
默认的PyTorch模型预测逻辑,支持自动回归。
|
||
"""
|
||
config = checkpoint['config']
|
||
scaler_X = checkpoint['scaler_X']
|
||
scaler_y = checkpoint['scaler_y']
|
||
features = config.get('features', ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'])
|
||
sequence_length = config['sequence_length']
|
||
|
||
if start_date:
|
||
start_date_dt = pd.to_datetime(start_date)
|
||
prediction_input_df = product_df[product_df['date'] < start_date_dt].tail(sequence_length)
|
||
else:
|
||
prediction_input_df = product_df.tail(sequence_length)
|
||
start_date_dt = product_df['date'].iloc[-1] + timedelta(days=1)
|
||
|
||
if len(prediction_input_df) < sequence_length:
|
||
raise ValueError(f"预测所需的历史数据不足。需要 {sequence_length} 天, 但只有 {len(prediction_input_df)} 天。")
|
||
|
||
history_for_chart_df = product_df[product_df['date'] < start_date_dt].tail(history_lookback_days)
|
||
|
||
all_predictions = []
|
||
current_sequence_df = prediction_input_df.copy()
|
||
|
||
for _ in range(future_days):
|
||
X_current_scaled = scaler_X.transform(current_sequence_df[features].values)
|
||
# **核心改进**: 智能判断模型类型并调用相应的预测方法
|
||
if isinstance(model, xgb.Booster):
|
||
# XGBoost 模型预测路径
|
||
X_input_reshaped = X_current_scaled.reshape(1, -1)
|
||
d_input = xgb.DMatrix(X_input_reshaped)
|
||
# **关键修复**: 使用 best_iteration 进行预测,以匹配早停策略
|
||
y_pred_scaled = model.predict(d_input, iteration_range=(0, model.best_iteration))
|
||
next_step_pred_scaled = y_pred_scaled.reshape(1, -1)
|
||
else:
|
||
# 默认 PyTorch 模型预测路径
|
||
X_input = torch.tensor(X_current_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
|
||
with torch.no_grad():
|
||
y_pred_scaled = model(X_input).cpu().numpy()
|
||
next_step_pred_scaled = y_pred_scaled[0, 0].reshape(1, -1)
|
||
next_step_pred_unscaled = float(max(0, scaler_y.inverse_transform(next_step_pred_scaled)[0][0]))
|
||
|
||
next_date = current_sequence_df['date'].iloc[-1] + timedelta(days=1)
|
||
all_predictions.append({'date': next_date, 'predicted_sales': next_step_pred_unscaled})
|
||
|
||
new_row = {'date': next_date, 'sales': next_step_pred_unscaled, 'weekday': next_date.weekday(), 'month': next_date.month, 'is_holiday': 0, 'is_weekend': 1 if next_date.weekday() >= 5 else 0, 'is_promotion': 0, 'temperature': current_sequence_df['temperature'].iloc[-1]}
|
||
new_row_df = pd.DataFrame([new_row])
|
||
current_sequence_df = pd.concat([current_sequence_df.iloc[1:], new_row_df], ignore_index=True)
|
||
|
||
predictions_df = pd.DataFrame(all_predictions)
|
||
return predictions_df, history_for_chart_df, prediction_input_df
|
||
|
||
# 注册默认的PyTorch预测器
|
||
register_predictor('default', default_pytorch_predictor)
|
||
# 将增强后的默认预测器也注册给xgboost
|
||
register_predictor('xgboost', default_pytorch_predictor)
|
||
# 将新模型也注册给默认预测器
|
||
register_predictor('cnn_bilstm_attention', default_pytorch_predictor)
|
||
|
||
|
||
def load_model_and_predict(model_path: str, product_id: str, model_type: str, store_id: Optional[str] = None, future_days: int = 7, start_date: Optional[str] = None, analyze_result: bool = False, version: Optional[str] = None, training_mode: str = 'product', history_lookback_days: int = 30):
|
||
"""
|
||
加载已训练的模型并进行预测 (v4版 - 插件式架构)
|
||
"""
|
||
try:
|
||
if not os.path.exists(model_path):
|
||
raise FileNotFoundError(f"模型文件 {model_path} 不存在")
|
||
|
||
# --- 数据加载部分保持不变 ---
|
||
from utils.multi_store_data_utils import aggregate_multi_store_data
|
||
if training_mode == 'store' and store_id:
|
||
from utils.multi_store_data_utils import load_multi_store_data
|
||
store_df_for_name = load_multi_store_data(store_id=store_id)
|
||
product_name = store_df_for_name['store_name'].iloc[0] if not store_df_for_name.empty else f"店铺 {store_id}"
|
||
product_df = aggregate_multi_store_data(store_id=store_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
|
||
elif training_mode == 'global':
|
||
product_df = aggregate_multi_store_data(aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
|
||
product_name = "全局销售数据"
|
||
else:
|
||
product_df = aggregate_multi_store_data(product_id=product_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
|
||
product_name = product_df['product_name'].iloc[0] if not product_df.empty else product_id
|
||
|
||
if product_df.empty:
|
||
raise ValueError(f"产品 {product_id} 或店铺 {store_id} 没有销售数据")
|
||
|
||
# --- 模型加载与实例化 (重构) ---
|
||
try:
|
||
torch.serialization.add_safe_globals([sklearn.preprocessing._data.MinMaxScaler])
|
||
except Exception: pass
|
||
|
||
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
|
||
config = checkpoint.get('config', {})
|
||
loaded_model_type = config.get('model_type', model_type) # 优先使用模型内保存的类型
|
||
|
||
# 根据模型类型决定如何获取模型实例
|
||
if loaded_model_type == 'xgboost':
|
||
# 对于XGBoost, 模型对象直接保存在'model_state_dict'键中
|
||
model = checkpoint['model_state_dict']
|
||
else:
|
||
# 对于PyTorch模型, 需要重新构建实例并加载state_dict
|
||
if loaded_model_type == 'transformer':
|
||
model = TimeSeriesTransformer(num_features=config['input_dim'], d_model=config['hidden_size'], nhead=config['num_heads'], num_encoder_layers=config['num_layers'], dim_feedforward=config['hidden_size'] * 2, dropout=config['dropout'], output_sequence_length=config['output_dim'], seq_length=config['sequence_length'], batch_size=32).to(DEVICE)
|
||
elif loaded_model_type == 'mlstm':
|
||
model = MatrixLSTM(num_features=config['input_dim'], hidden_size=config['hidden_size'], mlstm_layers=config['mlstm_layers'], embed_dim=config.get('embed_dim', 32), dense_dim=config.get('dense_dim', 32), num_heads=config.get('num_heads', 4), dropout_rate=config['dropout_rate'], num_blocks=config.get('num_blocks', 3), output_sequence_length=config['output_dim']).to(DEVICE)
|
||
elif loaded_model_type == 'kan':
|
||
model = KANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
|
||
elif loaded_model_type == 'optimized_kan':
|
||
model = OptimizedKANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
|
||
elif loaded_model_type == 'tcn':
|
||
model = TCNForecaster(num_features=config['input_dim'], output_sequence_length=config['output_dim'], num_channels=[config['hidden_size']] * config['num_layers'], kernel_size=config['kernel_size'], dropout=config['dropout']).to(DEVICE)
|
||
elif loaded_model_type == 'cnn_bilstm_attention':
|
||
model = CnnBiLstmAttention(
|
||
input_dim=config['input_dim'],
|
||
output_dim=config['output_dim'],
|
||
sequence_length=config['sequence_length']
|
||
).to(DEVICE)
|
||
else:
|
||
raise ValueError(f"不支持的模型类型: {loaded_model_type}")
|
||
|
||
model.load_state_dict(checkpoint['model_state_dict'])
|
||
model.eval()
|
||
|
||
# --- 动态调用预测器 ---
|
||
predictor_function = get_predictor(loaded_model_type)
|
||
if not predictor_function:
|
||
raise ValueError(f"找不到模型类型 '{loaded_model_type}' 的预测器实现")
|
||
|
||
predictions_df, history_for_chart_df, prediction_input_df = predictor_function(
|
||
model=model,
|
||
checkpoint=checkpoint,
|
||
product_df=product_df,
|
||
future_days=future_days,
|
||
start_date=start_date,
|
||
history_lookback_days=history_lookback_days
|
||
)
|
||
|
||
# --- 分析与返回部分保持不变 ---
|
||
analysis = None
|
||
if analyze_result:
|
||
try:
|
||
analysis = analyze_prediction_result(product_id, loaded_model_type, predictions_df['predicted_sales'].values, prediction_input_df[config.get('features')].values)
|
||
except Exception as e:
|
||
print(f"分析预测结果失败: {str(e)}")
|
||
|
||
history_data_json = history_for_chart_df.to_dict('records') if not history_for_chart_df.empty else []
|
||
prediction_data_json = predictions_df.to_dict('records') if not predictions_df.empty else []
|
||
|
||
return {
|
||
'product_id': product_id,
|
||
'product_name': product_name,
|
||
'model_type': loaded_model_type,
|
||
'predictions': prediction_data_json,
|
||
'prediction_data': prediction_data_json,
|
||
'history_data': history_data_json,
|
||
'analysis': analysis
|
||
}
|
||
except Exception as e:
|
||
print(f"预测过程中出现未捕获的异常: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return None |