ShopTRAINING/server/predictors/model_predictor.py
2025-07-22 15:41:05 +08:00

191 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
药店销售预测系统 - 模型预测函数
"""
import os
import torch
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
import sklearn.preprocessing._data # 添加这一行以支持MinMaxScaler的反序列化
from typing import Optional
from models.transformer_model import TimeSeriesTransformer
from models.slstm_model import sLSTM as ScalarLSTM
from models.mlstm_model import MLSTMTransformer as MatrixLSTM
from models.kan_model import KANForecaster
from models.tcn_model import TCNForecaster
from models.optimized_kan_forecaster import OptimizedKANForecaster
from models.cnn_bilstm_attention import CnnBiLstmAttention
import xgboost as xgb
from analysis.trend_analysis import analyze_prediction_result
from utils.visualization import plot_prediction_results
from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data
from core.config import DEVICE, get_model_file_path, DEFAULT_DATA_PATH
from models.model_registry import get_predictor, register_predictor
def default_pytorch_predictor(model, checkpoint, product_df, future_days, start_date, history_lookback_days):
"""
默认的PyTorch模型预测逻辑支持自动回归。
"""
config = checkpoint['config']
scaler_X = checkpoint['scaler_X']
scaler_y = checkpoint['scaler_y']
features = config.get('features', ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'])
sequence_length = config['sequence_length']
if start_date:
start_date_dt = pd.to_datetime(start_date)
prediction_input_df = product_df[product_df['date'] < start_date_dt].tail(sequence_length)
else:
prediction_input_df = product_df.tail(sequence_length)
start_date_dt = product_df['date'].iloc[-1] + timedelta(days=1)
if len(prediction_input_df) < sequence_length:
raise ValueError(f"预测所需的历史数据不足。需要 {sequence_length} 天, 但只有 {len(prediction_input_df)} 天。")
history_for_chart_df = product_df[product_df['date'] < start_date_dt].tail(history_lookback_days)
all_predictions = []
current_sequence_df = prediction_input_df.copy()
for _ in range(future_days):
X_current_scaled = scaler_X.transform(current_sequence_df[features].values)
# **核心改进**: 智能判断模型类型并调用相应的预测方法
if isinstance(model, xgb.Booster):
# XGBoost 模型预测路径
X_input_reshaped = X_current_scaled.reshape(1, -1)
d_input = xgb.DMatrix(X_input_reshaped)
# **关键修复**: 使用 best_iteration 进行预测,以匹配早停策略
y_pred_scaled = model.predict(d_input, iteration_range=(0, model.best_iteration))
next_step_pred_scaled = y_pred_scaled.reshape(1, -1)
else:
# 默认 PyTorch 模型预测路径
X_input = torch.tensor(X_current_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
with torch.no_grad():
y_pred_scaled = model(X_input).cpu().numpy()
next_step_pred_scaled = y_pred_scaled[0, 0].reshape(1, -1)
next_step_pred_unscaled = float(max(0, scaler_y.inverse_transform(next_step_pred_scaled)[0][0]))
next_date = current_sequence_df['date'].iloc[-1] + timedelta(days=1)
all_predictions.append({'date': next_date, 'predicted_sales': next_step_pred_unscaled})
new_row = {'date': next_date, 'sales': next_step_pred_unscaled, 'weekday': next_date.weekday(), 'month': next_date.month, 'is_holiday': 0, 'is_weekend': 1 if next_date.weekday() >= 5 else 0, 'is_promotion': 0, 'temperature': current_sequence_df['temperature'].iloc[-1]}
new_row_df = pd.DataFrame([new_row])
current_sequence_df = pd.concat([current_sequence_df.iloc[1:], new_row_df], ignore_index=True)
predictions_df = pd.DataFrame(all_predictions)
return predictions_df, history_for_chart_df, prediction_input_df
# 注册默认的PyTorch预测器
register_predictor('default', default_pytorch_predictor)
# 将增强后的默认预测器也注册给xgboost
register_predictor('xgboost', default_pytorch_predictor)
# 将新模型也注册给默认预测器
register_predictor('cnn_bilstm_attention', default_pytorch_predictor)
def load_model_and_predict(model_path: str, product_id: str, model_type: str, store_id: Optional[str] = None, future_days: int = 7, start_date: Optional[str] = None, analyze_result: bool = False, version: Optional[str] = None, training_mode: str = 'product', history_lookback_days: int = 30):
"""
加载已训练的模型并进行预测 (v4版 - 插件式架构)
"""
try:
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件 {model_path} 不存在")
# --- 数据加载部分保持不变 ---
from utils.multi_store_data_utils import aggregate_multi_store_data
if training_mode == 'store' and store_id:
from utils.multi_store_data_utils import load_multi_store_data
store_df_for_name = load_multi_store_data(store_id=store_id)
product_name = store_df_for_name['store_name'].iloc[0] if not store_df_for_name.empty else f"店铺 {store_id}"
product_df = aggregate_multi_store_data(store_id=store_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
elif training_mode == 'global':
product_df = aggregate_multi_store_data(aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
product_name = "全局销售数据"
else:
product_df = aggregate_multi_store_data(product_id=product_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
product_name = product_df['product_name'].iloc[0] if not product_df.empty else product_id
if product_df.empty:
raise ValueError(f"产品 {product_id} 或店铺 {store_id} 没有销售数据")
# --- 模型加载与实例化 (重构) ---
try:
torch.serialization.add_safe_globals([sklearn.preprocessing._data.MinMaxScaler])
except Exception: pass
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
config = checkpoint.get('config', {})
loaded_model_type = config.get('model_type', model_type) # 优先使用模型内保存的类型
# 根据模型类型决定如何获取模型实例
if loaded_model_type == 'xgboost':
# 对于XGBoost, 模型对象直接保存在'model_state_dict'键中
model = checkpoint['model_state_dict']
else:
# 对于PyTorch模型, 需要重新构建实例并加载state_dict
if loaded_model_type == 'transformer':
model = TimeSeriesTransformer(num_features=config['input_dim'], d_model=config['hidden_size'], nhead=config['num_heads'], num_encoder_layers=config['num_layers'], dim_feedforward=config['hidden_size'] * 2, dropout=config['dropout'], output_sequence_length=config['output_dim'], seq_length=config['sequence_length'], batch_size=32).to(DEVICE)
elif loaded_model_type == 'mlstm':
model = MatrixLSTM(num_features=config['input_dim'], hidden_size=config['hidden_size'], mlstm_layers=config['mlstm_layers'], embed_dim=config.get('embed_dim', 32), dense_dim=config.get('dense_dim', 32), num_heads=config.get('num_heads', 4), dropout_rate=config['dropout_rate'], num_blocks=config.get('num_blocks', 3), output_sequence_length=config['output_dim']).to(DEVICE)
elif loaded_model_type == 'kan':
model = KANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
elif loaded_model_type == 'optimized_kan':
model = OptimizedKANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
elif loaded_model_type == 'tcn':
model = TCNForecaster(num_features=config['input_dim'], output_sequence_length=config['output_dim'], num_channels=[config['hidden_size']] * config['num_layers'], kernel_size=config['kernel_size'], dropout=config['dropout']).to(DEVICE)
elif loaded_model_type == 'cnn_bilstm_attention':
model = CnnBiLstmAttention(
input_dim=config['input_dim'],
output_dim=config['output_dim'],
sequence_length=config['sequence_length']
).to(DEVICE)
else:
raise ValueError(f"不支持的模型类型: {loaded_model_type}")
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# --- 动态调用预测器 ---
predictor_function = get_predictor(loaded_model_type)
if not predictor_function:
raise ValueError(f"找不到模型类型 '{loaded_model_type}' 的预测器实现")
predictions_df, history_for_chart_df, prediction_input_df = predictor_function(
model=model,
checkpoint=checkpoint,
product_df=product_df,
future_days=future_days,
start_date=start_date,
history_lookback_days=history_lookback_days
)
# --- 分析与返回部分保持不变 ---
analysis = None
if analyze_result:
try:
analysis = analyze_prediction_result(product_id, loaded_model_type, predictions_df['predicted_sales'].values, prediction_input_df[config.get('features')].values)
except Exception as e:
print(f"分析预测结果失败: {str(e)}")
history_data_json = history_for_chart_df.to_dict('records') if not history_for_chart_df.empty else []
prediction_data_json = predictions_df.to_dict('records') if not predictions_df.empty else []
return {
'product_id': product_id,
'product_name': product_name,
'model_type': loaded_model_type,
'predictions': prediction_data_json,
'prediction_data': prediction_data_json,
'history_data': history_data_json,
'analysis': analysis
}
except Exception as e:
print(f"预测过程中出现未捕获的异常: {str(e)}")
import traceback
traceback.print_exc()
return None