189 lines
10 KiB
Python
189 lines
10 KiB
Python
"""
|
||
药店销售预测系统 - 模型预测函数
|
||
"""
|
||
|
||
import os
|
||
import torch
|
||
import pandas as pd
|
||
import numpy as np
|
||
from datetime import datetime, timedelta
|
||
import matplotlib.pyplot as plt
|
||
from sklearn.preprocessing import MinMaxScaler
|
||
import sklearn.preprocessing._data # 添加这一行以支持MinMaxScaler的反序列化
|
||
from typing import Optional
|
||
|
||
from models.transformer_model import TimeSeriesTransformer
|
||
from models.slstm_model import sLSTM as ScalarLSTM
|
||
from models.mlstm_model import MLSTMTransformer as MatrixLSTM
|
||
from models.kan_model import KANForecaster
|
||
from models.tcn_model import TCNForecaster
|
||
from models.optimized_kan_forecaster import OptimizedKANForecaster
|
||
|
||
from analysis.trend_analysis import analyze_prediction_result
|
||
from utils.visualization import plot_prediction_results
|
||
from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data
|
||
from core.config import DEVICE, get_model_file_path, DEFAULT_DATA_PATH
|
||
|
||
def load_model_and_predict(model_path: str, product_id: str, model_type: str, store_id: Optional[str] = None, future_days: int = 7, start_date: Optional[str] = None, analyze_result: bool = False, version: Optional[str] = None, training_mode: str = 'product', history_lookback_days: int = 30):
|
||
"""
|
||
加载已训练的模型并进行预测 (v3版 - 支持自动回归)
|
||
|
||
参数:
|
||
... (同上, 新增 history_lookback_days)
|
||
history_lookback_days: 用于图表展示的历史数据天数
|
||
|
||
返回:
|
||
预测结果和分析
|
||
"""
|
||
try:
|
||
print(f"v3版预测函数启动,模型路径: {model_path}, 预测天数: {future_days}, 历史回看: {history_lookback_days}")
|
||
|
||
if not os.path.exists(model_path):
|
||
print(f"模型文件 {model_path} 不存在")
|
||
return None
|
||
|
||
# 加载销售数据
|
||
from utils.multi_store_data_utils import aggregate_multi_store_data
|
||
if training_mode == 'store' and store_id:
|
||
# 先从原始数据加载一次以获取店铺名称,聚合会丢失此信息
|
||
from utils.multi_store_data_utils import load_multi_store_data
|
||
store_df_for_name = load_multi_store_data(store_id=store_id)
|
||
product_name = store_df_for_name['store_name'].iloc[0] if not store_df_for_name.empty else f"店铺 {store_id}"
|
||
|
||
# 然后再进行聚合获取用于预测的数据
|
||
product_df = aggregate_multi_store_data(store_id=store_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
|
||
elif training_mode == 'global':
|
||
product_df = aggregate_multi_store_data(aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
|
||
product_name = "全局销售数据"
|
||
else:
|
||
product_df = aggregate_multi_store_data(product_id=product_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
|
||
product_name = product_df['product_name'].iloc[0] if not product_df.empty else product_id
|
||
|
||
if product_df.empty:
|
||
print(f"产品 {product_id} 或店铺 {store_id} 没有销售数据")
|
||
return None
|
||
|
||
# 加载模型和配置
|
||
try:
|
||
torch.serialization.add_safe_globals([sklearn.preprocessing._data.MinMaxScaler])
|
||
except Exception: pass
|
||
|
||
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
|
||
if 'config' not in checkpoint or 'scaler_X' not in checkpoint or 'scaler_y' not in checkpoint:
|
||
print("模型文件不完整,缺少config或scaler")
|
||
return None
|
||
|
||
config = checkpoint['config']
|
||
scaler_X = checkpoint['scaler_X']
|
||
scaler_y = checkpoint['scaler_y']
|
||
|
||
# 创建模型实例
|
||
# (此处省略了与原版本相同的模型创建代码,以保持简洁)
|
||
if model_type == 'transformer':
|
||
model = TimeSeriesTransformer(num_features=config['input_dim'], d_model=config['hidden_size'], nhead=config['num_heads'], num_encoder_layers=config['num_layers'], dim_feedforward=config['hidden_size'] * 2, dropout=config['dropout'], output_sequence_length=config['output_dim'], seq_length=config['sequence_length'], batch_size=32).to(DEVICE)
|
||
elif model_type == 'mlstm':
|
||
model = MatrixLSTM(num_features=config['input_dim'], hidden_size=config['hidden_size'], mlstm_layers=config['mlstm_layers'], embed_dim=config.get('embed_dim', 32), dense_dim=config.get('dense_dim', 32), num_heads=config.get('num_heads', 4), dropout_rate=config['dropout_rate'], num_blocks=config.get('num_blocks', 3), output_sequence_length=config['output_dim']).to(DEVICE)
|
||
elif model_type == 'kan':
|
||
model = KANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
|
||
elif model_type == 'optimized_kan':
|
||
model = OptimizedKANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
|
||
elif model_type == 'tcn':
|
||
model = TCNForecaster(num_features=config['input_dim'], output_sequence_length=config['output_dim'], num_channels=[config['hidden_size']] * config['num_layers'], kernel_size=config['kernel_size'], dropout=config['dropout']).to(DEVICE)
|
||
else:
|
||
print(f"不支持的模型类型: {model_type}"); return None
|
||
|
||
model.load_state_dict(checkpoint['model_state_dict'])
|
||
model.eval()
|
||
|
||
# --- 核心逻辑修改:自动回归预测 ---
|
||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||
sequence_length = config['sequence_length']
|
||
|
||
# 确定预测的起始点
|
||
if start_date:
|
||
start_date_dt = pd.to_datetime(start_date)
|
||
# 获取预测开始日期前的 `sequence_length` 天数据作为初始输入
|
||
prediction_input_df = product_df[product_df['date'] < start_date_dt].tail(sequence_length)
|
||
else:
|
||
# 如果未指定开始日期,则从数据的最后一天开始预测
|
||
prediction_input_df = product_df.tail(sequence_length)
|
||
start_date_dt = product_df['date'].iloc[-1] + timedelta(days=1)
|
||
|
||
if len(prediction_input_df) < sequence_length:
|
||
print(f"错误: 预测所需的历史数据不足。需要 {sequence_length} 天, 但只有 {len(prediction_input_df)} 天。")
|
||
return None
|
||
|
||
# 准备用于图表展示的历史数据
|
||
history_for_chart_df = product_df[product_df['date'] < start_date_dt].tail(history_lookback_days)
|
||
|
||
# 自动回归预测循环
|
||
all_predictions = []
|
||
current_sequence_df = prediction_input_df.copy()
|
||
|
||
print(f"开始自动回归预测,共 {future_days} 天...")
|
||
for i in range(future_days):
|
||
# 准备当前序列的输入张量
|
||
X_current_scaled = scaler_X.transform(current_sequence_df[features].values)
|
||
X_input = torch.tensor(X_current_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
|
||
|
||
# 模型进行一次预测(可能预测出多个点,但我们只用第一个)
|
||
with torch.no_grad():
|
||
y_pred_scaled = model(X_input).cpu().numpy()
|
||
|
||
# 提取下一个时间点的预测值
|
||
next_step_pred_scaled = y_pred_scaled[0, 0].reshape(1, -1)
|
||
next_step_pred_unscaled = scaler_y.inverse_transform(next_step_pred_scaled)[0][0]
|
||
next_step_pred_unscaled = float(max(0, next_step_pred_unscaled)) # 确保销量不为负,并转换为标准float
|
||
|
||
# 获取新预测的日期
|
||
next_date = current_sequence_df['date'].iloc[-1] + timedelta(days=1)
|
||
all_predictions.append({'date': next_date, 'predicted_sales': next_step_pred_unscaled})
|
||
|
||
# 构建新的一行数据,用于更新输入序列
|
||
new_row = {
|
||
'date': next_date,
|
||
'sales': next_step_pred_unscaled,
|
||
'weekday': next_date.weekday(),
|
||
'month': next_date.month,
|
||
'is_holiday': 0,
|
||
'is_weekend': 1 if next_date.weekday() >= 5 else 0,
|
||
'is_promotion': 0,
|
||
'temperature': current_sequence_df['temperature'].iloc[-1] # 沿用最后一天的温度
|
||
}
|
||
|
||
# 更新序列:移除最旧的一行,添加最新预测的一行
|
||
new_row_df = pd.DataFrame([new_row])
|
||
current_sequence_df = pd.concat([current_sequence_df.iloc[1:], new_row_df], ignore_index=True)
|
||
|
||
predictions_df = pd.DataFrame(all_predictions)
|
||
print(f"自动回归预测完成,生成 {len(predictions_df)} 条预测数据。")
|
||
|
||
# 分析与可视化
|
||
analysis = None
|
||
if analyze_result:
|
||
try:
|
||
y_pred_for_analysis = predictions_df['predicted_sales'].values
|
||
# 使用初始输入序列的特征进行分析
|
||
initial_features_for_analysis = prediction_input_df[features].values
|
||
analysis = analyze_prediction_result(product_id, model_type, y_pred_for_analysis, initial_features_for_analysis)
|
||
except Exception as e:
|
||
print(f"分析预测结果失败: {str(e)}")
|
||
|
||
# 在返回前,将DataFrame转换为前端期望的JSON数组格式
|
||
history_data_json = history_for_chart_df.to_dict('records') if not history_for_chart_df.empty else []
|
||
prediction_data_json = predictions_df.to_dict('records') if not predictions_df.empty else []
|
||
|
||
return {
|
||
'product_id': product_id,
|
||
'product_name': product_name,
|
||
'model_type': model_type,
|
||
'predictions': prediction_data_json, # 兼容旧字段,使用已转换的json
|
||
'prediction_data': prediction_data_json,
|
||
'history_data': history_data_json,
|
||
'analysis': analysis
|
||
}
|
||
except Exception as e:
|
||
print(f"预测过程中出现未捕获的异常: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return None |