ShopTRAINING/server/predictors/model_predictor.py

184 lines
9.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
药店销售预测系统 - 模型预测函数
"""
import os
import torch
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
import sklearn.preprocessing._data # 添加这一行以支持MinMaxScaler的反序列化
from typing import Optional
from models.transformer_model import TimeSeriesTransformer
from models.slstm_model import sLSTM as ScalarLSTM
from models.mlstm_model import MLSTMTransformer as MatrixLSTM
from models.kan_model import KANForecaster
from models.tcn_model import TCNForecaster
from models.optimized_kan_forecaster import OptimizedKANForecaster
from analysis.trend_analysis import analyze_prediction_result
from utils.visualization import plot_prediction_results
from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data
from core.config import DEVICE, get_model_file_path, DEFAULT_DATA_PATH
def load_model_and_predict(model_path: str, product_id: str, model_type: str, store_id: Optional[str] = None, future_days: int = 7, start_date: Optional[str] = None, analyze_result: bool = False, version: Optional[str] = None, training_mode: str = 'product', history_lookback_days: int = 30):
"""
加载已训练的模型并进行预测 (v3版 - 支持自动回归)
参数:
... (同上, 新增 history_lookback_days)
history_lookback_days: 用于图表展示的历史数据天数
返回:
预测结果和分析
"""
try:
print(f"v3版预测函数启动模型路径: {model_path}, 预测天数: {future_days}, 历史回看: {history_lookback_days}")
if not os.path.exists(model_path):
print(f"模型文件 {model_path} 不存在")
return None
# 加载销售数据
from utils.multi_store_data_utils import aggregate_multi_store_data
if training_mode == 'store' and store_id:
product_df = aggregate_multi_store_data(store_id=store_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
product_name = product_df['store_name'].iloc[0] if not product_df.empty else f"店铺{store_id}"
elif training_mode == 'global':
product_df = aggregate_multi_store_data(aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
product_name = "全局销售数据"
else:
product_df = aggregate_multi_store_data(product_id=product_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
product_name = product_df['product_name'].iloc[0] if not product_df.empty else product_id
if product_df.empty:
print(f"产品 {product_id} 或店铺 {store_id} 没有销售数据")
return None
# 加载模型和配置
try:
torch.serialization.add_safe_globals([sklearn.preprocessing._data.MinMaxScaler])
except Exception: pass
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
if 'config' not in checkpoint or 'scaler_X' not in checkpoint or 'scaler_y' not in checkpoint:
print("模型文件不完整缺少config或scaler")
return None
config = checkpoint['config']
scaler_X = checkpoint['scaler_X']
scaler_y = checkpoint['scaler_y']
# 创建模型实例
# (此处省略了与原版本相同的模型创建代码,以保持简洁)
if model_type == 'transformer':
model = TimeSeriesTransformer(num_features=config['input_dim'], d_model=config['hidden_size'], nhead=config['num_heads'], num_encoder_layers=config['num_layers'], dim_feedforward=config['hidden_size'] * 2, dropout=config['dropout'], output_sequence_length=config['output_dim'], seq_length=config['sequence_length'], batch_size=32).to(DEVICE)
elif model_type == 'mlstm':
model = MatrixLSTM(num_features=config['input_dim'], hidden_size=config['hidden_size'], mlstm_layers=config['mlstm_layers'], embed_dim=config.get('embed_dim', 32), dense_dim=config.get('dense_dim', 32), num_heads=config.get('num_heads', 4), dropout_rate=config['dropout_rate'], num_blocks=config.get('num_blocks', 3), output_sequence_length=config['output_dim']).to(DEVICE)
elif model_type == 'kan':
model = KANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
elif model_type == 'optimized_kan':
model = OptimizedKANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
elif model_type == 'tcn':
model = TCNForecaster(num_features=config['input_dim'], output_sequence_length=config['output_dim'], num_channels=[config['hidden_size']] * config['num_layers'], kernel_size=config['kernel_size'], dropout=config['dropout']).to(DEVICE)
else:
print(f"不支持的模型类型: {model_type}"); return None
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# --- 核心逻辑修改:自动回归预测 ---
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
sequence_length = config['sequence_length']
# 确定预测的起始点
if start_date:
start_date_dt = pd.to_datetime(start_date)
# 获取预测开始日期前的 `sequence_length` 天数据作为初始输入
prediction_input_df = product_df[product_df['date'] < start_date_dt].tail(sequence_length)
else:
# 如果未指定开始日期,则从数据的最后一天开始预测
prediction_input_df = product_df.tail(sequence_length)
start_date_dt = product_df['date'].iloc[-1] + timedelta(days=1)
if len(prediction_input_df) < sequence_length:
print(f"错误: 预测所需的历史数据不足。需要 {sequence_length} 天, 但只有 {len(prediction_input_df)} 天。")
return None
# 准备用于图表展示的历史数据
history_for_chart_df = product_df[product_df['date'] < start_date_dt].tail(history_lookback_days)
# 自动回归预测循环
all_predictions = []
current_sequence_df = prediction_input_df.copy()
print(f"开始自动回归预测,共 {future_days} 天...")
for i in range(future_days):
# 准备当前序列的输入张量
X_current_scaled = scaler_X.transform(current_sequence_df[features].values)
X_input = torch.tensor(X_current_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
# 模型进行一次预测(可能预测出多个点,但我们只用第一个)
with torch.no_grad():
y_pred_scaled = model(X_input).cpu().numpy()
# 提取下一个时间点的预测值
next_step_pred_scaled = y_pred_scaled[0, 0].reshape(1, -1)
next_step_pred_unscaled = scaler_y.inverse_transform(next_step_pred_scaled)[0][0]
next_step_pred_unscaled = float(max(0, next_step_pred_unscaled)) # 确保销量不为负并转换为标准float
# 获取新预测的日期
next_date = current_sequence_df['date'].iloc[-1] + timedelta(days=1)
all_predictions.append({'date': next_date, 'predicted_sales': next_step_pred_unscaled})
# 构建新的一行数据,用于更新输入序列
new_row = {
'date': next_date,
'sales': next_step_pred_unscaled,
'weekday': next_date.weekday(),
'month': next_date.month,
'is_holiday': 0,
'is_weekend': 1 if next_date.weekday() >= 5 else 0,
'is_promotion': 0,
'temperature': current_sequence_df['temperature'].iloc[-1] # 沿用最后一天的温度
}
# 更新序列:移除最旧的一行,添加最新预测的一行
new_row_df = pd.DataFrame([new_row])
current_sequence_df = pd.concat([current_sequence_df.iloc[1:], new_row_df], ignore_index=True)
predictions_df = pd.DataFrame(all_predictions)
print(f"自动回归预测完成,生成 {len(predictions_df)} 条预测数据。")
# 分析与可视化
analysis = None
if analyze_result:
try:
y_pred_for_analysis = predictions_df['predicted_sales'].values
# 使用初始输入序列的特征进行分析
initial_features_for_analysis = prediction_input_df[features].values
analysis = analyze_prediction_result(product_id, model_type, y_pred_for_analysis, initial_features_for_analysis)
except Exception as e:
print(f"分析预测结果失败: {str(e)}")
# 在返回前将DataFrame转换为前端期望的JSON数组格式
history_data_json = history_for_chart_df.to_dict('records') if not history_for_chart_df.empty else []
prediction_data_json = predictions_df.to_dict('records') if not predictions_df.empty else []
return {
'product_id': product_id,
'product_name': product_name,
'model_type': model_type,
'predictions': prediction_data_json, # 兼容旧字段使用已转换的json
'prediction_data': prediction_data_json,
'history_data': history_data_json,
'analysis': analysis
}
except Exception as e:
print(f"预测过程中出现未捕获的异常: {str(e)}")
import traceback
traceback.print_exc()
return None