ShopTRAINING/server/predictors/model_predictor.py
xz2000 87df49f764 添加训练算法模拟xgboosT,训练可以完成,预测读取还有问题
数据文件保存机构改为### 1.2. 文件存储位置

-   **最终产物**: 所有最终模型、元数据文件、损失图等,统一存放在 `saved_models/` 根目录下。
-   **过程文件**: 所有训练过程中的检查点文件,统一存放在 `saved_models/checkpoints/` 目录下。

### 1.3. 文件名生成规则

1.  **构建逻辑路径**: 根据训练参数(模式、范围、类型、版本)确定逻辑路径。
    -   *示例*: `product/P001_all/mlstm/v2`

2.  **生成文件名前缀**: 将逻辑路径中的所有 `/` 替换为 `_`。
    -   *示例*: `product_P001_all_mlstm_v2`

3.  **拼接文件后缀**: 在前缀后加上描述文件类型的后缀。
    -   `_model.pth`
    -   `_loss_curve.png`
    -   `_checkpoint_best.pth`
    -   `_checkpoint_epoch_{N}.pth`

#### **完整示例:**

-   **最终模型**: `saved_models/product_P001_all_mlstm_v2_model.pth`
-   **最佳检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_best.pth`
-   **Epoch 50 检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_epoch_50.pth`
2025-07-21 18:47:27 +08:00

