708 lines
27 KiB
Python
708 lines
27 KiB
Python
import os
|
||
import torch
|
||
import glob
|
||
import pandas as pd
|
||
import matplotlib.pyplot as plt
|
||
from datetime import datetime
|
||
import json
|
||
import shutil
|
||
|
||
from .utils import get_device, to_device
|
||
from .mlstm_model import MLSTMTransformer
|
||
from .transformer_model import TimeSeriesTransformer
|
||
from .kan_model import KANForecaster
|
||
|
||
class ModelManager:
|
||
"""
|
||
模型管理类:负责模型的保存、加载、列出和删除等操作
|
||
"""
|
||
|
||
def __init__(self, models_dir='models'):
|
||
"""
|
||
初始化模型管理器
|
||
|
||
参数:
|
||
models_dir: 模型存储目录
|
||
"""
|
||
self.models_dir = models_dir
|
||
self._ensure_model_dir()
|
||
|
||
# 模型类型映射
|
||
self.model_types = {
|
||
'mlstm': MLSTMTransformer,
|
||
'transformer': TimeSeriesTransformer,
|
||
'kan': KANForecaster
|
||
}
|
||
|
||
def _ensure_model_dir(self):
|
||
"""确保模型目录存在"""
|
||
if not os.path.exists(self.models_dir):
|
||
try:
|
||
os.makedirs(self.models_dir, exist_ok=True)
|
||
print(f"创建模型目录: {os.path.abspath(self.models_dir)}")
|
||
except Exception as e:
|
||
print(f"创建模型目录失败: {str(e)}")
|
||
raise
|
||
|
||
def save_model(self, model, model_type, product_id, optimizer=None,
|
||
train_loss=None, test_loss=None, scaler_X=None,
|
||
scaler_y=None, features=None, look_back=None, T=None,
|
||
metrics=None, version=None):
|
||
"""
|
||
保存模型及其相关信息
|
||
|
||
参数:
|
||
model: 训练好的模型
|
||
model_type: 模型类型 ('mlstm', 'transformer', 'kan')
|
||
product_id: 产品ID
|
||
optimizer: 优化器
|
||
train_loss: 训练损失历史
|
||
test_loss: 测试损失历史
|
||
scaler_X: 特征缩放器
|
||
scaler_y: 目标缩放器
|
||
features: 使用的特征列表
|
||
look_back: 回看天数
|
||
T: 预测天数
|
||
metrics: 模型评估指标
|
||
version: 模型版本(可选),如果不提供则使用时间戳
|
||
"""
|
||
self._ensure_model_dir()
|
||
|
||
# 设置版本
|
||
if version is None:
|
||
version = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||
|
||
# 设置文件名
|
||
model_filename = f"{product_id}_{model_type}_model_v{version}.pt"
|
||
model_path = os.path.join(self.models_dir, model_filename)
|
||
|
||
# 准备要保存的数据
|
||
save_dict = {
|
||
'model_state_dict': model.state_dict(),
|
||
'model_type': model_type,
|
||
'product_id': product_id,
|
||
'version': version,
|
||
'created_at': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
||
'features': features,
|
||
'look_back': look_back,
|
||
'T': T
|
||
}
|
||
|
||
# 添加可选数据
|
||
if optimizer is not None:
|
||
save_dict['optimizer_state_dict'] = optimizer.state_dict()
|
||
if train_loss is not None:
|
||
save_dict['train_loss'] = train_loss
|
||
if test_loss is not None:
|
||
save_dict['test_loss'] = test_loss
|
||
if scaler_X is not None:
|
||
save_dict['scaler_X'] = scaler_X
|
||
if scaler_y is not None:
|
||
save_dict['scaler_y'] = scaler_y
|
||
if metrics is not None:
|
||
save_dict['metrics'] = metrics
|
||
|
||
try:
|
||
# 保存模型
|
||
torch.save(save_dict, model_path)
|
||
print(f"模型已成功保存到 {os.path.abspath(model_path)}")
|
||
|
||
# 保存模型的元数据到JSON文件,便于查询
|
||
meta_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_meta_v{version}.json")
|
||
meta_dict = {k: str(v) if not isinstance(v, (int, float, bool, list, dict, type(None))) else v
|
||
for k, v in save_dict.items() if k != 'model_state_dict' and
|
||
k != 'optimizer_state_dict' and k != 'scaler_X' and k != 'scaler_y'}
|
||
|
||
# 如果有评估指标,添加到元数据
|
||
if metrics is not None:
|
||
meta_dict['metrics'] = metrics
|
||
|
||
with open(meta_path, 'w') as f:
|
||
json.dump(meta_dict, f, indent=4)
|
||
|
||
return model_path
|
||
except Exception as e:
|
||
print(f"保存模型时出错: {str(e)}")
|
||
raise
|
||
|
||
def load_model(self, product_id, model_type='mlstm', version=None, device=None):
|
||
"""
|
||
加载指定的模型
|
||
|
||
参数:
|
||
product_id: 产品ID
|
||
model_type: 模型类型 ('mlstm', 'transformer', 'kan')
|
||
version: 模型版本,如果不指定则加载最新版本
|
||
device: 设备 (cuda/cpu)
|
||
|
||
返回:
|
||
model: 加载的模型
|
||
checkpoint: 包含模型信息的字典
|
||
"""
|
||
if device is None:
|
||
device = get_device()
|
||
|
||
# 查找匹配的模型文件
|
||
if version is None:
|
||
# 查找最新版本
|
||
pattern = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v*.pt")
|
||
model_files = glob.glob(pattern)
|
||
|
||
if not model_files:
|
||
print(f"错误: 未找到产品 {product_id} 的 {model_type} 模型文件")
|
||
return None, None
|
||
|
||
# 按照文件修改时间排序,获取最新的
|
||
model_path = max(model_files, key=os.path.getmtime)
|
||
else:
|
||
# 指定版本
|
||
model_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v{version}.pt")
|
||
if not os.path.exists(model_path):
|
||
print(f"错误: 未找到产品 {product_id} 的 {model_type} 模型版本 {version}")
|
||
return None, None
|
||
|
||
try:
|
||
# 加载模型
|
||
checkpoint = torch.load(model_path, map_location=device)
|
||
|
||
# 创建模型实例
|
||
if model_type == 'mlstm':
|
||
model = MLSTMTransformer(
|
||
num_features=len(checkpoint['features']),
|
||
hidden_size=128,
|
||
mlstm_layers=1,
|
||
embed_dim=32,
|
||
dense_dim=32,
|
||
num_heads=4,
|
||
dropout_rate=0.1,
|
||
num_blocks=3,
|
||
output_sequence_length=checkpoint['T']
|
||
)
|
||
elif model_type == 'transformer':
|
||
model = TimeSeriesTransformer(
|
||
num_features=len(checkpoint['features']),
|
||
d_model=32,
|
||
nhead=4,
|
||
num_encoder_layers=3,
|
||
dim_feedforward=32,
|
||
dropout=0.1,
|
||
output_sequence_length=checkpoint['T']
|
||
)
|
||
elif model_type == 'kan':
|
||
model = KANForecaster(
|
||
input_features=len(checkpoint['features']),
|
||
hidden_sizes=[64, 128, 64],
|
||
output_size=1,
|
||
grid_size=5,
|
||
spline_order=3,
|
||
dropout_rate=0.1,
|
||
output_sequence_length=checkpoint['T']
|
||
)
|
||
else:
|
||
raise ValueError(f"不支持的模型类型: {model_type}")
|
||
|
||
# 加载模型参数
|
||
model.load_state_dict(checkpoint['model_state_dict'])
|
||
model = model.to(device)
|
||
model.eval()
|
||
|
||
print(f"模型已从 {os.path.abspath(model_path)} 成功加载")
|
||
return model, checkpoint
|
||
except Exception as e:
|
||
print(f"加载模型时出错: {str(e)}")
|
||
raise
|
||
|
||
def list_models(self, product_id=None, model_type=None):
|
||
"""
|
||
列出所有保存的模型
|
||
|
||
参数:
|
||
product_id: 按产品ID筛选 (可选)
|
||
model_type: 按模型类型筛选 (可选)
|
||
|
||
返回:
|
||
models_list: 模型信息列表
|
||
"""
|
||
self._ensure_model_dir()
|
||
|
||
# 构建搜索模式
|
||
if product_id and model_type:
|
||
pattern = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v*.pt")
|
||
elif product_id:
|
||
pattern = os.path.join(self.models_dir, f"{product_id}_*_model_v*.pt")
|
||
elif model_type:
|
||
pattern = os.path.join(self.models_dir, f"*_{model_type}_model_v*.pt")
|
||
else:
|
||
pattern = os.path.join(self.models_dir, "*_model_v*.pt")
|
||
|
||
model_files = glob.glob(pattern)
|
||
|
||
if not model_files:
|
||
print("未找到匹配的模型文件")
|
||
return []
|
||
|
||
# 收集模型信息
|
||
models_list = []
|
||
for model_path in model_files:
|
||
try:
|
||
# 从文件名解析信息
|
||
filename = os.path.basename(model_path)
|
||
parts = filename.split('_')
|
||
if len(parts) < 4:
|
||
continue
|
||
|
||
product_id = parts[0]
|
||
model_type = parts[1]
|
||
version = parts[-1].replace('model_v', '').replace('.pt', '')
|
||
|
||
# 查找对应的元数据文件
|
||
meta_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_meta_v{version}.json")
|
||
|
||
model_info = {
|
||
'product_id': product_id,
|
||
'model_type': model_type,
|
||
'version': version,
|
||
'file_path': model_path,
|
||
'created_at': datetime.fromtimestamp(os.path.getctime(model_path)).strftime("%Y-%m-%d %H:%M:%S"),
|
||
'file_size': f"{os.path.getsize(model_path) / (1024 * 1024):.2f} MB"
|
||
}
|
||
|
||
# 如果有元数据文件,添加更多信息
|
||
if os.path.exists(meta_path):
|
||
with open(meta_path, 'r') as f:
|
||
meta = json.load(f)
|
||
model_info.update(meta)
|
||
|
||
models_list.append(model_info)
|
||
except Exception as e:
|
||
print(f"解析模型文件 {model_path} 时出错: {str(e)}")
|
||
|
||
# 按创建时间排序
|
||
models_list.sort(key=lambda x: x['created_at'], reverse=True)
|
||
|
||
return models_list
|
||
|
||
def delete_model(self, product_id, model_type, version=None):
|
||
"""
|
||
删除指定的模型
|
||
|
||
参数:
|
||
product_id: 产品ID
|
||
model_type: 模型类型
|
||
version: 模型版本,如果不指定则删除所有版本
|
||
|
||
返回:
|
||
success: 是否成功删除
|
||
"""
|
||
self._ensure_model_dir()
|
||
|
||
if version:
|
||
# 删除特定版本
|
||
model_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v{version}.pt")
|
||
meta_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_meta_v{version}.json")
|
||
|
||
if not os.path.exists(model_path):
|
||
print(f"错误: 未找到产品 {product_id} 的 {model_type} 模型版本 {version}")
|
||
return False
|
||
|
||
try:
|
||
os.remove(model_path)
|
||
if os.path.exists(meta_path):
|
||
os.remove(meta_path)
|
||
print(f"已删除产品 {product_id} 的 {model_type} 模型版本 {version}")
|
||
return True
|
||
except Exception as e:
|
||
print(f"删除模型时出错: {str(e)}")
|
||
return False
|
||
else:
|
||
# 删除所有版本
|
||
pattern = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v*.pt")
|
||
meta_pattern = os.path.join(self.models_dir, f"{product_id}_{model_type}_meta_v*.json")
|
||
|
||
model_files = glob.glob(pattern)
|
||
meta_files = glob.glob(meta_pattern)
|
||
|
||
if not model_files:
|
||
print(f"错误: 未找到产品 {product_id} 的 {model_type} 模型文件")
|
||
return False
|
||
|
||
try:
|
||
for file_path in model_files:
|
||
os.remove(file_path)
|
||
|
||
for file_path in meta_files:
|
||
os.remove(file_path)
|
||
|
||
print(f"已删除产品 {product_id} 的所有 {model_type} 模型")
|
||
return True
|
||
except Exception as e:
|
||
print(f"删除模型时出错: {str(e)}")
|
||
return False
|
||
|
||
def get_model_details(self, product_id, model_type, version=None):
|
||
"""
|
||
获取模型的详细信息
|
||
|
||
参数:
|
||
product_id: 产品ID
|
||
model_type: 模型类型
|
||
version: 模型版本,如果不指定则获取最新版本
|
||
|
||
返回:
|
||
details: 模型详细信息字典
|
||
"""
|
||
# 查找匹配的模型文件
|
||
if version is None:
|
||
# 查找最新版本
|
||
pattern = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v*.pt")
|
||
model_files = glob.glob(pattern)
|
||
|
||
if not model_files:
|
||
print(f"错误: 未找到产品 {product_id} 的 {model_type} 模型文件")
|
||
return None
|
||
|
||
# 按照文件修改时间排序,获取最新的
|
||
model_path = max(model_files, key=os.path.getmtime)
|
||
# 从文件名解析版本
|
||
filename = os.path.basename(model_path)
|
||
version = filename.split('_')[-1].replace('model_v', '').replace('.pt', '')
|
||
|
||
# 查找元数据文件
|
||
meta_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_meta_v{version}.json")
|
||
|
||
if not os.path.exists(meta_path):
|
||
print(f"错误: 未找到产品 {product_id} 的 {model_type} 模型版本 {version} 的元数据")
|
||
return None
|
||
|
||
try:
|
||
with open(meta_path, 'r') as f:
|
||
details = json.load(f)
|
||
|
||
# 添加文件路径
|
||
model_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v{version}.pt")
|
||
details['file_path'] = model_path
|
||
details['file_size'] = f"{os.path.getsize(model_path) / (1024 * 1024):.2f} MB"
|
||
|
||
return details
|
||
except Exception as e:
|
||
print(f"获取模型详情时出错: {str(e)}")
|
||
return None
|
||
|
||
def predict_with_model(self, product_id, model_type='mlstm', version=None, future_days=7,
|
||
product_df=None, features=None, visualize=True, save_results=True):
|
||
"""
|
||
使用指定的模型进行预测
|
||
|
||
参数:
|
||
product_id: 产品ID
|
||
model_type: 模型类型 ('mlstm', 'transformer', 'kan')
|
||
version: 模型版本,如果不指定则使用最新版本
|
||
future_days: 要预测的未来天数
|
||
product_df: 产品数据DataFrame
|
||
features: 特征列表
|
||
visualize: 是否可视化结果
|
||
save_results: 是否保存结果
|
||
|
||
返回:
|
||
predictions_df: 预测结果DataFrame
|
||
"""
|
||
# 获取设备
|
||
device = get_device()
|
||
print(f"使用设备: {device} 进行预测")
|
||
|
||
# 加载模型
|
||
model, checkpoint = self.load_model(product_id, model_type, version, device)
|
||
|
||
if model is None or checkpoint is None:
|
||
return None
|
||
|
||
# 如果没有提供产品数据,则从Excel文件加载
|
||
if product_df is None:
|
||
try:
|
||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||
except Exception as e:
|
||
print(f"加载产品数据时出错: {str(e)}")
|
||
return None
|
||
|
||
product_name = product_df['product_name'].iloc[0]
|
||
|
||
# 获取模型参数
|
||
features = checkpoint['features']
|
||
look_back = checkpoint['look_back']
|
||
T = checkpoint['T']
|
||
scaler_X = checkpoint['scaler_X']
|
||
scaler_y = checkpoint['scaler_y']
|
||
|
||
# 获取最近的look_back天数据
|
||
last_data = product_df[features].values[-look_back:]
|
||
last_data_scaled = scaler_X.transform(last_data)
|
||
|
||
# 准备输入数据
|
||
X_input = torch.Tensor(last_data_scaled).unsqueeze(0) # 添加批次维度
|
||
X_input = X_input.to(device) # 移动到设备上
|
||
|
||
# 进行预测
|
||
with torch.no_grad():
|
||
y_pred_scaled = model(X_input).squeeze(0).cpu().numpy() # 返回到CPU并转换为numpy
|
||
|
||
# 反归一化预测结果
|
||
y_pred = scaler_y.inverse_transform(y_pred_scaled.reshape(-1, 1)).flatten()
|
||
|
||
# 创建预测日期范围
|
||
last_date = product_df['date'].iloc[-1]
|
||
future_dates = pd.date_range(start=last_date + pd.Timedelta(days=1), periods=T, freq='D')
|
||
|
||
# 创建预测结果DataFrame
|
||
predictions_df = pd.DataFrame({
|
||
'date': future_dates,
|
||
'product_id': product_id,
|
||
'product_name': product_name,
|
||
'predicted_sales': y_pred
|
||
})
|
||
|
||
print(f"\n{product_name} 未来 {T} 天销售预测 (使用{model_type.upper()}模型):")
|
||
print(predictions_df[['date', 'predicted_sales']])
|
||
|
||
# 可视化预测结果
|
||
if visualize:
|
||
plt.figure(figsize=(12, 6))
|
||
|
||
# 显示历史数据和预测数据
|
||
history_days = 30 # 显示最近30天的历史数据
|
||
history_dates = product_df['date'].iloc[-history_days:].values
|
||
history_sales = product_df['sales'].iloc[-history_days:].values
|
||
|
||
plt.plot(history_dates, history_sales, 'b-', label='历史销量')
|
||
plt.plot(future_dates, y_pred, 'r--', label=f'{model_type.upper()}预测销量')
|
||
|
||
plt.title(f'{product_name} - {model_type.upper()}销量预测 (未来{T}天)')
|
||
plt.xlabel('日期')
|
||
plt.ylabel('销量')
|
||
plt.legend()
|
||
plt.grid(True)
|
||
plt.xticks(rotation=45)
|
||
plt.tight_layout()
|
||
|
||
# 保存和显示图表
|
||
forecast_chart = f'{product_id}_{model_type}_forecast.png'
|
||
plt.savefig(forecast_chart)
|
||
print(f"预测图表已保存为: {forecast_chart}")
|
||
|
||
# 保存预测结果到CSV
|
||
if save_results:
|
||
forecast_csv = f'{product_id}_{model_type}_forecast.csv'
|
||
predictions_df.to_csv(forecast_csv, index=False)
|
||
print(f"预测结果已保存到: {forecast_csv}")
|
||
|
||
return predictions_df
|
||
|
||
def compare_models(self, product_id, model_types=None, versions=None, product_df=None, visualize=True):
|
||
"""
|
||
比较不同模型的预测结果
|
||
|
||
参数:
|
||
product_id: 产品ID
|
||
model_types: 要比较的模型类型列表
|
||
versions: 对应的模型版本列表,如果不指定则使用最新版本
|
||
product_df: 产品数据DataFrame
|
||
visualize: 是否可视化结果
|
||
|
||
返回:
|
||
比较结果DataFrame
|
||
"""
|
||
if model_types is None:
|
||
model_types = ['mlstm', 'transformer', 'kan']
|
||
|
||
if versions is None:
|
||
versions = [None] * len(model_types)
|
||
|
||
if len(versions) != len(model_types):
|
||
print("错误: 模型类型和版本列表长度不匹配")
|
||
return None
|
||
|
||
# 如果没有提供产品数据,则从Excel文件加载
|
||
if product_df is None:
|
||
try:
|
||
df = pd.read_excel('pharmacy_sales.xlsx')
|
||
product_df = df[df['product_id'] == product_id].sort_values('date')
|
||
except Exception as e:
|
||
print(f"加载产品数据时出错: {str(e)}")
|
||
return None
|
||
|
||
product_name = product_df['product_name'].iloc[0]
|
||
|
||
# 存储所有模型的预测结果
|
||
predictions = {}
|
||
|
||
# 对每个模型进行预测
|
||
for i, model_type in enumerate(model_types):
|
||
version = versions[i]
|
||
|
||
try:
|
||
pred_df = self.predict_with_model(
|
||
product_id,
|
||
model_type=model_type,
|
||
version=version,
|
||
product_df=product_df,
|
||
visualize=False,
|
||
save_results=False
|
||
)
|
||
|
||
if pred_df is not None:
|
||
predictions[model_type] = pred_df
|
||
except Exception as e:
|
||
print(f"{model_type} 模型预测出错: {str(e)}")
|
||
|
||
if not predictions:
|
||
print("没有成功的预测结果")
|
||
return None
|
||
|
||
# 合并预测结果
|
||
result_df = predictions[list(predictions.keys())[0]][['date', 'product_id', 'product_name']].copy()
|
||
|
||
for model_type, pred_df in predictions.items():
|
||
result_df[f'{model_type}_prediction'] = pred_df['predicted_sales'].values
|
||
|
||
# 可视化比较结果
|
||
if visualize and len(predictions) > 0:
|
||
plt.figure(figsize=(12, 6))
|
||
|
||
# 显示历史数据
|
||
history_days = 30 # 显示最近30天的历史数据
|
||
history_dates = product_df['date'].iloc[-history_days:].values
|
||
history_sales = product_df['sales'].iloc[-history_days:].values
|
||
|
||
plt.plot(history_dates, history_sales, 'k-', label='历史销量')
|
||
|
||
# 显示预测数据
|
||
colors = ['r', 'g', 'b', 'c', 'm', 'y']
|
||
future_dates = result_df['date'].values
|
||
|
||
for i, (model_type, pred_df) in enumerate(predictions.items()):
|
||
color = colors[i % len(colors)]
|
||
plt.plot(future_dates, pred_df['predicted_sales'].values,
|
||
f'{color}--', label=f'{model_type.upper()}预测')
|
||
|
||
plt.title(f'{product_name} - 不同模型预测结果比较')
|
||
plt.xlabel('日期')
|
||
plt.ylabel('销量')
|
||
plt.legend()
|
||
plt.grid(True)
|
||
plt.xticks(rotation=45)
|
||
plt.tight_layout()
|
||
|
||
# 保存和显示图表
|
||
compare_chart = f'{product_id}_model_comparison.png'
|
||
plt.savefig(compare_chart)
|
||
print(f"比较图表已保存为: {compare_chart}")
|
||
|
||
# 保存比较结果到CSV
|
||
compare_csv = f'{product_id}_model_comparison.csv'
|
||
result_df.to_csv(compare_csv, index=False)
|
||
print(f"比较结果已保存到: {compare_csv}")
|
||
|
||
return result_df
|
||
|
||
def export_model(self, product_id, model_type, version=None, export_dir='exported_models'):
|
||
"""
|
||
导出模型到指定目录
|
||
|
||
参数:
|
||
product_id: 产品ID
|
||
model_type: 模型类型
|
||
version: 模型版本,如果不指定则导出最新版本
|
||
export_dir: 导出目录
|
||
|
||
返回:
|
||
export_path: 导出的文件路径
|
||
"""
|
||
# 确保导出目录存在
|
||
if not os.path.exists(export_dir):
|
||
os.makedirs(export_dir, exist_ok=True)
|
||
|
||
# 查找匹配的模型文件
|
||
if version is None:
|
||
# 查找最新版本
|
||
pattern = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v*.pt")
|
||
model_files = glob.glob(pattern)
|
||
|
||
if not model_files:
|
||
print(f"错误: 未找到产品 {product_id} 的 {model_type} 模型文件")
|
||
return None
|
||
|
||
# 按照文件修改时间排序,获取最新的
|
||
model_path = max(model_files, key=os.path.getmtime)
|
||
# 从文件名解析版本
|
||
filename = os.path.basename(model_path)
|
||
version = filename.split('_')[-1].replace('model_v', '').replace('.pt', '')
|
||
else:
|
||
model_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_model_v{version}.pt")
|
||
if not os.path.exists(model_path):
|
||
print(f"错误: 未找到产品 {product_id} 的 {model_type} 模型版本 {version}")
|
||
return None
|
||
|
||
# 元数据文件
|
||
meta_path = os.path.join(self.models_dir, f"{product_id}_{model_type}_meta_v{version}.json")
|
||
|
||
# 导出路径
|
||
export_model_path = os.path.join(export_dir, f"{product_id}_{model_type}_model_v{version}.pt")
|
||
export_meta_path = os.path.join(export_dir, f"{product_id}_{model_type}_meta_v{version}.json")
|
||
|
||
try:
|
||
# 复制文件
|
||
shutil.copy2(model_path, export_model_path)
|
||
if os.path.exists(meta_path):
|
||
shutil.copy2(meta_path, export_meta_path)
|
||
|
||
print(f"模型已导出到 {os.path.abspath(export_model_path)}")
|
||
return export_model_path
|
||
except Exception as e:
|
||
print(f"导出模型时出错: {str(e)}")
|
||
return None
|
||
|
||
def import_model(self, import_file, overwrite=False):
|
||
"""
|
||
导入模型文件
|
||
|
||
参数:
|
||
import_file: 要导入的模型文件路径
|
||
overwrite: 如果存在同名文件是否覆盖
|
||
|
||
返回:
|
||
import_path: 导入后的文件路径
|
||
"""
|
||
self._ensure_model_dir()
|
||
|
||
if not os.path.exists(import_file):
|
||
print(f"错误: 导入文件 {import_file} 不存在")
|
||
return None
|
||
|
||
# 获取文件名
|
||
filename = os.path.basename(import_file)
|
||
|
||
# 目标路径
|
||
target_path = os.path.join(self.models_dir, filename)
|
||
|
||
# 检查是否存在同名文件
|
||
if os.path.exists(target_path) and not overwrite:
|
||
print(f"错误: 目标文件 {target_path} 已存在,如需覆盖请设置overwrite=True")
|
||
return None
|
||
|
||
try:
|
||
# 复制文件
|
||
shutil.copy2(import_file, target_path)
|
||
|
||
# 如果有对应的元数据文件,也一并导入
|
||
meta_filename = filename.replace('_model_v', '_meta_v')
|
||
meta_import_file = import_file.replace('_model_v', '_meta_v').replace('.pt', '.json')
|
||
meta_target_path = os.path.join(self.models_dir, meta_filename.replace('.pt', '.json'))
|
||
|
||
if os.path.exists(meta_import_file):
|
||
shutil.copy2(meta_import_file, meta_target_path)
|
||
|
||
print(f"模型已导入到 {os.path.abspath(target_path)}")
|
||
return target_path
|
||
except Exception as e:
|
||
print(f"导入模型时出错: {str(e)}")
|
||
return None |