ShopTRAINING/server/predictors/model_predictor.py
xz2000 e999ed4af2 ### 2025-07-15 (续): 训练器与核心调用层重构
**核心目标**: 将新的 `ModelManager` 统一应用到项目中所有剩余的模型训练器,并重构核心调用逻辑,确保整个训练链路的架构一致性。

**1. 修改 `server/trainers/kan_trainer.py`**
*   **内容**: 完全重写了 `kan_trainer.py`。
    *   **适配接口**: 函数签名与 `mlstm_trainer` 对齐,增加了 `socketio`, `task_id`, `patience` 等参数。
    *   **集成 `ModelManager`**: 移除了所有旧的、手动的保存逻辑,改为在训练开始时调用 `model_manager` 获取版本号和路径。
    *   **标准化产物保存**: 所有产物(模型、元数据、检查点、损失曲线)均通过 `model_manager.save_model_artifact()` 保存。
    *   **增加健壮性**: 引入了早停(Early Stopping)和保存最佳检查点(Best Checkpoint)的逻辑。

**2. 修改 `server/trainers/tcn_trainer.py`**
*   **内容**: 完全重写了 `tcn_trainer.py`,应用了与 `kan_trainer` 完全相同的重构模式。
    *   移除了旧的 `save_checkpoint` 辅助函数和基于 `core.config` 的版本管理。
    *   全面转向使用 `model_manager` 进行版本控制和文件保存。
    *   统一了函数签名和进度反馈逻辑。

**3. 修改 `server/trainers/transformer_trainer.py`**
*   **内容**: 完全重写了 `transformer_trainer.py`,完成了对所有训练器的统一重构。
    *   移除了所有遗留的、基于文件名的路径拼接和保存逻辑。
    *   实现了与其它训练器一致的、基于 `ModelManager` 的标准化训练流程。

**4. 修改 `server/core/predictor.py`**
*   **内容**: 对核心预测器类 `PharmacyPredictor` 进行了彻底重构。
    *   **统一调用接口**: `train_model` 方法现在以完全一致的方式调用所有(`mlstm`, `kan`, `tcn`, `transformer`)训练器。
    *   **移除旧逻辑**: 删除了 `_parse_model_filename` 等所有基于文件名解析的旧方法。
    *   **适配 `ModelManager`**: `list_models` 和 `delete_model` 等方法现在直接调用 `model_manager` 的相应功能,不再自己实现逻辑。
    *   **简化 `predict`**: 预测方法现在直接接收标准化的模型版本路径 (`model_version_path`) 作为输入,逻辑更清晰。
2025-07-15 20:09:09 +08:00