417 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
药店销售预测系统 - 模型预测函数
"""
import os
import torch
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
import xgboost as xgb
from sklearn.preprocessing import MinMaxScaler
import sklearn.preprocessing._data # 添加这一行以支持MinMaxScaler的反序列化
import joblib
from models.transformer_model import TimeSeriesTransformer
from models.slstm_model import sLSTM as ScalarLSTM
from models.mlstm_model import MLSTMTransformer as MatrixLSTM
from models.kan_model import KANForecaster
from models.tcn_model import TCNForecaster
from models.optimized_kan_forecaster import OptimizedKANForecaster
from analysis.trend_analysis import analyze_prediction_result
from utils.visualization import plot_prediction_results
from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data
from core.config import DEVICE
from utils.file_save import ModelPathManager
def load_model_and_predict(product_id, model_type, model_path=None, store_id=None, future_days=7, start_date=None, analyze_result=False, version=None, training_mode='product', **kwargs):
"""
加载已训练的模型并进行预测
参数:
product_id: 产品ID
model_type: 模型类型 ('transformer', 'mlstm', 'kan', 'tcn', 'optimized_kan', 'xgboost')
model_path: 模型的完整文件路径
store_id: 店铺ID为None时使用全局模型
future_days: 预测未来天数
start_date: 预测起始日期如果为None则使用最后一个已知日期
analyze_result: 是否分析预测结果
version: 模型版本
返回:
预测结果和分析如果analyze_result为True
"""
try:
print(f"尝试加载模型文件: {model_path}")
# 如果没有提供 model_path则使用 ModelPathManager 动态生成
if not model_path:
if version is None:
raise ValueError("使用动态路径加载时必须提供 'version'")
path_manager = ModelPathManager()
# 传递所有必要的参数以重构路径
path_params = {
'product_id': product_id,
'store_id': store_id,
**kwargs
}
model_path = path_manager.get_model_path_for_prediction(
training_mode=training_mode,
model_type=model_type,
version=version,
**path_params
)
if not model_path or not os.path.exists(model_path):
print(f"模型文件 {model_path} 不存在或无法生成。")
return None
# 加载销售数据(支持多店铺)
try:
if store_id:
# 加载特定店铺的数据
product_df = get_store_product_sales_data(
store_id,
product_id,
None # 使用默认数据路径
)
store_name = product_df['store_name'].iloc[0] if 'store_name' in product_df.columns else f"店铺{store_id}"
prediction_scope = f"店铺 '{store_name}' ({store_id})"
else:
# 聚合所有店铺的数据进行预测
product_df = aggregate_multi_store_data(
product_id,
aggregation_method='sum',
file_path=None # 使用默认数据路径
)
prediction_scope = "全部店铺(聚合数据)"
except Exception as e:
print(f"多店铺数据加载失败,尝试使用原始数据格式: {e}")
# 后向兼容:尝试加载原始数据格式
try:
from core.config import DEFAULT_DATA_PATH
from utils.multi_store_data_utils import load_multi_store_data
df = load_multi_store_data(DEFAULT_DATA_PATH)
product_df = df[df['product_id'] == product_id].sort_values('date')
if store_id:
print(f"警告:原始数据不支持店铺过滤,将使用所有数据预测")
prediction_scope = "默认数据"
except Exception as e2:
print(f"加载产品数据失败: {str(e2)}")
return None
if product_df.empty:
print(f"产品 {product_id} 没有销售数据")
return None
product_name = product_df['product_name'].iloc[0]
print(f"使用 {model_type} 模型预测产品 '{product_name}' (ID: {product_id}) 的未来 {future_days} 天销量")
print(f"预测范围: {prediction_scope}")
# 添加安全的全局变量以支持MinMaxScaler的反序列化
try:
torch.serialization.add_safe_globals([sklearn.preprocessing._data.MinMaxScaler])
except Exception as e:
print(f"添加安全全局变量失败,但这可能不影响模型加载: {str(e)}")
# 加载模型和配置
try:
# 首先尝试使用weights_only=False加载
if model_type == 'xgboost':
if not os.path.exists(model_path):
print(f"XGBoost模型文件不存在: {model_path}")
return None
# 加载元数据
metadata = joblib.load(model_path)
model_file_path = metadata['model_file']
if not os.path.exists(model_file_path):
print(f"引用的XGBoost模型文件不存在: {model_file_path}")
return None
# 加载原生Booster模型
model = xgb.Booster()
model.load_model(model_file_path)
config = metadata['config']
metrics = metadata['metrics']
scaler_X = metadata['scaler_X']
scaler_y = metadata['scaler_y']
print("XGBoost原生模型及元数据加载成功")
else:
try:
print("尝试使用 weights_only=False 加载模型")
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
except Exception as e:
print(f"使用weights_only=False加载失败: {str(e)}")
print("尝试使用默认参数加载模型")
checkpoint = torch.load(model_path, map_location=DEVICE)
print(f"模型加载成功检查checkpoint类型: {type(checkpoint)}")
if isinstance(checkpoint, dict):
print(f"checkpoint包含的键: {list(checkpoint.keys())}")
else:
print(f"checkpoint不是字典类型而是: {type(checkpoint)}")
return None
except Exception as e:
print(f"加载模型失败: {str(e)}")
return None
# XGBoost有不同的处理逻辑
if model_type == 'xgboost':
look_back = config['look_back']
features = config['features']
# 准备输入数据
recent_data = product_df.iloc[-look_back:].copy()
predictions = []
current_input_df = recent_data[features].copy()
for _ in range(future_days):
# 归一化输入数据并展平
input_scaled = scaler_X.transform(current_input_df.values)
input_vector = input_scaled.flatten().reshape(1, -1)
# 预测缩放后的值
dpredict = xgb.DMatrix(input_vector)
prediction_scaled = model.predict(dpredict)
# 反归一化得到真实预测值
prediction = scaler_y.inverse_transform(prediction_scaled.reshape(-1, 1)).flatten()[0]
predictions.append(prediction)
# 更新输入窗口以进行下一次预测
# 创建新的一行,包含真实的预测值
new_row_values = current_input_df.iloc[-1].copy()
new_row_values['sales'] = prediction
# 可以在这里添加更复杂的未来特征生成逻辑例如根据新日期更新weekday, month等
new_row_df = pd.DataFrame([new_row_values], columns=features)
# 滚动窗口
current_input_df = pd.concat([current_input_df.iloc[1:], new_row_df], ignore_index=True)
# 生成预测日期
last_date = recent_data['date'].iloc[-1]
pred_dates = [last_date + timedelta(days=i+1) for i in range(future_days)]
y_pred = np.array(predictions)
else: # 原有的PyTorch模型逻辑
# 检查并获取配置
if 'config' not in checkpoint:
print("模型文件中没有配置信息")
return None
config = checkpoint['config']
print(f"模型配置: {config}")
# 检查并获取缩放器
if 'scaler_X' not in checkpoint or 'scaler_y' not in checkpoint:
print("模型文件中没有缩放器信息")
return None
scaler_X = checkpoint['scaler_X']
scaler_y = checkpoint['scaler_y']
# 创建模型实例
try:
if model_type == 'transformer':
model = TimeSeriesTransformer(
num_features=config['input_dim'],
d_model=config['hidden_size'],
nhead=config['num_heads'],
num_encoder_layers=config['num_layers'],
dim_feedforward=config['hidden_size'] * 2,
dropout=config['dropout'],
output_sequence_length=config['output_dim'],
seq_length=config['sequence_length'],
batch_size=32
).to(DEVICE)
elif model_type == 'slstm':
model = ScalarLSTM(
input_dim=config['input_dim'],
hidden_dim=config['hidden_size'],
output_dim=config['output_dim'],
num_layers=config['num_layers'],
dropout=config['dropout']
).to(DEVICE)
elif model_type == 'mlstm':
# 获取配置参数,如果不存在则使用默认值
embed_dim = config.get('embed_dim', 32)
dense_dim = config.get('dense_dim', 32)
num_heads = config.get('num_heads', 4)
num_blocks = config.get('num_blocks', 3)
model = MatrixLSTM(
num_features=config['input_dim'],
hidden_size=config['hidden_size'],
mlstm_layers=config['num_layers'],
embed_dim=embed_dim,
dense_dim=dense_dim,
num_heads=num_heads,
dropout_rate=config['dropout'],
num_blocks=num_blocks,
output_sequence_length=config['output_dim']
).to(DEVICE)
elif model_type == 'kan':
model = KANForecaster(
input_features=config['input_dim'],
hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']],
output_sequence_length=config['output_dim']
).to(DEVICE)
elif model_type == 'optimized_kan':
model = OptimizedKANForecaster(
input_features=config['input_dim'],
hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']],
output_sequence_length=config['output_dim']
).to(DEVICE)
elif model_type == 'tcn':
model = TCNForecaster(
num_features=config['input_dim'],
output_sequence_length=config['output_dim'],
num_channels=[config['hidden_size']] * config['num_layers'],
kernel_size=3,
dropout=config['dropout']
).to(DEVICE)
else:
print(f"不支持的模型类型: {model_type}")
return None
print(f"模型实例创建成功: {type(model)}")
except Exception as e:
print(f"创建模型实例失败: {str(e)}")
return None
# 加载模型参数
try:
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print("模型参数加载成功")
except Exception as e:
print(f"加载模型参数失败: {str(e)}")
return None
# 准备输入数据
try:
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
sequence_length = config['sequence_length']
# 获取最近的sequence_length天数据作为输入
recent_data = product_df.iloc[-sequence_length:].copy()
# 如果指定了起始日期,则使用该日期之后的数据
if start_date:
if isinstance(start_date, str):
start_date = datetime.strptime(start_date, '%Y-%m-%d')
recent_data = product_df[product_df['date'] >= start_date].iloc[:sequence_length].copy()
if len(recent_data) < sequence_length:
print(f"警告: 从指定日期 {start_date} 开始的数据少于所需的 {sequence_length}")
# 补充数据
missing_days = sequence_length - len(recent_data)
additional_data = product_df[product_df['date'] < start_date].iloc[-missing_days:].copy()
recent_data = pd.concat([additional_data, recent_data]).reset_index(drop=True)
print(f"输入数据准备完成,形状: {recent_data.shape}")
except Exception as e:
print(f"准备输入数据失败: {str(e)}")
return None
# 归一化输入数据
try:
X = recent_data[features].values
X_scaled = scaler_X.transform(X)
# 转换为模型输入格式
X_input = torch.tensor(X_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
print(f"输入张量准备完成,形状: {X_input.shape}")
except Exception as e:
print(f"归一化输入数据失败: {str(e)}")
return None
# 预测
try:
with torch.no_grad():
y_pred_scaled = model(X_input).cpu().numpy()
print(f"原始预测输出形状: {y_pred_scaled.shape}")
# 处理TCN、Transformer、mLSTM和KAN模型的输出确保形状正确
if model_type in ['tcn', 'transformer', 'mlstm', 'kan', 'optimized_kan'] and len(y_pred_scaled.shape) == 3:
y_pred_scaled = y_pred_scaled.squeeze(-1)
print(f"处理后的预测输出形状: {y_pred_scaled.shape}")
# 反归一化预测结果
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
print(f"反归一化后的预测结果: {y_pred}")
# 生成预测日期
last_date = recent_data['date'].iloc[-1]
pred_dates = [(last_date + timedelta(days=i+1)) for i in range(len(y_pred))]
print(f"预测日期: {pred_dates}")
except Exception as e:
print(f"执行预测失败: {str(e)}")
return None
# 创建预测结果DataFrame
try:
predictions_df = pd.DataFrame({
'date': pred_dates,
'sales': y_pred # 使用sales字段名而不是predicted_sales以便与历史数据兼容
})
print(f"预测结果DataFrame创建成功形状: {predictions_df.shape}")
except Exception as e:
print(f"创建预测结果DataFrame失败: {str(e)}")
return None
# 绘制预测结果
try:
plt.figure(figsize=(12, 6))
plt.plot(product_df['date'], product_df['sales'], 'b-', label='历史销量')
plt.plot(predictions_df['date'], predictions_df['sales'], 'r--', label='预测销量')
plt.title(f'{product_name} - {model_type}模型销量预测')
plt.xlabel('日期')
plt.ylabel('销量')
plt.legend()
plt.grid(True)
plt.xticks(rotation=45)
plt.tight_layout()
# 保存图像
plt.savefig(f'{product_id}_{model_type}_prediction.png')
plt.close()
print(f"预测结果已保存到 {product_id}_{model_type}_prediction.png")
except Exception as e:
print(f"绘制预测结果图表失败: {str(e)}")
# 这个错误不影响主要功能,继续执行
# 分析预测结果
analysis = None
if analyze_result:
try:
analysis = analyze_prediction_result(product_id, model_type, y_pred, X)
print("\n预测结果分析:")
if analysis and 'explanation' in analysis:
print(analysis['explanation'])
else:
print("分析结果不包含explanation字段")
except Exception as e:
print(f"分析预测结果失败: {str(e)}")
# 分析失败不影响主要功能,继续执行
return {
'product_id': product_id,
'product_name': product_name,
'model_type': model_type,
'predictions': predictions_df,
'analysis': analysis
}
except Exception as e:
print(f"预测过程中出现未捕获的异常: {str(e)}")
import traceback
traceback.print_exc()
return None