**版本**: 4.0 (最终版) **核心思想**: 逻辑路径被转换为文件名的一部分,实现极致扁平化的文件存储。 --- ## 一、 文件保存规则 ### 1.1. 核心原则 所有元数据都被编码到文件名中。一个逻辑上的层级路径(例如 `product/P001_all/mlstm/v2`)应该被转换为一个用下划线连接的文件名前缀(`product_P001_all_mlstm_v2`)。 ### 1.2. 文件存储位置 - **最终产物**: 所有最终模型、元数据文件、损失图等,统一存放在 `saved_models/` 根目录下。 - **过程文件**: 所有训练过程中的检查点文件,统一存放在 `saved_models/checkpoints/` 目录下。 ### 1.3. 文件名生成规则 1. **构建逻辑路径**: 根据训练参数(模式、范围、类型、版本)确定逻辑路径。 - *示例*: `product/P001_all/mlstm/v2` 2. **生成文件名前缀**: 将逻辑路径中的所有 `/` 替换为 `_`。 - *示例*: `product_P001_all_mlstm_v2` 3. **拼接文件后缀**: 在前缀后加上描述文件类型的后缀。 - `_model.pth` - `_metadata.json` - `_loss_curve.png` - `_checkpoint_best.pth` - `_checkpoint_epoch_{N}.pth` #### **完整示例:** - **最终模型**: `saved_models/product_P001_all_mlstm_v2_model.pth` - **元数据**: `saved_models/product_P001_all_mlstm_v2_metadata.json` - **最佳检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_best.pth` - **Epoch 50 检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_epoch_50.pth` --- ## 二、 文件读取规则 1. **确定模型元数据**: 根据需求确定要加载的模型的训练模式、范围、类型和版本。 2. **构建文件名前缀**: 按照与保存时相同的逻辑,将元数据拼接成文件名前缀(例如 `product_P001_all_mlstm_v2`)。 3. **定位文件**: - 要加载最终模型,查找文件: `saved_models/{prefix}_model.pth`。 - 要加载最佳检查点,查找文件: `saved_models/checkpoints/{prefix}_checkpoint_best.pth`。 --- ## 三、 数据库存储规则 数据库用于索引,应存储足以重构文件名前缀的关键元数据。 #### **`models` 表结构建议:** | 字段名 | 类型 | 描述 | 示例 | | :--- | :--- | :--- | :--- | | `id` | INTEGER | 主键 | 1 | | `filename_prefix` | TEXT | **完整文件名前缀,可作为唯一标识** | `product_P001_all_mlstm_v2` | | `model_identifier`| TEXT | 用于版本控制的标识符 (不含版本) | `product_P001_all_mlstm` | | `version` | INTEGER | 版本号 | `2` | | `status` | TEXT | 模型状态 | `completed`, `training`, `failed` | | `created_at` | TEXT | 创建时间 | `2025-07-21 02:29:00` | | `metrics_summary`| TEXT | 关键性能指标的JSON字符串 | `{"rmse": 10.5, "r2": 0.89}` | #### **保存逻辑:** - 训练完成后,向表中插入一条记录。`filename_prefix` 字段是查找与该次训练相关的所有文件的关键。 --- ## 四、 版本记录规则 版本管理依赖于根目录下的 `versions.json` 文件,以实现原子化、线程安全的版本号递增。 - **文件名**: `versions.json` - **位置**: `saved_models/versions.json` - **结构**: 一个JSON对象,`key` 是不包含版本号的标识符,`value` 是该标识符下最新的版本号(整数)。 - **Key**: `{prefix_core}_{model_type}` (例如: `product_P001_all_mlstm`) - **Value**: `Integer` #### **`versions.json` 示例:** ```json { "product_P001_all_mlstm": 2, "store_S001_P002_transformer": 1 } ``` #### **版本管理流程:** 1. **获取新版本**: 开始训练前,构建 `key`。读取 `versions.json`,找到对应 `key` 的 `value`。新版本号为 `value + 1` (若key不存在,则为 `1`)。 2. **更新版本**: 训练成功后,将新的版本号写回到 `versions.json`。此过程**必须使用文件锁**以防止并发冲突。 调试完成药品预测和店铺预测
351 lines
15 KiB
Python
351 lines
15 KiB
Python
"""
|
||
药店销售预测系统 - 模型预测函数
|
||
"""
|
||
|
||
import os
|
||
import torch
|
||
import pandas as pd
|
||
import numpy as np
|
||
from datetime import datetime, timedelta
|
||
import matplotlib.pyplot as plt
|
||
from sklearn.preprocessing import MinMaxScaler
|
||
import sklearn.preprocessing._data # 添加这一行以支持MinMaxScaler的反序列化
|
||
|
||
from models.transformer_model import TimeSeriesTransformer
|
||
from models.slstm_model import sLSTM as ScalarLSTM
|
||
from models.mlstm_model import MLSTMTransformer as MatrixLSTM
|
||
from models.kan_model import KANForecaster
|
||
from models.tcn_model import TCNForecaster
|
||
from models.optimized_kan_forecaster import OptimizedKANForecaster
|
||
|
||
from analysis.trend_analysis import analyze_prediction_result
|
||
from utils.visualization import plot_prediction_results
|
||
from utils.multi_store_data_utils import get_store_product_sales_data, aggregate_multi_store_data
|
||
from core.config import DEVICE
|
||
from utils.file_save import ModelPathManager
|
||
|
||
def load_model_and_predict(product_id, model_type, model_path=None, store_id=None, future_days=7, start_date=None, analyze_result=False, version=None, training_mode='product', **kwargs):
|
||
"""
|
||
加载已训练的模型并进行预测
|
||
|
||
参数:
|
||
product_id: 产品ID
|
||
model_type: 模型类型 ('transformer', 'mlstm', 'kan', 'tcn', 'optimized_kan')
|
||
model_path: 模型的完整文件路径
|
||
store_id: 店铺ID,为None时使用全局模型
|
||
future_days: 预测未来天数
|
||
start_date: 预测起始日期,如果为None则使用最后一个已知日期
|
||
analyze_result: 是否分析预测结果
|
||
version: 模型版本
|
||
|
||
返回:
|
||
预测结果和分析(如果analyze_result为True)
|
||
"""
|
||
try:
|
||
print(f"尝试加载模型文件: {model_path}")
|
||
|
||
# 如果没有提供 model_path,则使用 ModelPathManager 动态生成
|
||
if not model_path:
|
||
if version is None:
|
||
raise ValueError("使用动态路径加载时必须提供 'version'。")
|
||
|
||
path_manager = ModelPathManager()
|
||
# 传递所有必要的参数以重构路径
|
||
path_params = {
|
||
'product_id': product_id,
|
||
'store_id': store_id,
|
||
**kwargs
|
||
}
|
||
model_path = path_manager.get_model_path_for_prediction(
|
||
training_mode=training_mode,
|
||
model_type=model_type,
|
||
version=version,
|
||
**path_params
|
||
)
|
||
|
||
if not model_path or not os.path.exists(model_path):
|
||
print(f"模型文件 {model_path} 不存在或无法生成。")
|
||
return None
|
||
|
||
# 加载销售数据(支持多店铺)
|
||
try:
|
||
if store_id:
|
||
# 加载特定店铺的数据
|
||
product_df = get_store_product_sales_data(
|
||
store_id,
|
||
product_id,
|
||
None # 使用默认数据路径
|
||
)
|
||
store_name = product_df['store_name'].iloc[0] if 'store_name' in product_df.columns else f"店铺{store_id}"
|
||
prediction_scope = f"店铺 '{store_name}' ({store_id})"
|
||
else:
|
||
# 聚合所有店铺的数据进行预测
|
||
product_df = aggregate_multi_store_data(
|
||
product_id,
|
||
aggregation_method='sum',
|
||
file_path=None # 使用默认数据路径
|
||
)
|
||
prediction_scope = "全部店铺(聚合数据)"
|
||
except Exception as e:
|
||
print(f"多店铺数据加载失败,尝试使用原始数据格式: {e}")
|
||
# 后向兼容:尝试加载原始数据格式
|
||
try:
|
||
from core.config import DEFAULT_DATA_PATH
|
||
from utils.multi_store_data_utils import load_multi_store_data
|
||
df = load_multi_store_data(DEFAULT_DATA_PATH)
|
||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||
if store_id:
|
||
print(f"警告:原始数据不支持店铺过滤,将使用所有数据预测")
|
||
prediction_scope = "默认数据"
|
||
except Exception as e2:
|
||
print(f"加载产品数据失败: {str(e2)}")
|
||
return None
|
||
|
||
if product_df.empty:
|
||
print(f"产品 {product_id} 没有销售数据")
|
||
return None
|
||
|
||
product_name = product_df['product_name'].iloc[0]
|
||
print(f"使用 {model_type} 模型预测产品 '{product_name}' (ID: {product_id}) 的未来 {future_days} 天销量")
|
||
print(f"预测范围: {prediction_scope}")
|
||
|
||
# 添加安全的全局变量以支持MinMaxScaler的反序列化
|
||
try:
|
||
torch.serialization.add_safe_globals([sklearn.preprocessing._data.MinMaxScaler])
|
||
except Exception as e:
|
||
print(f"添加安全全局变量失败,但这可能不影响模型加载: {str(e)}")
|
||
|
||
# 加载模型和配置
|
||
try:
|
||
# 首先尝试使用weights_only=False加载
|
||
try:
|
||
print("尝试使用 weights_only=False 加载模型")
|
||
checkpoint = torch.load(model_path, map_location=DEVICE, weights_only=False)
|
||
except Exception as e:
|
||
print(f"使用weights_only=False加载失败: {str(e)}")
|
||
print("尝试使用默认参数加载模型")
|
||
checkpoint = torch.load(model_path, map_location=DEVICE)
|
||
|
||
print(f"模型加载成功,检查checkpoint类型: {type(checkpoint)}")
|
||
if isinstance(checkpoint, dict):
|
||
print(f"checkpoint包含的键: {list(checkpoint.keys())}")
|
||
else:
|
||
print(f"checkpoint不是字典类型,而是: {type(checkpoint)}")
|
||
return None
|
||
except Exception as e:
|
||
print(f"加载模型失败: {str(e)}")
|
||
return None
|
||
|
||
# 检查并获取配置
|
||
if 'config' not in checkpoint:
|
||
print("模型文件中没有配置信息")
|
||
return None
|
||
|
||
config = checkpoint['config']
|
||
print(f"模型配置: {config}")
|
||
|
||
# 检查并获取缩放器
|
||
if 'scaler_X' not in checkpoint or 'scaler_y' not in checkpoint:
|
||
print("模型文件中没有缩放器信息")
|
||
return None
|
||
|
||
scaler_X = checkpoint['scaler_X']
|
||
scaler_y = checkpoint['scaler_y']
|
||
|
||
# 创建模型实例
|
||
try:
|
||
if model_type == 'transformer':
|
||
model = TimeSeriesTransformer(
|
||
num_features=config['input_dim'],
|
||
d_model=config['hidden_size'],
|
||
nhead=config['num_heads'],
|
||
num_encoder_layers=config['num_layers'],
|
||
dim_feedforward=config['hidden_size'] * 2,
|
||
dropout=config['dropout'],
|
||
output_sequence_length=config['output_dim'],
|
||
seq_length=config['sequence_length'],
|
||
batch_size=32
|
||
).to(DEVICE)
|
||
elif model_type == 'slstm':
|
||
model = ScalarLSTM(
|
||
input_dim=config['input_dim'],
|
||
hidden_dim=config['hidden_size'],
|
||
output_dim=config['output_dim'],
|
||
num_layers=config['num_layers'],
|
||
dropout=config['dropout']
|
||
).to(DEVICE)
|
||
elif model_type == 'mlstm':
|
||
# 获取配置参数,如果不存在则使用默认值
|
||
embed_dim = config.get('embed_dim', 32)
|
||
dense_dim = config.get('dense_dim', 32)
|
||
num_heads = config.get('num_heads', 4)
|
||
num_blocks = config.get('num_blocks', 3)
|
||
|
||
model = MatrixLSTM(
|
||
num_features=config['input_dim'],
|
||
hidden_size=config['hidden_size'],
|
||
mlstm_layers=config['num_layers'],
|
||
embed_dim=embed_dim,
|
||
dense_dim=dense_dim,
|
||
num_heads=num_heads,
|
||
dropout_rate=config['dropout'],
|
||
num_blocks=num_blocks,
|
||
output_sequence_length=config['output_dim']
|
||
).to(DEVICE)
|
||
elif model_type == 'kan':
|
||
model = KANForecaster(
|
||
input_features=config['input_dim'],
|
||
hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']],
|
||
output_sequence_length=config['output_dim']
|
||
).to(DEVICE)
|
||
elif model_type == 'optimized_kan':
|
||
model = OptimizedKANForecaster(
|
||
input_features=config['input_dim'],
|
||
hidden_sizes=[config['hidden_size'], config['hidden_size']*2, config['hidden_size']],
|
||
output_sequence_length=config['output_dim']
|
||
).to(DEVICE)
|
||
elif model_type == 'tcn':
|
||
model = TCNForecaster(
|
||
num_features=config['input_dim'],
|
||
output_sequence_length=config['output_dim'],
|
||
num_channels=[config['hidden_size']] * config['num_layers'],
|
||
kernel_size=3,
|
||
dropout=config['dropout']
|
||
).to(DEVICE)
|
||
else:
|
||
print(f"不支持的模型类型: {model_type}")
|
||
return None
|
||
|
||
print(f"模型实例创建成功: {type(model)}")
|
||
except Exception as e:
|
||
print(f"创建模型实例失败: {str(e)}")
|
||
return None
|
||
|
||
# 加载模型参数
|
||
try:
|
||
model.load_state_dict(checkpoint['model_state_dict'])
|
||
model.eval()
|
||
print("模型参数加载成功")
|
||
except Exception as e:
|
||
print(f"加载模型参数失败: {str(e)}")
|
||
return None
|
||
|
||
# 准备输入数据
|
||
try:
|
||
features = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||
sequence_length = config['sequence_length']
|
||
|
||
# 获取最近的sequence_length天数据作为输入
|
||
recent_data = product_df.iloc[-sequence_length:].copy()
|
||
|
||
# 如果指定了起始日期,则使用该日期之后的数据
|
||
if start_date:
|
||
if isinstance(start_date, str):
|
||
start_date = datetime.strptime(start_date, '%Y-%m-%d')
|
||
recent_data = product_df[product_df['date'] >= start_date].iloc[:sequence_length].copy()
|
||
if len(recent_data) < sequence_length:
|
||
print(f"警告: 从指定日期 {start_date} 开始的数据少于所需的 {sequence_length} 天")
|
||
# 补充数据
|
||
missing_days = sequence_length - len(recent_data)
|
||
additional_data = product_df[product_df['date'] < start_date].iloc[-missing_days:].copy()
|
||
recent_data = pd.concat([additional_data, recent_data]).reset_index(drop=True)
|
||
|
||
print(f"输入数据准备完成,形状: {recent_data.shape}")
|
||
except Exception as e:
|
||
print(f"准备输入数据失败: {str(e)}")
|
||
return None
|
||
|
||
# 归一化输入数据
|
||
try:
|
||
X = recent_data[features].values
|
||
X_scaled = scaler_X.transform(X)
|
||
|
||
# 转换为模型输入格式
|
||
X_input = torch.tensor(X_scaled.reshape(1, sequence_length, -1), dtype=torch.float32).to(DEVICE)
|
||
print(f"输入张量准备完成,形状: {X_input.shape}")
|
||
except Exception as e:
|
||
print(f"归一化输入数据失败: {str(e)}")
|
||
return None
|
||
|
||
# 预测
|
||
try:
|
||
with torch.no_grad():
|
||
y_pred_scaled = model(X_input).cpu().numpy()
|
||
print(f"原始预测输出形状: {y_pred_scaled.shape}")
|
||
|
||
# 处理TCN、Transformer、mLSTM和KAN模型的输出,确保形状正确
|
||
if model_type in ['tcn', 'transformer', 'mlstm', 'kan', 'optimized_kan'] and len(y_pred_scaled.shape) == 3:
|
||
y_pred_scaled = y_pred_scaled.squeeze(-1)
|
||
print(f"处理后的预测输出形状: {y_pred_scaled.shape}")
|
||
|
||
# 反归一化预测结果
|
||
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
|
||
print(f"反归一化后的预测结果: {y_pred}")
|
||
|
||
# 生成预测日期
|
||
last_date = recent_data['date'].iloc[-1]
|
||
pred_dates = [(last_date + timedelta(days=i+1)) for i in range(len(y_pred))]
|
||
print(f"预测日期: {pred_dates}")
|
||
except Exception as e:
|
||
print(f"执行预测失败: {str(e)}")
|
||
return None
|
||
|
||
# 创建预测结果DataFrame
|
||
try:
|
||
predictions_df = pd.DataFrame({
|
||
'date': pred_dates,
|
||
'sales': y_pred # 使用sales字段名而不是predicted_sales,以便与历史数据兼容
|
||
})
|
||
print(f"预测结果DataFrame创建成功,形状: {predictions_df.shape}")
|
||
except Exception as e:
|
||
print(f"创建预测结果DataFrame失败: {str(e)}")
|
||
return None
|
||
|
||
# 绘制预测结果
|
||
try:
|
||
plt.figure(figsize=(12, 6))
|
||
plt.plot(product_df['date'], product_df['sales'], 'b-', label='历史销量')
|
||
plt.plot(predictions_df['date'], predictions_df['sales'], 'r--', label='预测销量')
|
||
plt.title(f'{product_name} - {model_type}模型销量预测')
|
||
plt.xlabel('日期')
|
||
plt.ylabel('销量')
|
||
plt.legend()
|
||
plt.grid(True)
|
||
plt.xticks(rotation=45)
|
||
plt.tight_layout()
|
||
|
||
# 保存图像
|
||
plt.savefig(f'{product_id}_{model_type}_prediction.png')
|
||
plt.close()
|
||
|
||
print(f"预测结果已保存到 {product_id}_{model_type}_prediction.png")
|
||
except Exception as e:
|
||
print(f"绘制预测结果图表失败: {str(e)}")
|
||
# 这个错误不影响主要功能,继续执行
|
||
|
||
# 分析预测结果
|
||
analysis = None
|
||
if analyze_result:
|
||
try:
|
||
analysis = analyze_prediction_result(product_id, model_type, y_pred, X)
|
||
print("\n预测结果分析:")
|
||
if analysis and 'explanation' in analysis:
|
||
print(analysis['explanation'])
|
||
else:
|
||
print("分析结果不包含explanation字段")
|
||
except Exception as e:
|
||
print(f"分析预测结果失败: {str(e)}")
|
||
# 分析失败不影响主要功能,继续执行
|
||
|
||
return {
|
||
'product_id': product_id,
|
||
'product_name': product_name,
|
||
'model_type': model_type,
|
||
'predictions': predictions_df,
|
||
'analysis': analysis
|
||
}
|
||
except Exception as e:
|
||
print(f"预测过程中出现未捕获的异常: {str(e)}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return None |