ShopTRAINING/server/predictors/model_predictor.py

189 lines
10 KiB
Python
Raw Normal View History

"""
药店销售预测系统 - 模型预测函数
"""
import os
import torch
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
import sklearn.preprocessing._data # 添加这一行以支持MinMaxScaler的反序列化
from typing import Optional
from models.transformer_model import TimeSeriesTransformer
from models.slstm_model import sLSTM as ScalarLSTM
from models.mlstm_model import MLSTMTransformer as MatrixLSTM
from models.kan_model import KANForecaster
from models.tcn_model import TCNForecaster
from models.optimized_kan_forecaster import OptimizedKANForecaster
from analysis.trend_analysis import analyze_prediction_result
from utils.visualization import plot_prediction_results
2025-07-02 11:05:23 +08:00
from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data
2025-07-15 20:06:17 +08:00
from core.config import DEVICE, get_model_file_path, DEFAULT_DATA_PATH
def load_model_and_predict(model_path: str, product_id: str, model_type: str, store_id: Optional[str] = None, future_days: int = 7, start_date: Optional[str] = None, analyze_result: bool = False, version: Optional[str] = None, training_mode: str = 'product', history_lookback_days: int = 30):
"""
加载已训练的模型并进行预测 (v3版 - 支持自动回归)
参数:
... (同上, 新增 history_lookback_days)
history_lookback_days: 用于图表展示的历史数据天数
返回:
预测结果和分析
"""
try:
print(f"v3版预测函数启动模型路径: {model_path}, 预测天数: {future_days}, 历史回看: {history_lookback_days}")
2025-07-02 11:05:23 +08:00
if not os.path.exists(model_path):
print(f"模型文件 {model_path} 不存在")
return None
# 加载销售数据
from utils.multi_store_data_utils import aggregate_multi_store_data
if training_mode == 'store' and store_id:
2025-07-21 18:44:20 +08:00
# 先从原始数据加载一次以获取店铺名称,聚合会丢失此信息
from utils.multi_store_data_utils import load_multi_store_data
store_df_for_name = load_multi_store_data(store_id=store_id)
product_name = store_df_for_name['store_name'].iloc[0] if not store_df_for_name.empty else f"店铺 {store_id}"
# 然后再进行聚合获取用于预测的数据
product_df = aggregate_multi_store_data(store_id=store_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
elif training_mode == 'global':
product_df = aggregate_multi_store_data(aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
product_name = "全局销售数据"
else:
product_df = aggregate_multi_store_data(product_id=product_id, aggregation_method='sum', file_path=DEFAULT_DATA_PATH)
product_name = product_df['product_name'].iloc[0] if not product_df.empty else product_id
2025-07-16 16:24:08 +08:00
2025-07-02 11:05:23 +08:00
if product_df.empty:
2025-07-16 16:24:08 +08:00
print(f"产品 {product_id} 或店铺 {store_id} 没有销售数据")
return None
# 加载模型和配置
try:
torch.serialization.add_safe_globals([sklearn.preprocessing._data.MinMaxScaler])
except Exception: pass
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
if 'config' not in checkpoint or 'scaler_X' not in checkpoint or 'scaler_y' not in checkpoint:
print("模型文件不完整缺少config或scaler")
return None
config = checkpoint['config']
scaler_X = checkpoint['scaler_X']
scaler_y = checkpoint['scaler_y']
# 创建模型实例
# (此处省略了与原版本相同的模型创建代码,以保持简洁)
if model_type == 'transformer':
model = TimeSeriesTransformer(num_features=config['input_dim'], d_model=config['hidden_size'], nhead=config['num_heads'], num_encoder_layers=config['num_layers'], dim_feedforward=config['hidden_size'] * 2, dropout=config['dropout'], output_sequence_length=config['output_dim'], seq_length=config['sequence_length'], batch_size=32).to(DEVICE)
elif model_type == 'mlstm':
model = MatrixLSTM(num_features=config['input_dim'], hidden_size=config['hidden_size'], mlstm_layers=config['mlstm_layers'], embed_dim=config.get('embed_dim', 32), dense_dim=config.get('dense_dim', 32), num_heads=config.get('num_heads', 4), dropout_rate=config['dropout_rate'], num_blocks=config.get('num_blocks', 3), output_sequence_length=config['output_dim']).to(DEVICE)
elif model_type == 'kan':
model = KANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
elif model_type == 'optimized_kan':
model = OptimizedKANForecaster(input_features=config['input_dim'], hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']], output_sequence_length=config['output_dim']).to(DEVICE)
elif model_type == 'tcn':
model = TCNForecaster(num_features=config['input_dim'], output_sequence_length=config['output_dim'], num_channels=[config['hidden_size']] * config['num_layers'], kernel_size=config['kernel_size'], dropout=config['dropout']).to(DEVICE)
else:
print(f"不支持的模型类型: {model_type}"); return None
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
# --- 核心逻辑修改:自动回归预测 ---
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
sequence_length = config['sequence_length']
# 确定预测的起始点
if start_date:
start_date_dt = pd.to_datetime(start_date)
# 获取预测开始日期前的 `sequence_length` 天数据作为初始输入
prediction_input_df = product_df[product_df['date'] < start_date_dt].tail(sequence_length)
else:
# 如果未指定开始日期,则从数据的最后一天开始预测
prediction_input_df = product_df.tail(sequence_length)
start_date_dt = product_df['date'].iloc[-1] + timedelta(days=1)
if len(prediction_input_df) < sequence_length:
print(f"错误: 预测所需的历史数据不足。需要 {sequence_length} 天, 但只有 {len(prediction_input_df)} 天。")
return None
# 准备用于图表展示的历史数据
history_for_chart_df = product_df[product_df['date'] < start_date_dt].tail(history_lookback_days)
# 自动回归预测循环
all_predictions = []
current_sequence_df = prediction_input_df.copy()
print(f"开始自动回归预测,共 {future_days} 天...")
for i in range(future_days):
# 准备当前序列的输入张量
X_current_scaled = scaler_X.transform(current_sequence_df[features].values)
X_input = torch.tensor(X_current_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
# 模型进行一次预测(可能预测出多个点,但我们只用第一个)
with torch.no_grad():
y_pred_scaled = model(X_input).cpu().numpy()
# 提取下一个时间点的预测值
next_step_pred_scaled = y_pred_scaled[0, 0].reshape(1, -1)
next_step_pred_unscaled = scaler_y.inverse_transform(next_step_pred_scaled)[0][0]
next_step_pred_unscaled = float(max(0, next_step_pred_unscaled)) # 确保销量不为负并转换为标准float
# 获取新预测的日期
next_date = current_sequence_df['date'].iloc[-1] + timedelta(days=1)
all_predictions.append({'date': next_date, 'predicted_sales': next_step_pred_unscaled})
# 构建新的一行数据,用于更新输入序列
new_row = {
'date': next_date,
'sales': next_step_pred_unscaled,
'weekday': next_date.weekday(),
'month': next_date.month,
'is_holiday': 0,
'is_weekend': 1 if next_date.weekday() >= 5 else 0,
'is_promotion': 0,
'temperature': current_sequence_df['temperature'].iloc[-1] # 沿用最后一天的温度
}
# 更新序列:移除最旧的一行,添加最新预测的一行
new_row_df = pd.DataFrame([new_row])
current_sequence_df = pd.concat([current_sequence_df.iloc[1:], new_row_df], ignore_index=True)
predictions_df = pd.DataFrame(all_predictions)
print(f"自动回归预测完成,生成 {len(predictions_df)} 条预测数据。")
# 分析与可视化
analysis = None
if analyze_result:
try:
y_pred_for_analysis = predictions_df['predicted_sales'].values
# 使用初始输入序列的特征进行分析
initial_features_for_analysis = prediction_input_df[features].values
analysis = analyze_prediction_result(product_id, model_type, y_pred_for_analysis, initial_features_for_analysis)
except Exception as e:
print(f"分析预测结果失败: {str(e)}")
# 在返回前将DataFrame转换为前端期望的JSON数组格式
history_data_json = history_for_chart_df.to_dict('records') if not history_for_chart_df.empty else []
prediction_data_json = predictions_df.to_dict('records') if not predictions_df.empty else []
2025-07-16 12:59:56 +08:00
return {
'product_id': product_id,
'product_name': product_name,
'model_type': model_type,
'predictions': prediction_data_json, # 兼容旧字段使用已转换的json
'prediction_data': prediction_data_json,
'history_data': history_data_json,
'analysis': analysis
}
except Exception as e:
print(f"预测过程中出现未捕获的异常: {str(e)}")
import traceback
traceback.print_exc()
return None