ShopTRAINING/server/predictors/model_predictor.py

381 lines
17 KiB
Python
Raw Normal View History

"""
药店销售预测系统 - 模型预测函数
"""
import os
import torch
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
import sklearn.preprocessing._data # 添加这一行以支持MinMaxScaler的反序列化
from models.transformer_model import TimeSeriesTransformer
from models.slstm_model import sLSTM as ScalarLSTM
from models.mlstm_model import MLSTMTransformer as MatrixLSTM
from models.kan_model import KANForecaster
from models.tcn_model import TCNForecaster
from models.optimized_kan_forecaster import OptimizedKANForecaster
from analysis.trend_analysis import analyze_prediction_result
from utils.visualization import plot_prediction_results
2025-07-02 11:05:23 +08:00
from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data
from core.config import DEVICE, get_model_file_path
2025-07-02 11:05:23 +08:00
def load_model_and_predict(product_id, model_type, store_id=None, future_days=7, start_date=None, analyze_result=False, version=None):
"""
加载已训练的模型并进行预测
参数:
product_id: 产品ID
model_type: 模型类型 ('transformer', 'mlstm', 'kan', 'tcn', 'optimized_kan')
2025-07-02 11:05:23 +08:00
store_id: 店铺ID为None时使用全局模型
future_days: 预测未来天数
start_date: 预测起始日期如果为None则使用最后一个已知日期
analyze_result: 是否分析预测结果
2025-07-02 11:05:23 +08:00
version: 模型版本如果为None则使用最新版本
返回:
预测结果和分析如果analyze_result为True
"""
try:
2025-07-02 11:05:23 +08:00
# 确定模型文件路径(支持多店铺)
model_path = None
2025-07-02 11:05:23 +08:00
if version:
# 使用版本管理系统获取正确的文件路径
model_path = get_model_file_path(product_id, model_type, version)
else:
# 根据store_id确定搜索目录
if store_id:
# 查找特定店铺的模型
possible_dirs = [
os.path.join('saved_models', model_type, store_id),
os.path.join('models', model_type, store_id)
]
else:
# 查找全局模型
possible_dirs = [
os.path.join('saved_models', model_type, 'global'),
os.path.join('models', model_type, 'global'),
os.path.join('saved_models', model_type), # 后向兼容
'saved_models' # 最基本的目录
]
# 文件名模式
model_suffix = '_optimized' if model_type == 'optimized_kan' else ''
file_model_type = 'kan' if model_type == 'optimized_kan' else model_type
possible_names = [
f"{product_id}_{model_type}_v1_model.pt", # 新多店铺格式
f"{product_id}_{model_type}_v1_global_model.pt", # 全局模型格式
f"{product_id}_{model_type}_v1.pth", # 旧版本格式
f"{file_model_type}{model_suffix}_model_product_{product_id}.pth", # 原始格式
f"{model_type}_model_product_{product_id}.pth" # 简化格式
]
2025-07-02 11:05:23 +08:00
# 搜索模型文件
for dir_path in possible_dirs:
if not os.path.exists(dir_path):
continue
for name in possible_names:
test_path = os.path.join(dir_path, name)
if os.path.exists(test_path):
model_path = test_path
break
if model_path:
break
2025-07-02 11:05:23 +08:00
if not model_path:
scope_msg = f"店铺 {store_id}" if store_id else "全局"
print(f"找不到产品 {product_id}{model_type} 模型文件 ({scope_msg})")
print(f"搜索目录: {possible_dirs}")
return None
2025-07-02 11:05:23 +08:00
print(f"尝试加载模型文件: {model_path}")
if not os.path.exists(model_path):
print(f"模型文件 {model_path} 不存在")
return None
# 加载销售数据(支持多店铺)
try:
2025-07-02 11:05:23 +08:00
if store_id:
# 加载特定店铺的数据
product_df = get_store_product_sales_data(
store_id,
product_id,
'pharmacy_sales_multi_store.csv'
)
store_name = product_df['store_name'].iloc[0] if 'store_name' in product_df.columns else f"店铺{store_id}"
prediction_scope = f"店铺 '{store_name}' ({store_id})"
else:
# 聚合所有店铺的数据进行预测
product_df = aggregate_multi_store_data(
product_id,
aggregation_method='sum',
file_path='pharmacy_sales_multi_store.csv'
)
prediction_scope = "全部店铺(聚合数据)"
except Exception as e:
2025-07-02 11:05:23 +08:00
print(f"多店铺数据加载失败,尝试使用原始数据格式: {e}")
# 后向兼容:尝试加载原始数据格式
try:
df = pd.read_excel('pharmacy_sales.xlsx')
product_df = df[df['product_id'] == product_id].sort_values('date')
if store_id:
print(f"警告:原始数据不支持店铺过滤,将使用所有数据预测")
prediction_scope = "默认数据"
except Exception as e2:
print(f"加载产品数据失败: {str(e2)}")
return None
if product_df.empty:
print(f"产品 {product_id} 没有销售数据")
return None
2025-07-02 11:05:23 +08:00
product_name = product_df['product_name'].iloc[0]
print(f"使用 {model_type} 模型预测产品 '{product_name}' (ID: {product_id}) 的未来 {future_days} 天销量")
print(f"预测范围: {prediction_scope}")
# 添加安全的全局变量以支持MinMaxScaler的反序列化
try:
torch.serialization.add_safe_globals([sklearn.preprocessing._data.MinMaxScaler])
except Exception as e:
print(f"添加安全全局变量失败,但这可能不影响模型加载: {str(e)}")
# 加载模型和配置
try:
# 首先尝试使用weights_only=False加载
try:
print("尝试使用 weights_only=False 加载模型")
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
except Exception as e:
print(f"使用weights_only=False加载失败: {str(e)}")
print("尝试使用默认参数加载模型")
checkpoint = torch.load(model_path, map_location=DEVICE)
print(f"模型加载成功检查checkpoint类型: {type(checkpoint)}")
if isinstance(checkpoint, dict):
print(f"checkpoint包含的键: {list(checkpoint.keys())}")
else:
print(f"checkpoint不是字典类型而是: {type(checkpoint)}")
return None
except Exception as e:
print(f"加载模型失败: {str(e)}")
return None
# 检查并获取配置
if 'config' not in checkpoint:
print("模型文件中没有配置信息")
return None
config = checkpoint['config']
print(f"模型配置: {config}")
# 检查并获取缩放器
if 'scaler_X' not in checkpoint or 'scaler_y' not in checkpoint:
print("模型文件中没有缩放器信息")
return None
scaler_X = checkpoint['scaler_X']
scaler_y = checkpoint['scaler_y']
# 创建模型实例
try:
if model_type == 'transformer':
model = TimeSeriesTransformer(
num_features=config['input_dim'],
d_model=config['hidden_size'],
nhead=config['num_heads'],
num_encoder_layers=config['num_layers'],
dim_feedforward=config['hidden_size'] * 2,
dropout=config['dropout'],
output_sequence_length=config['output_dim'],
seq_length=config['sequence_length'],
batch_size=32
).to(DEVICE)
elif model_type == 'slstm':
model = ScalarLSTM(
input_dim=config['input_dim'],
hidden_dim=config['hidden_size'],
output_dim=config['output_dim'],
num_layers=config['num_layers'],
dropout=config['dropout']
).to(DEVICE)
elif model_type == 'mlstm':
# 获取配置参数,如果不存在则使用默认值
embed_dim = config.get('embed_dim', 32)
dense_dim = config.get('dense_dim', 32)
num_heads = config.get('num_heads', 4)
num_blocks = config.get('num_blocks', 3)
model = MatrixLSTM(
num_features=config['input_dim'],
hidden_size=config['hidden_size'],
mlstm_layers=config['num_layers'],
embed_dim=embed_dim,
dense_dim=dense_dim,
num_heads=num_heads,
dropout_rate=config['dropout'],
num_blocks=num_blocks,
output_sequence_length=config['output_dim']
).to(DEVICE)
elif model_type == 'kan':
model = KANForecaster(
input_features=config['input_dim'],
hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']],
output_sequence_length=config['output_dim']
).to(DEVICE)
elif model_type == 'optimized_kan':
model = OptimizedKANForecaster(
input_features=config['input_dim'],
hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']],
output_sequence_length=config['output_dim']
).to(DEVICE)
elif model_type == 'tcn':
model = TCNForecaster(
num_features=config['input_dim'],
output_sequence_length=config['output_dim'],
num_channels=[config['hidden_size']] * config['num_layers'],
kernel_size=3,
dropout=config['dropout']
).to(DEVICE)
else:
print(f"不支持的模型类型: {model_type}")
return None
print(f"模型实例创建成功: {type(model)}")
except Exception as e:
print(f"创建模型实例失败: {str(e)}")
return None
# 加载模型参数
try:
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print("模型参数加载成功")
except Exception as e:
print(f"加载模型参数失败: {str(e)}")
return None
# 准备输入数据
try:
features = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
sequence_length = config['sequence_length']
# 获取最近的sequence_length天数据作为输入
recent_data = product_df.iloc[-sequence_length:].copy()
# 如果指定了起始日期,则使用该日期之后的数据
if start_date:
if isinstance(start_date, str):
start_date = datetime.strptime(start_date, '%Y-%m-%d')
recent_data = product_df[product_df['date'] >= start_date].iloc[:sequence_length].copy()
if len(recent_data) < sequence_length:
print(f"警告: 从指定日期 {start_date} 开始的数据少于所需的 {sequence_length}")
# 补充数据
missing_days = sequence_length - len(recent_data)
additional_data = product_df[product_df['date'] < start_date].iloc[-missing_days:].copy()
recent_data = pd.concat([additional_data, recent_data]).reset_index(drop=True)
print(f"输入数据准备完成,形状: {recent_data.shape}")
except Exception as e:
print(f"准备输入数据失败: {str(e)}")
return None
# 归一化输入数据
try:
X = recent_data[features].values
X_scaled = scaler_X.transform(X)
# 转换为模型输入格式
X_input = torch.tensor(X_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
print(f"输入张量准备完成,形状: {X_input.shape}")
except Exception as e:
print(f"归一化输入数据失败: {str(e)}")
return None
# 预测
try:
with torch.no_grad():
y_pred_scaled = model(X_input).cpu().numpy()
print(f"原始预测输出形状: {y_pred_scaled.shape}")
# 处理TCN、Transformer、mLSTM和KAN模型的输出确保形状正确
if model_type in ['tcn', 'transformer', 'mlstm', 'kan', 'optimized_kan'] and len(y_pred_scaled.shape) == 3:
y_pred_scaled = y_pred_scaled.squeeze(-1)
print(f"处理后的预测输出形状: {y_pred_scaled.shape}")
# 反归一化预测结果
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
print(f"反归一化后的预测结果: {y_pred}")
# 生成预测日期
last_date = recent_data['date'].iloc[-1]
pred_dates = [(last_date + timedelta(days=i+1)) for i in range(len(y_pred))]
print(f"预测日期: {pred_dates}")
except Exception as e:
print(f"执行预测失败: {str(e)}")
return None
# 创建预测结果DataFrame
try:
predictions_df = pd.DataFrame({
'date': pred_dates,
'sales': y_pred # 使用sales字段名而不是predicted_sales以便与历史数据兼容
})
print(f"预测结果DataFrame创建成功形状: {predictions_df.shape}")
except Exception as e:
print(f"创建预测结果DataFrame失败: {str(e)}")
return None
# 绘制预测结果
try:
plt.figure(figsize=(12, 6))
plt.plot(product_df['date'], product_df['sales'], 'b-', label='历史销量')
plt.plot(predictions_df['date'], predictions_df['sales'], 'r--', label='预测销量')
plt.title(f'{product_name} - {model_type}模型销量预测')
plt.xlabel('日期')
plt.ylabel('销量')
plt.legend()
plt.grid(True)
plt.xticks(rotation=45)
plt.tight_layout()
# 保存图像
plt.savefig(f'{product_id}_{model_type}_prediction.png')
plt.close()
print(f"预测结果已保存到 {product_id}_{model_type}_prediction.png")
except Exception as e:
print(f"绘制预测结果图表失败: {str(e)}")
# 这个错误不影响主要功能,继续执行
# 分析预测结果
analysis = None
if analyze_result:
try:
analysis = analyze_prediction_result(product_id, model_type, y_pred, X)
print("\n预测结果分析:")
if analysis and 'explanation' in analysis:
print(analysis['explanation'])
else:
print("分析结果不包含explanation字段")
except Exception as e:
print(f"分析预测结果失败: {str(e)}")
# 分析失败不影响主要功能,继续执行
return {
'product_id': product_id,
'product_name': product_name,
'model_type': model_type,
'predictions': predictions_df,
'analysis': analysis
}
except Exception as e:
print(f"预测过程中出现未捕获的异常: {str(e)}")
import traceback
traceback.print_exc()
return None