340 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
药店销售预测系统 - 模型预测函数
"""
import os
import torch
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
import sklearn.preprocessing._data # 添加这一行以支持MinMaxScaler的反序列化
from models.transformer_model import TimeSeriesTransformer
from models.slstm_model import sLSTM as ScalarLSTM
from models.mlstm_model import MLSTMTransformer as MatrixLSTM
from models.kan_model import KANForecaster
from models.tcn_model import TCNForecaster
from models.optimized_kan_forecaster import OptimizedKANForecaster
from analysis.trend_analysis import analyze_prediction_result
from utils.visualization import plot_prediction_results
from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data
from core.config import DEVICE
def load_model_and_predict(model_version_path: str, future_days=7, start_date=None, analyze_result=False):
"""
从指定的版本目录加载模型并进行预测。
参数:
model_version_path: 模型版本目录的绝对路径。
future_days: 预测未来天数。
start_date: 预测起始日期如果为None则使用最后一个已知日期。
analyze_result: 是否分析预测结果。
返回:
预测结果和分析如果analyze_result为True
"""
try:
# 从路径中解析元数据
metadata_path = os.path.join(model_version_path, 'metadata.json')
if not os.path.exists(metadata_path):
raise FileNotFoundError(f"在路径 {model_version_path} 中未找到 metadata.json")
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
product_id = metadata.get('product_id')
model_type = metadata.get('model_type')
store_id = metadata.get('store_id')
training_mode = metadata.get('training_mode')
aggregation_method = metadata.get('aggregation_method')
model_path = os.path.join(model_version_path, 'model.pth')
print(f"尝试加载模型文件: {model_path}")
if not os.path.exists(model_path):
print(f"模型文件 {model_path} 不存在")
return None
# 加载销售数据(支持多店铺)
try:
if store_id:
# 加载特定店铺的数据
product_df = get_store_product_sales_data(
store_id,
product_id,
'pharmacy_sales_multi_store.csv'
)
store_name = product_df['store_name'].iloc[0] if 'store_name' in product_df.columns else f"店铺{store_id}"
prediction_scope = f"店铺 '{store_name}' ({store_id})"
else:
# 聚合所有店铺的数据进行预测
product_df = aggregate_multi_store_data(
product_id,
aggregation_method='sum',
file_path='pharmacy_sales_multi_store.csv'
)
prediction_scope = "全部店铺(聚合数据)"
except Exception as e:
print(f"多店铺数据加载失败,尝试使用原始数据格式: {e}")
# 后向兼容:尝试加载原始数据格式
try:
df = pd.read_excel('pharmacy_sales.xlsx')
product_df = df[df['product_id'] == product_id].sort_values('date')
if store_id:
print(f"警告:原始数据不支持店铺过滤,将使用所有数据预测")
prediction_scope = "默认数据"
except Exception as e2:
print(f"加载产品数据失败: {str(e2)}")
return None
if product_df.empty:
print(f"产品 {product_id} 没有销售数据")
return None
product_name = product_df['product_name'].iloc[0]
print(f"使用 {model_type} 模型预测产品 '{product_name}' (ID: {product_id}) 的未来 {future_days} 天销量")
print(f"预测范围: {prediction_scope}")
# 添加安全的全局变量以支持MinMaxScaler的反序列化
try:
torch.serialization.add_safe_globals([sklearn.preprocessing._data.MinMaxScaler])
except Exception as e:
print(f"添加安全全局变量失败,但这可能不影响模型加载: {str(e)}")
# 加载模型和配置
try:
# 首先尝试使用weights_only=False加载
try:
print("尝试使用 weights_only=False 加载模型")
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
except Exception as e:
print(f"使用weights_only=False加载失败: {str(e)}")
print("尝试使用默认参数加载模型")
checkpoint = torch.load(model_path, map_location=DEVICE)
print(f"模型加载成功检查checkpoint类型: {type(checkpoint)}")
if isinstance(checkpoint, dict):
print(f"checkpoint包含的键: {list(checkpoint.keys())}")
else:
print(f"checkpoint不是字典类型而是: {type(checkpoint)}")
return None
except Exception as e:
print(f"加载模型失败: {str(e)}")
return None
# 检查并获取配置
if 'config' not in checkpoint:
print("模型文件中没有配置信息")
return None
config = checkpoint['config']
print(f"模型配置: {config}")
# 检查并获取缩放器
if 'scaler_X' not in checkpoint or 'scaler_y' not in checkpoint:
print("模型文件中没有缩放器信息")
return None
scaler_X = checkpoint['scaler_X']
scaler_y = checkpoint['scaler_y']
# 创建模型实例
try:
if model_type == 'transformer':
model = TimeSeriesTransformer(
num_features=config['input_dim'],
d_model=config['hidden_size'],
nhead=config['num_heads'],
num_encoder_layers=config['num_layers'],
dim_feedforward=config['hidden_size'] * 2,
dropout=config['dropout'],
output_sequence_length=config['output_dim'],
seq_length=config['sequence_length'],
batch_size=32
).to(DEVICE)
elif model_type == 'slstm':
model = ScalarLSTM(
input_dim=config['input_dim'],
hidden_dim=config['hidden_size'],
output_dim=config['output_dim'],
num_layers=config['num_layers'],
dropout=config['dropout']
).to(DEVICE)
elif model_type == 'mlstm':
# 获取配置参数,如果不存在则使用默认值
embed_dim = config.get('embed_dim', 32)
dense_dim = config.get('dense_dim', 32)
num_heads = config.get('num_heads', 4)
num_blocks = config.get('num_blocks', 3)
model = MatrixLSTM(
num_features=config['input_dim'],
hidden_size=config['hidden_size'],
mlstm_layers=config['num_layers'],
embed_dim=embed_dim,
dense_dim=dense_dim,
num_heads=num_heads,
dropout_rate=config['dropout'],
num_blocks=num_blocks,
output_sequence_length=config['output_dim']
).to(DEVICE)
elif model_type == 'kan':
model = KANForecaster(
input_features=config['input_dim'],
hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']],
output_sequence_length=config['output_dim']
).to(DEVICE)
elif model_type == 'optimized_kan':
model = OptimizedKANForecaster(
input_features=config['input_dim'],
hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']],
output_sequence_length=config['output_dim']
).to(DEVICE)
elif model_type == 'tcn':
model = TCNForecaster(
num_features=config['input_dim'],
output_sequence_length=config['output_dim'],
num_channels=[config['hidden_size']] * config['num_layers'],
kernel_size=3,
dropout=config['dropout']
).to(DEVICE)
else:
print(f"不支持的模型类型: {model_type}")
return None
print(f"模型实例创建成功: {type(model)}")
except Exception as e:
print(f"创建模型实例失败: {str(e)}")
return None
# 加载模型参数
try:
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
print("模型参数加载成功")
except Exception as e:
print(f"加载模型参数失败: {str(e)}")
return None
# 准备输入数据
try:
features = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
sequence_length = config['sequence_length']
# 获取最近的sequence_length天数据作为输入
recent_data = product_df.iloc[-sequence_length:].copy()
# 如果指定了起始日期,则使用该日期之后的数据
if start_date:
if isinstance(start_date, str):
start_date = datetime.strptime(start_date, '%Y-%m-%d')
recent_data = product_df[product_df['date'] >= start_date].iloc[:sequence_length].copy()
if len(recent_data) < sequence_length:
print(f"警告: 从指定日期 {start_date} 开始的数据少于所需的 {sequence_length}")
# 补充数据
missing_days = sequence_length - len(recent_data)
additional_data = product_df[product_df['date'] < start_date].iloc[-missing_days:].copy()
recent_data = pd.concat([additional_data, recent_data]).reset_index(drop=True)
print(f"输入数据准备完成,形状: {recent_data.shape}")
except Exception as e:
print(f"准备输入数据失败: {str(e)}")
return None
# 归一化输入数据
try:
X = recent_data[features].values
X_scaled = scaler_X.transform(X)
# 转换为模型输入格式
X_input = torch.tensor(X_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
print(f"输入张量准备完成,形状: {X_input.shape}")
except Exception as e:
print(f"归一化输入数据失败: {str(e)}")
return None
# 预测
try:
with torch.no_grad():
y_pred_scaled = model(X_input).cpu().numpy()
print(f"原始预测输出形状: {y_pred_scaled.shape}")
# 处理TCN、Transformer、mLSTM和KAN模型的输出确保形状正确
if model_type in ['tcn', 'transformer', 'mlstm', 'kan', 'optimized_kan'] and len(y_pred_scaled.shape) == 3:
y_pred_scaled = y_pred_scaled.squeeze(-1)
print(f"处理后的预测输出形状: {y_pred_scaled.shape}")
# 反归一化预测结果
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
print(f"反归一化后的预测结果: {y_pred}")
# 生成预测日期
last_date = recent_data['date'].iloc[-1]
pred_dates = [(last_date + timedelta(days=i+1)) for i in range(len(y_pred))]
print(f"预测日期: {pred_dates}")
except Exception as e:
print(f"执行预测失败: {str(e)}")
return None
# 创建预测结果DataFrame
try:
predictions_df = pd.DataFrame({
'date': pred_dates,
'sales': y_pred # 使用sales字段名而不是predicted_sales以便与历史数据兼容
})
print(f"预测结果DataFrame创建成功形状: {predictions_df.shape}")
except Exception as e:
print(f"创建预测结果DataFrame失败: {str(e)}")
return None
# 绘制预测结果
try:
plt.figure(figsize=(12, 6))
plt.plot(product_df['date'], product_df['sales'], 'b-', label='历史销量')
plt.plot(predictions_df['date'], predictions_df['sales'], 'r--', label='预测销量')
plt.title(f'{product_name} - {model_type}模型销量预测')
plt.xlabel('日期')
plt.ylabel('销量')
plt.legend()
plt.grid(True)
plt.xticks(rotation=45)
plt.tight_layout()
# 保存图像
plt.savefig(f'{product_id}_{model_type}_prediction.png')
plt.close()
print(f"预测结果已保存到 {product_id}_{model_type}_prediction.png")
except Exception as e:
print(f"绘制预测结果图表失败: {str(e)}")
# 这个错误不影响主要功能,继续执行
# 分析预测结果
analysis = None
if analyze_result:
try:
analysis = analyze_prediction_result(product_id, model_type, y_pred, X)
print("\n预测结果分析:")
if analysis and 'explanation' in analysis:
print(analysis['explanation'])
else:
print("分析结果不包含explanation字段")
except Exception as e:
print(f"分析预测结果失败: {str(e)}")
# 分析失败不影响主要功能,继续执行
return {
'product_id': product_id,
'product_name': product_name,
'model_type': model_type,
'predictions': predictions_df,
'analysis': analysis
}
except Exception as e:
print(f"预测过程中出现未捕获的异常: {str(e)}")
import traceback
traceback.print_exc()
return None