ShopTRAINING/server/models/model_manager.py

708 lines
27 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import torch
import glob
import pandas as pd
import matplotlib.pyplot as plt
from datetime import datetime
import json
import shutil
from .utils import get_device, to_device
from .mlstm_model import MLSTMTransformer
from .transformer_model import TimeSeriesTransformer
from .kan_model import KANForecaster
class ModelManager:
"""
模型管理类:负责模型的保存、加载、列出和删除等操作
"""
def __init__(self, models_dir='models'):
"""
初始化模型管理器
参数:
models_dir: 模型存储目录
"""
self.models_dir = models_dir
self._ensure_model_dir()
# 模型类型映射
self.model_types = {
'mlstm': MLSTMTransformer,
'transformer': TimeSeriesTransformer,
'kan': KANForecaster
}
def _ensure_model_dir(self):
"""确保模型目录存在"""
if not os.path.exists(self.models_dir):
try:
os.makedirs(self.models_dir, exist_ok=True)
print(f"创建模型目录: {os.path.abspath(self.models_dir)}")
except Exception as e:
print(f"创建模型目录失败: {str(e)}")
raise
def save_model(self, model, model_type, product_id, optimizer=None,
train_loss=None, test_loss=None, scaler_X=None,
scaler_y=None, features=None, look_back=None, T=None,
metrics=None, version=None):
"""
保存模型及其相关信息
参数:
model: 训练好的模型
model_type: 模型类型 ('mlstm', 'transformer', 'kan')
product_id: 产品ID
optimizer: 优化器
train_loss: 训练损失历史
test_loss: 测试损失历史
scaler_X: 特征缩放器
scaler_y: 目标缩放器
features: 使用的特征列表
look_back: 回看天数
T: 预测天数
metrics: 模型评估指标
version: 模型版本(可选),如果不提供则使用时间戳
"""
self._ensure_model_dir()
# 设置版本
if version is None:
version = datetime.now().strftime("%Y%m%d_%H%M%S")
# 设置文件名
model_filename = f"{product_id}_{model_type}_model_v{version}.pt"
model_path = os.path.join(self.models_dir, model_filename)
# 准备要保存的数据
save_dict = {
'model_state_dict': model.state_dict(),
'model_type': model_type,
'product_id': product_id,
'version': version,
'created_at': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
'features': features,
'look_back': look_back,
'T': T
}
# 添加可选数据
if optimizer is not None:
save_dict['optimizer_state_dict'] = optimizer.state_dict()
if train_loss is not None:
save_dict['train_loss'] = train_loss
if test_loss is not None:
save_dict['test_loss'] = test_loss
if scaler_X is not None:
save_dict['scaler_X'] = scaler_X
if scaler_y is not None:
save_dict['scaler_y'] = scaler_y
if metrics is not None:
save_dict['metrics'] = metrics
try:
# 保存模型
torch.save(save_dict, model_path)
print(f"模型已成功保存到 {os.path.abspath(model_path)}")
# 保存模型的元数据到JSON文件便于查询
meta_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_meta_v{version}.json")
meta_dict = {k: str(v) if not isinstance(v, (int, float, bool, list, dict, type(None))) else v
for k, v in save_dict.items() if k != 'model_state_dict' and
k != 'optimizer_state_dict' and k != 'scaler_X' and k != 'scaler_y'}
# 如果有评估指标,添加到元数据
if metrics is not None:
meta_dict['metrics'] = metrics
with open(meta_path, 'w') as f:
json.dump(meta_dict, f, indent=4)
return model_path
except Exception as e:
print(f"保存模型时出错: {str(e)}")
raise
def load_model(self, product_id, model_type='mlstm', version=None, device=None):
"""
加载指定的模型
参数:
product_id: 产品ID
model_type: 模型类型 ('mlstm', 'transformer', 'kan')
version: 模型版本,如果不指定则加载最新版本
device: 设备 (cuda/cpu)
返回:
model: 加载的模型
checkpoint: 包含模型信息的字典
"""
if device is None:
device = get_device()
# 查找匹配的模型文件
if version is None:
# 查找最新版本
pattern = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v*.pt")
model_files = glob.glob(pattern)
if not model_files:
print(f"错误: 未找到产品 {product_id}{model_type} 模型文件")
return None, None
# 按照文件修改时间排序,获取最新的
model_path = max(model_files, key=os.path.getmtime)
else:
# 指定版本
model_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v{version}.pt")
if not os.path.exists(model_path):
print(f"错误: 未找到产品 {product_id}{model_type} 模型版本 {version}")
return None, None
try:
# 加载模型
checkpoint = torch.load(model_path, map_location=device)
# 创建模型实例
if model_type == 'mlstm':
model = MLSTMTransformer(
num_features=len(checkpoint['features']),
hidden_size=128,
mlstm_layers=1,
embed_dim=32,
dense_dim=32,
num_heads=4,
dropout_rate=0.1,
num_blocks=3,
output_sequence_length=checkpoint['T']
)
elif model_type == 'transformer':
model = TimeSeriesTransformer(
num_features=len(checkpoint['features']),
d_model=32,
nhead=4,
num_encoder_layers=3,
dim_feedforward=32,
dropout=0.1,
output_sequence_length=checkpoint['T']
)
elif model_type == 'kan':
model = KANForecaster(
input_features=len(checkpoint['features']),
hidden_sizes=[64, 128, 64],
output_size=1,
grid_size=5,
spline_order=3,
dropout_rate=0.1,
output_sequence_length=checkpoint['T']
)
else:
raise ValueError(f"不支持的模型类型: {model_type}")
# 加载模型参数
model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)
model.eval()
print(f"模型已从 {os.path.abspath(model_path)} 成功加载")
return model, checkpoint
except Exception as e:
print(f"加载模型时出错: {str(e)}")
raise
def list_models(self, product_id=None, model_type=None):
"""
列出所有保存的模型
参数:
product_id: 按产品ID筛选 (可选)
model_type: 按模型类型筛选 (可选)
返回:
models_list: 模型信息列表
"""
self._ensure_model_dir()
# 构建搜索模式
if product_id and model_type:
pattern = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v*.pt")
elif product_id:
pattern = os.path.join(self.models_dir, f"{product_id}_*_model_v*.pt")
elif model_type:
pattern = os.path.join(self.models_dir, f"*_{model_type}_model_v*.pt")
else:
pattern = os.path.join(self.models_dir, "*_model_v*.pt")
model_files = glob.glob(pattern)
if not model_files:
print("未找到匹配的模型文件")
return []
# 收集模型信息
models_list = []
for model_path in model_files:
try:
# 从文件名解析信息
filename = os.path.basename(model_path)
parts = filename.split('_')
if len(parts) < 4:
continue
product_id = parts[0]
model_type = parts[1]
version = parts[-1].replace('model_v', '').replace('.pt', '')
# 查找对应的元数据文件
meta_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_meta_v{version}.json")
model_info = {
'product_id': product_id,
'model_type': model_type,
'version': version,
'file_path': model_path,
'created_at': datetime.fromtimestamp(os.path.getctime(model_path)).strftime("%Y-%m-%d %H:%M:%S"),
'file_size': f"{os.path.getsize(model_path) / (1024 * 1024):.2f} MB"
}
# 如果有元数据文件,添加更多信息
if os.path.exists(meta_path):
with open(meta_path, 'r') as f:
meta = json.load(f)
model_info.update(meta)
models_list.append(model_info)
except Exception as e:
print(f"解析模型文件 {model_path} 时出错: {str(e)}")
# 按创建时间排序
models_list.sort(key=lambda x: x['created_at'], reverse=True)
return models_list
def delete_model(self, product_id, model_type, version=None):
"""
删除指定的模型
参数:
product_id: 产品ID
model_type: 模型类型
version: 模型版本,如果不指定则删除所有版本
返回:
success: 是否成功删除
"""
self._ensure_model_dir()
if version:
# 删除特定版本
model_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v{version}.pt")
meta_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_meta_v{version}.json")
if not os.path.exists(model_path):
print(f"错误: 未找到产品 {product_id}{model_type} 模型版本 {version}")
return False
try:
os.remove(model_path)
if os.path.exists(meta_path):
os.remove(meta_path)
print(f"已删除产品 {product_id}{model_type} 模型版本 {version}")
return True
except Exception as e:
print(f"删除模型时出错: {str(e)}")
return False
else:
# 删除所有版本
pattern = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v*.pt")
meta_pattern = os.path.join(self.models_dir, f"{product_id}_{model_type}_meta_v*.json")
model_files = glob.glob(pattern)
meta_files = glob.glob(meta_pattern)
if not model_files:
print(f"错误: 未找到产品 {product_id}{model_type} 模型文件")
return False
try:
for file_path in model_files:
os.remove(file_path)
for file_path in meta_files:
os.remove(file_path)
print(f"已删除产品 {product_id} 的所有 {model_type} 模型")
return True
except Exception as e:
print(f"删除模型时出错: {str(e)}")
return False
def get_model_details(self, product_id, model_type, version=None):
"""
获取模型的详细信息
参数:
product_id: 产品ID
model_type: 模型类型
version: 模型版本,如果不指定则获取最新版本
返回:
details: 模型详细信息字典
"""
# 查找匹配的模型文件
if version is None:
# 查找最新版本
pattern = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v*.pt")
model_files = glob.glob(pattern)
if not model_files:
print(f"错误: 未找到产品 {product_id}{model_type} 模型文件")
return None
# 按照文件修改时间排序,获取最新的
model_path = max(model_files, key=os.path.getmtime)
# 从文件名解析版本
filename = os.path.basename(model_path)
version = filename.split('_')[-1].replace('model_v', '').replace('.pt', '')
# 查找元数据文件
meta_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_meta_v{version}.json")
if not os.path.exists(meta_path):
print(f"错误: 未找到产品 {product_id}{model_type} 模型版本 {version} 的元数据")
return None
try:
with open(meta_path, 'r') as f:
details = json.load(f)
# 添加文件路径
model_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v{version}.pt")
details['file_path'] = model_path
details['file_size'] = f"{os.path.getsize(model_path) / (1024 * 1024):.2f} MB"
return details
except Exception as e:
print(f"获取模型详情时出错: {str(e)}")
return None
def predict_with_model(self, product_id, model_type='mlstm', version=None, future_days=7,
product_df=None, features=None, visualize=True, save_results=True):
"""
使用指定的模型进行预测
参数:
product_id: 产品ID
model_type: 模型类型 ('mlstm', 'transformer', 'kan')
version: 模型版本,如果不指定则使用最新版本
future_days: 要预测的未来天数
product_df: 产品数据DataFrame
features: 特征列表
visualize: 是否可视化结果
save_results: 是否保存结果
返回:
predictions_df: 预测结果DataFrame
"""
# 获取设备
device = get_device()
print(f"使用设备: {device} 进行预测")
# 加载模型
model, checkpoint = self.load_model(product_id, model_type, version, device)
if model is None or checkpoint is None:
return None
# 如果没有提供产品数据则从Excel文件加载
if product_df is None:
try:
df = pd.read_excel('pharmacy_sales.xlsx')
product_df = df[df['product_id'] == product_id].sort_values('date')
except Exception as e:
print(f"加载产品数据时出错: {str(e)}")
return None
product_name = product_df['product_name'].iloc[0]
# 获取模型参数
features = checkpoint['features']
look_back = checkpoint['look_back']
T = checkpoint['T']
scaler_X = checkpoint['scaler_X']
scaler_y = checkpoint['scaler_y']
# 获取最近的look_back天数据
last_data = product_df[features].values[-look_back:]
last_data_scaled = scaler_X.transform(last_data)
# 准备输入数据
X_input = torch.Tensor(last_data_scaled).unsqueeze(0) # 添加批次维度
X_input = X_input.to(device) # 移动到设备上
# 进行预测
with torch.no_grad():
y_pred_scaled = model(X_input).squeeze(0).cpu().numpy() # 返回到CPU并转换为numpy
# 反归一化预测结果
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
# 创建预测日期范围
last_date = product_df['date'].iloc[-1]
future_dates = pd.date_range(start=last_date + pd.Timedelta(days=1), periods=T, freq='D')
# 创建预测结果DataFrame
predictions_df = pd.DataFrame({
'date': future_dates,
'product_id': product_id,
'product_name': product_name,
'predicted_sales': y_pred
})
print(f"\n{product_name} 未来 {T} 天销售预测 (使用{model_type.upper()}模型):")
print(predictions_df[['date', 'predicted_sales']])
# 可视化预测结果
if visualize:
plt.figure(figsize=(12, 6))
# 显示历史数据和预测数据
history_days = 30 # 显示最近30天的历史数据
history_dates = product_df['date'].iloc[-history_days:].values
history_sales = product_df['sales'].iloc[-history_days:].values
plt.plot(history_dates, history_sales, 'b-', label='历史销量')
plt.plot(future_dates, y_pred, 'r--', label=f'{model_type.upper()}预测销量')
plt.title(f'{product_name} - {model_type.upper()}销量预测 (未来{T}天)')
plt.xlabel('日期')
plt.ylabel('销量')
plt.legend()
plt.grid(True)
plt.xticks(rotation=45)
plt.tight_layout()
# 保存和显示图表
forecast_chart = f'{product_id}_{model_type}_forecast.png'
plt.savefig(forecast_chart)
print(f"预测图表已保存为: {forecast_chart}")
# 保存预测结果到CSV
if save_results:
forecast_csv = f'{product_id}_{model_type}_forecast.csv'
predictions_df.to_csv(forecast_csv, index=False)
print(f"预测结果已保存到: {forecast_csv}")
return predictions_df
def compare_models(self, product_id, model_types=None, versions=None, product_df=None, visualize=True):
"""
比较不同模型的预测结果
参数:
product_id: 产品ID
model_types: 要比较的模型类型列表
versions: 对应的模型版本列表,如果不指定则使用最新版本
product_df: 产品数据DataFrame
visualize: 是否可视化结果
返回:
比较结果DataFrame
"""
if model_types is None:
model_types = ['mlstm', 'transformer', 'kan']
if versions is None:
versions = [None] * len(model_types)
if len(versions) != len(model_types):
print("错误: 模型类型和版本列表长度不匹配")
return None
# 如果没有提供产品数据则从Excel文件加载
if product_df is None:
try:
df = pd.read_excel('pharmacy_sales.xlsx')
product_df = df[df['product_id'] == product_id].sort_values('date')
except Exception as e:
print(f"加载产品数据时出错: {str(e)}")
return None
product_name = product_df['product_name'].iloc[0]
# 存储所有模型的预测结果
predictions = {}
# 对每个模型进行预测
for i, model_type in enumerate(model_types):
version = versions[i]
try:
pred_df = self.predict_with_model(
product_id,
model_type=model_type,
version=version,
product_df=product_df,
visualize=False,
save_results=False
)
if pred_df is not None:
predictions[model_type] = pred_df
except Exception as e:
print(f"{model_type} 模型预测出错: {str(e)}")
if not predictions:
print("没有成功的预测结果")
return None
# 合并预测结果
result_df = predictions[list(predictions.keys())[0]][['date', 'product_id', 'product_name']].copy()
for model_type, pred_df in predictions.items():
result_df[f'{model_type}_prediction'] = pred_df['predicted_sales'].values
# 可视化比较结果
if visualize and len(predictions) > 0:
plt.figure(figsize=(12, 6))
# 显示历史数据
history_days = 30 # 显示最近30天的历史数据
history_dates = product_df['date'].iloc[-history_days:].values
history_sales = product_df['sales'].iloc[-history_days:].values
plt.plot(history_dates, history_sales, 'k-', label='历史销量')
# 显示预测数据
colors = ['r', 'g', 'b', 'c', 'm', 'y']
future_dates = result_df['date'].values
for i, (model_type, pred_df) in enumerate(predictions.items()):
color = colors[i % len(colors)]
plt.plot(future_dates, pred_df['predicted_sales'].values,
f'{color}--', label=f'{model_type.upper()}预测')
plt.title(f'{product_name} - 不同模型预测结果比较')
plt.xlabel('日期')
plt.ylabel('销量')
plt.legend()
plt.grid(True)
plt.xticks(rotation=45)
plt.tight_layout()
# 保存和显示图表
compare_chart = f'{product_id}_model_comparison.png'
plt.savefig(compare_chart)
print(f"比较图表已保存为: {compare_chart}")
# 保存比较结果到CSV
compare_csv = f'{product_id}_model_comparison.csv'
result_df.to_csv(compare_csv, index=False)
print(f"比较结果已保存到: {compare_csv}")
return result_df
def export_model(self, product_id, model_type, version=None, export_dir='exported_models'):
"""
导出模型到指定目录
参数:
product_id: 产品ID
model_type: 模型类型
version: 模型版本,如果不指定则导出最新版本
export_dir: 导出目录
返回:
export_path: 导出的文件路径
"""
# 确保导出目录存在
if not os.path.exists(export_dir):
os.makedirs(export_dir, exist_ok=True)
# 查找匹配的模型文件
if version is None:
# 查找最新版本
pattern = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v*.pt")
model_files = glob.glob(pattern)
if not model_files:
print(f"错误: 未找到产品 {product_id}{model_type} 模型文件")
return None
# 按照文件修改时间排序,获取最新的
model_path = max(model_files, key=os.path.getmtime)
# 从文件名解析版本
filename = os.path.basename(model_path)
version = filename.split('_')[-1].replace('model_v', '').replace('.pt', '')
else:
model_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v{version}.pt")
if not os.path.exists(model_path):
print(f"错误: 未找到产品 {product_id}{model_type} 模型版本 {version}")
return None
# 元数据文件
meta_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_meta_v{version}.json")
# 导出路径
export_model_path = os.path.join(export_dir, f"{product_id}_{model_type}_model_v{version}.pt")
export_meta_path = os.path.join(export_dir, f"{product_id}_{model_type}_meta_v{version}.json")
try:
# 复制文件
shutil.copy2(model_path, export_model_path)
if os.path.exists(meta_path):
shutil.copy2(meta_path, export_meta_path)
print(f"模型已导出到 {os.path.abspath(export_model_path)}")
return export_model_path
except Exception as e:
print(f"导出模型时出错: {str(e)}")
return None
def import_model(self, import_file, overwrite=False):
"""
导入模型文件
参数:
import_file: 要导入的模型文件路径
overwrite: 如果存在同名文件是否覆盖
返回:
import_path: 导入后的文件路径
"""
self._ensure_model_dir()
if not os.path.exists(import_file):
print(f"错误: 导入文件 {import_file} 不存在")
return None
# 获取文件名
filename = os.path.basename(import_file)
# 目标路径
target_path = os.path.join(self.models_dir, filename)
# 检查是否存在同名文件
if os.path.exists(target_path) and not overwrite:
print(f"错误: 目标文件 {target_path} 已存在如需覆盖请设置overwrite=True")
return None
try:
# 复制文件
shutil.copy2(import_file, target_path)
# 如果有对应的元数据文件,也一并导入
meta_filename = filename.replace('_model_v', '_meta_v')
meta_import_file = import_file.replace('_model_v', '_meta_v').replace('.pt', '.json')
meta_target_path = os.path.join(self.models_dir, meta_filename.replace('.pt', '.json'))
if os.path.exists(meta_import_file):
shutil.copy2(meta_import_file, meta_target_path)
print(f"模型已导入到 {os.path.abspath(target_path)}")
return target_path
except Exception as e:
print(f"导入模型时出错: {str(e)}")
return None