ShopTRAINING/server/predictors/model_predictor.py

260 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
药店销售预测系统 - 模型预测函数
"""
import os
import torch
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
import sklearn.preprocessing._data # 添加这一行以支持MinMaxScaler的反序列化
from typing import Optional
from models.transformer_model import TimeSeriesTransformer
from models.slstm_model import sLSTM as ScalarLSTM
from models.mlstm_model import MLSTMTransformer as MatrixLSTM
from models.kan_model import KANForecaster
from models.tcn_model import TCNForecaster
from models.optimized_kan_forecaster import OptimizedKANForecaster
from models.cnn_bilstm_attention import CnnBiLstmAttention
import xgboost as xgb
from analysis.trend_analysis import analyze_prediction_result
from utils.visualization import plot_prediction_results
from utils.new_data_loader import load_new_data
from core.config import DEVICE, get_model_file_path, DEFAULT_DATA_PATH
from models.model_registry import get_predictor, register_predictor
def default_pytorch_predictor(model, checkpoint, product_df, future_days, start_date, history_lookback_days):
"""
默认的PyTorch模型预测逻辑支持自动回归。
"""
config = checkpoint['config']
scaler_X = checkpoint['scaler_X']
scaler_y = checkpoint['scaler_y']
features = config.get('features', ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'])
sequence_length = config['sequence_length']
if start_date:
start_date_dt = pd.to_datetime(start_date)
prediction_input_df = product_df[product_df['date'] < start_date_dt].tail(sequence_length)
else:
prediction_input_df = product_df.tail(sequence_length)
start_date_dt = product_df['date'].iloc[-1] + timedelta(days=1)
if len(prediction_input_df) < sequence_length:
raise ValueError(f"预测所需的历史数据不足。需要 {sequence_length} 天, 但只有 {len(prediction_input_df)} 天。")
history_for_chart_df = product_df[product_df['date'] < start_date_dt].tail(history_lookback_days)
# --- 重构预测逻辑以正确处理XGBoost ---
if isinstance(model, xgb.Booster):
# --- XGBoost 预测路径 (非回归式,一次性预测) ---
X_current_scaled = scaler_X.transform(prediction_input_df[features].values)
X_input_reshaped = X_current_scaled.reshape(1, -1)
d_input = xgb.DMatrix(X_input_reshaped)
# 一次性获取所有未来天数的预测
y_pred_scaled = model.predict(d_input, iteration_range=(0, model.best_iteration))
# 反归一化整个序列
y_pred_unscaled = scaler_y.inverse_transform(y_pred_scaled.reshape(1, -1)).flatten()
y_pred_unscaled = np.maximum(0, y_pred_unscaled) # 确保销量不为负
# 生成未来日期序列
# 修正: 未来日期的数量必须与模型实际输出的预测点数量一致
# 而不是遵循用户输入的 future_days因为XGBoost模型输出的长度是固定的。
future_dates = pd.date_range(start=start_date_dt, periods=len(y_pred_unscaled))
# 直接构建结果DataFrame
predictions_df = pd.DataFrame({
'date': future_dates,
'predicted_sales': y_pred_unscaled
})
elif isinstance(model, CnnBiLstmAttention):
# --- CnnBiLstmAttention 预测路径 (非回归式,一次性预测) ---
X_current_scaled = scaler_X.transform(prediction_input_df[features].values)
X_input = torch.tensor(X_current_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
with torch.no_grad():
# 一次性获取所有未来天数的预测
y_pred_scaled = model(X_input).cpu().numpy()
# 反归一化整个序列
y_pred_unscaled = scaler_y.inverse_transform(y_pred_scaled).flatten()
y_pred_unscaled = np.maximum(0, y_pred_unscaled) # 确保销量不为负
# 生成未来日期序列,其长度与模型实际输出的预测点数量一致
future_dates = pd.date_range(start=start_date_dt, periods=len(y_pred_unscaled))
# 直接构建结果DataFrame
predictions_df = pd.DataFrame({
'date': future_dates,
'predicted_sales': y_pred_unscaled
})
else:
# --- 默认 PyTorch 模型预测路径 (自回归式) ---
all_predictions = []
current_sequence_df = prediction_input_df.copy()
for _ in range(future_days):
X_current_scaled = scaler_X.transform(current_sequence_df[features].values)
X_input = torch.tensor(X_current_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
with torch.no_grad():
y_pred_scaled = model(X_input).cpu().numpy()
next_step_pred_scaled = y_pred_scaled[0, 0].reshape(1, -1)
next_step_pred_unscaled = float(max(0, scaler_y.inverse_transform(next_step_pred_scaled)[0][0]))
next_date = current_sequence_df['date'].iloc[-1] + timedelta(days=1)
all_predictions.append({'date': next_date, 'predicted_sales': next_step_pred_unscaled})
new_row = {'date': next_date, 'sales': next_step_pred_unscaled, 'weekday': next_date.weekday(), 'month': next_date.month, 'is_holiday': 0, 'is_weekend': 1 if next_date.weekday() >= 5 else 0, 'is_promotion': 0, 'temperature': current_sequence_df['temperature'].iloc[-1]}
new_row_df = pd.DataFrame([new_row])
current_sequence_df = pd.concat([current_sequence_df.iloc[1:], new_row_df], ignore_index=True)
predictions_df = pd.DataFrame(all_predictions)
return predictions_df, history_for_chart_df, prediction_input_df
# 注册默认的PyTorch预测器
register_predictor('default', default_pytorch_predictor)
# 将增强后的默认预测器也注册给xgboost
register_predictor('xgboost', default_pytorch_predictor)
# 将新模型也注册给默认预测器
register_predictor('cnn_bilstm_attention', default_pytorch_predictor)
def load_model_and_predict(model_path: str, product_id: str, model_type: str, store_id: Optional[str] = None, future_days: int = 7, start_date: Optional[str] = None, analyze_result: bool = False, version: Optional[str] = None, training_mode: str = 'product', history_lookback_days: int = 30):
"""
加载已训练的模型并进行预测 (v4版 - 插件式架构)
"""
try:
if not os.path.exists(model_path):
raise FileNotFoundError(f"模型文件 {model_path} 不存在")
# --- 数据加载重构 ---
# 统一使用新的数据加载器,确保与训练时的数据源和处理逻辑完全一致
print("正在使用新的统一数据加载器进行预测...")
full_df = load_new_data()
if training_mode == 'store' and store_id:
store_df = full_df[full_df['store_id'] == store_id].copy()
# 判断是为单个产品预测还是为整个店铺聚合预测
if product_id and product_id != 'unknown' and product_id != 'all_products':
product_df = store_df[store_df['product_id'] == product_id].copy()
product_name = f"店铺 {store_id} - 产品 {product_id}"
else:
# 为整个店铺的聚合销售额进行预测
product_df = store_df.groupby('date').agg({
'sales': 'sum',
'weekday': 'first', 'month': 'first', 'is_holiday': 'first',
'is_weekend': 'first', 'is_promotion': 'first', 'temperature': 'mean'
}).reset_index()
product_name = f"店铺 {store_id} (所有药品聚合)"
elif training_mode == 'global':
# 修正全局预测的数据加载逻辑
if product_id and product_id not in ['unknown', 'all_products']:
# 如果提供了具体产品ID虽然全局模式下不常见但应兼容则聚合该产品的跨店数据
product_df = full_df[full_df['product_id'] == product_id].copy()
product_name = f"全局聚合 - 产品 {product_id}"
else:
# 如果是“所有药品”的全局预测,则聚合所有数据
product_df = full_df.copy()
product_name = "全局聚合 (所有药品)"
product_df = product_df.groupby('date').agg({
'sales': 'sum', # 默认使用sum未来可配置
'weekday': 'first', 'month': 'first', 'is_holiday': 'first',
'is_weekend': 'first', 'is_promotion': 'first', 'temperature': 'mean'
}).reset_index()
else: # 默认 'product' 模式
product_df = full_df[full_df['product_id'] == product_id].copy()
# 兼容性处理:新数据可能没有 product_name 列
if 'product_name' in product_df.columns and not product_df['product_name'].empty:
product_name = product_df['product_name'].iloc[0]
else:
product_name = f"Product {product_id}"
if product_df.empty:
raise ValueError(f"产品 {product_id} 或店铺 {store_id} 没有销售数据")
# --- 模型加载与实例化 (重构) ---
try:
torch.serialization.add_safe_globals([sklearn.preprocessing._data.MinMaxScaler])
except Exception: pass
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
config = checkpoint.get('config', {})
loaded_model_type = config.get('model_type', model_type) # 优先使用模型内保存的类型
# 根据模型类型决定如何获取模型实例
if loaded_model_type == 'xgboost':
# 对于XGBoost, 模型对象直接保存在'model_state_dict'键中
model = checkpoint['model_state_dict']
else:
# 对于PyTorch模型, 需要重新构建实例并加载state_dict
if loaded_model_type == 'transformer':
model = TimeSeriesTransformer(num_features=config['input_dim'], d_model=config['hidden_size'], nhead=config['num_heads'], num_encoder_layers=config['num_layers'], dim_feedforward=config['hidden_size'] * 2, dropout=config['dropout'], output_sequence_length=config['output_dim'], seq_length=config['sequence_length'], batch_size=32).to(DEVICE)
elif loaded_model_type == 'mlstm':
model = MatrixLSTM(num_features=config['input_dim'], hidden_size=config['hidden_size'], mlstm_layers=config['mlstm_layers'], embed_dim=config.get('embed_dim', 32), dense_dim=config.get('dense_dim', 32), num_heads=config.get('num_heads', 4), dropout_rate=config['dropout_rate'], num_blocks=config.get('num_blocks', 3), output_sequence_length=config['output_dim']).to(DEVICE)
elif loaded_model_type == 'kan':
model = KANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
elif loaded_model_type == 'optimized_kan':
model = OptimizedKANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
elif loaded_model_type == 'tcn':
model = TCNForecaster(num_features=config['input_dim'], output_sequence_length=config['output_dim'], num_channels=[config['hidden_size']] * config['num_layers'], kernel_size=config['kernel_size'], dropout=config['dropout']).to(DEVICE)
elif loaded_model_type == 'cnn_bilstm_attention':
model = CnnBiLstmAttention(
input_dim=config['input_dim'],
output_dim=config['output_dim'],
sequence_length=config['sequence_length']
).to(DEVICE)
else:
raise ValueError(f"不支持的模型类型: {loaded_model_type}")
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# --- 动态调用预测器 ---
predictor_function = get_predictor(loaded_model_type)
if not predictor_function:
raise ValueError(f"找不到模型类型 '{loaded_model_type}' 的预测器实现")
predictions_df, history_for_chart_df, prediction_input_df = predictor_function(
model=model,
checkpoint=checkpoint,
product_df=product_df,
future_days=future_days,
start_date=start_date,
history_lookback_days=history_lookback_days
)
# --- 分析与返回部分保持不变 ---
analysis = None
if analyze_result:
try:
analysis = analyze_prediction_result(product_id, loaded_model_type, predictions_df['predicted_sales'].values, prediction_input_df[config.get('features')].values)
except Exception as e:
print(f"分析预测结果失败: {str(e)}")
history_data_json = history_for_chart_df.to_dict('records') if not history_for_chart_df.empty else []
prediction_data_json = predictions_df.to_dict('records') if not predictions_df.empty else []
return {
'product_id': product_id,
'product_name': product_name,
'model_type': loaded_model_type,
'predictions': prediction_data_json,
'prediction_data': prediction_data_json,
'history_data': history_data_json,
'analysis': analysis
}
except Exception as e:
print(f"预测过程中出现未捕获的异常: {str(e)}")
import traceback
traceback.print_exc()
return None