ShopTRAINING/server/predictors/model_predictor.py

246 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
药店销售预测系统 - 模型预测函数
"""
import os
import torch
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
import sklearn.preprocessing._data # 添加这一行以支持MinMaxScaler的反序列化
from typing import Optional
from models.transformer_model import TimeSeriesTransformer
from models.slstm_model import sLSTM as ScalarLSTM
from models.mlstm_model import MLSTMTransformer as MatrixLSTM
from models.kan_model import KANForecaster
from models.tcn_model import TCNForecaster
from models.optimized_kan_forecaster import OptimizedKANForecaster
from models.cnn_bilstm_attention import CnnBiLstmAttention
import xgboost as xgb
from analysis.trend_analysis import analyze_prediction_result
from utils.visualization import plot_prediction_results
from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data
from core.config import DEVICE, get_model_file_path, DEFAULT_DATA_PATH
from models.model_registry import get_predictor, register_predictor
from utils.feature_selection import get_feature_list_for_model
def default_pytorch_predictor(model, checkpoint, product_df, future_days, start_date, history_lookback_days):
"""
默认的PyTorch模型预测逻辑支持自动回归。
"""
config = checkpoint['config']
scaler_X = checkpoint['scaler_X']
scaler_y = checkpoint['scaler_y']
# --- FINAL FIX v2: Prioritize feature list from the model's own config ---
if 'features' in config and config['features']:
features = config['features']
elif hasattr(scaler_X, 'feature_names_in_'):
# Fallback for newer models
features = scaler_X.feature_names_in_
else:
# Last resort, which is known to be unreliable for some models
print("⚠️ WARNING: Could not find feature list in config or scaler. Falling back to dynamic feature selection.")
model_type = config.get('model_type', 'default')
features = get_feature_list_for_model(model_type, product_df.columns.tolist())
sequence_length = config['sequence_length']
if start_date:
start_date_dt = pd.to_datetime(start_date)
prediction_input_df = product_df[product_df['date'] < start_date_dt].tail(sequence_length)
else:
prediction_input_df = product_df.tail(sequence_length)
start_date_dt = product_df['date'].iloc[-1] + timedelta(days=1)
if len(prediction_input_df) < sequence_length:
raise ValueError(f"预测所需的历史数据不足。需要 {sequence_length} 天, 但只有 {len(prediction_input_df)} 天。")
history_for_chart_df = product_df[product_df['date'] < start_date_dt].tail(history_lookback_days).copy()
all_predictions = []
current_sequence_df = prediction_input_df.copy()
for _ in range(future_days):
# --- FINAL FIX v3: Replicate the numeric feature filtering from the training process ---
# The scaler was only fitted on numeric columns. We must select only those to match.
numeric_features = current_sequence_df[features].select_dtypes(include=np.number).columns.tolist()
# This ensures the number of columns precisely matches the scaler's expectations.
features_df = current_sequence_df[numeric_features]
# Coercion and fillna are still good for robustness against unexpected non-numeric strings.
features_df = features_df.apply(pd.to_numeric, errors='coerce').fillna(0)
X_current_scaled = scaler_X.transform(features_df.values)
# **核心改进**: 智能判断模型类型并调用相应的预测方法
if isinstance(model, xgb.Booster):
# XGBoost 模型预测路径
X_input_reshaped = X_current_scaled.reshape(1, -1)
d_input = xgb.DMatrix(X_input_reshaped)
# **关键修复**: 使用 best_iteration 进行预测,以匹配早停策略
y_pred_scaled = model.predict(d_input, iteration_range=(0, model.best_iteration))
next_step_pred_scaled = y_pred_scaled.reshape(1, -1)
else:
# 默认 PyTorch 模型预测路径
X_input = torch.tensor(X_current_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
with torch.no_grad():
y_pred_scaled = model(X_input).cpu().numpy()
next_step_pred_scaled = y_pred_scaled[0, 0].reshape(1, -1)
next_step_pred_unscaled = float(max(0, scaler_y.inverse_transform(next_step_pred_scaled)[0][0]))
next_date = current_sequence_df['date'].iloc[-1] + timedelta(days=1)
all_predictions.append({'date': next_date, 'predicted_sales': next_step_pred_unscaled})
# --- FIX: Create a complete new row for autoregression ---
# Start by copying the last known row to preserve all static features
new_row_series = current_sequence_df.iloc[-1].copy()
# Update time-varying and predicted features
new_row_series['date'] = next_date
new_row_series['sales'] = next_step_pred_unscaled # Use the predicted sales
new_row_series['weekday'] = next_date.weekday()
new_row_series['month'] = next_date.month
new_row_series['is_holiday'] = 0 # Assume future days are not holidays
new_row_series['is_weekend'] = 1 if next_date.weekday() >= 5 else 0
# Note: Rolling features are not recalculated here, we are using the last known values.
# This is a simplification for prediction. For higher accuracy, they should be recalculated.
new_row_df = pd.DataFrame([new_row_series])
current_sequence_df = pd.concat([current_sequence_df.iloc[1:], new_row_df], ignore_index=True)
predictions_df = pd.DataFrame(all_predictions)
return predictions_df, history_for_chart_df, prediction_input_df
# 注册默认的PyTorch预测器
register_predictor('default', default_pytorch_predictor)
# 将增强后的默认预测器也注册给xgboost
register_predictor('xgboost', default_pytorch_predictor)
# 将新模型也注册给默认预测器
register_predictor('cnn_bilstm_attention', default_pytorch_predictor)
register_predictor('transformer', default_pytorch_predictor)
register_predictor('mlstm', default_pytorch_predictor)
register_predictor('kan', default_pytorch_predictor)
register_predictor('optimized_kan', default_pytorch_predictor)
register_predictor('tcn', default_pytorch_predictor)
def load_model_and_predict(model_path: str, product_id: str, model_type: str, store_id: Optional[str] = None, future_days: int = 7, start_date: Optional[str] = None, analyze_result: bool = False, version: Optional[str] = None, training_mode: str = 'product', history_lookback_days: int = 30):
"""
加载已训练的模型并进行预测 (v4版 - 插件式架构)
"""
try:
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件 {model_path} 不存在")
# --- 数据加载部分保持不变 ---
from utils.multi_store_data_utils import load_multi_store_data
if training_mode == 'store' and store_id:
product_df = load_multi_store_data(store_id=store_id, file_path=DEFAULT_DATA_PATH)
if not product_df.empty and 'store_name' in product_df.columns:
product_name = product_df['store_name'].iloc[0]
else:
product_name = f"店铺 {store_id}"
elif training_mode == 'global':
product_df = load_multi_store_data(file_path=DEFAULT_DATA_PATH)
product_name = "全局销售数据"
else:
product_df = load_multi_store_data(product_id=product_id, file_path=DEFAULT_DATA_PATH)
if not product_df.empty and 'product_name' in product_df.columns:
product_name = product_df['product_name'].iloc[0]
else:
product_name = product_id
if product_df.empty:
raise ValueError(f"产品 {product_id} 或店铺 {store_id} 没有销售数据")
# --- 模型加载与实例化 (重构) ---
try:
torch.serialization.add_safe_globals([sklearn.preprocessing._data.MinMaxScaler])
except Exception: pass
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
config = checkpoint.get('config', {})
loaded_model_type = config.get('model_type', model_type) # 优先使用模型内保存的类型
# 根据模型类型决定如何获取模型实例
if loaded_model_type == 'xgboost':
# 对于XGBoost, 模型对象直接保存在'model_state_dict'键中
model = checkpoint['model_state_dict']
else:
# 对于PyTorch模型, 需要重新构建实例并加载state_dict
if loaded_model_type == 'transformer':
model = TimeSeriesTransformer(num_features=config['input_dim'], d_model=config['hidden_size'], nhead=config['num_heads'], num_encoder_layers=config['num_layers'], dim_feedforward=config['hidden_size'] * 2, dropout=config['dropout'], output_sequence_length=config['output_dim'], seq_length=config['sequence_length'], batch_size=32).to(DEVICE)
elif loaded_model_type == 'mlstm':
model = MatrixLSTM(num_features=config['input_dim'], hidden_size=config['hidden_size'], mlstm_layers=config['mlstm_layers'], embed_dim=config.get('embed_dim', 32), dense_dim=config.get('dense_dim', 32), num_heads=config.get('num_heads', 4), dropout_rate=config['dropout_rate'], num_blocks=config.get('num_blocks', 3), output_sequence_length=config['output_dim']).to(DEVICE)
elif loaded_model_type == 'kan':
model = KANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
elif loaded_model_type == 'optimized_kan':
model = OptimizedKANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
elif loaded_model_type == 'tcn':
model = TCNForecaster(num_features=config['input_dim'], output_sequence_length=config['output_dim'], num_channels=[config['hidden_size']] * config['num_layers'], kernel_size=config['kernel_size'], dropout=config['dropout']).to(DEVICE)
elif loaded_model_type == 'cnn_bilstm_attention':
model = CnnBiLstmAttention(
input_dim=config['input_dim'],
output_dim=config['output_dim'],
sequence_length=config['sequence_length']
).to(DEVICE)
else:
raise ValueError(f"不支持的模型类型: {loaded_model_type}")
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# --- 动态调用预测器 ---
predictor_function = get_predictor(loaded_model_type)
if not predictor_function:
raise ValueError(f"找不到模型类型 '{loaded_model_type}' 的预测器实现")
predictions_df, history_for_chart_df, prediction_input_df = predictor_function(
model=model,
checkpoint=checkpoint,
product_df=product_df,
future_days=future_days,
start_date=start_date,
history_lookback_days=history_lookback_days
)
# --- 分析与返回部分保持不变 ---
analysis = None
if analyze_result:
try:
analysis = analyze_prediction_result(product_id, loaded_model_type, predictions_df['predicted_sales'].values, prediction_input_df[config.get('features')].values)
except Exception as e:
print(f"分析预测结果失败: {str(e)}")
# --- FINAL FIX v5: Ensure ALL date columns are strings before JSON serialization ---
if not history_for_chart_df.empty and 'date' in history_for_chart_df.columns:
# The history can now contain mixed types (Timestamp and datetime.date), so handle robustly.
history_for_chart_df['date'] = pd.to_datetime(history_for_chart_df['date']).dt.strftime('%Y-%m-%d')
if not predictions_df.empty and 'date' in predictions_df.columns:
predictions_df['date'] = pd.to_datetime(predictions_df['date']).dt.strftime('%Y-%m-%d')
history_data_json = history_for_chart_df.to_dict('records') if not history_for_chart_df.empty else []
prediction_data_json = predictions_df.to_dict('records') if not predictions_df.empty else []
return {
'product_id': product_id,
'product_name': product_name,
'model_type': loaded_model_type,
'predictions': prediction_data_json,
'prediction_data': prediction_data_json,
'history_data': history_data_json,
'analysis': analysis
}
except Exception as e:
print(f"预测过程中出现未捕获的异常: {str(e)}")
import traceback
traceback.print_exc()
return None