ShopTRAINING/server/utils/model_manager.py
2025-07-02 11:05:23 +08:00

383 lines
14 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
统一模型管理工具
处理模型文件的统一命名、存储和检索
"""
import os
import json
import torch
import glob
from datetime import datetime
from typing import List, Dict, Optional, Tuple
from core.config import DEFAULT_MODEL_DIR
class ModelManager:
"""统一模型管理器"""
def __init__(self, model_dir: str = DEFAULT_MODEL_DIR):
self.model_dir = model_dir
self.ensure_model_dir()
def ensure_model_dir(self):
"""确保模型目录存在"""
if not os.path.exists(self.model_dir):
os.makedirs(self.model_dir)
def generate_model_filename(self,
product_id: str,
model_type: str,
version: str,
store_id: Optional[str] = None,
training_mode: str = 'product',
aggregation_method: Optional[str] = None) -> str:
"""
生成统一的模型文件名
格式规范:
- 产品模式: {model_type}_product_{product_id}_{version}.pth
- 店铺模式: {model_type}_store_{store_id}_{product_id}_{version}.pth
- 全局模式: {model_type}_global_{product_id}_{aggregation_method}_{version}.pth
"""
if training_mode == 'store' and store_id:
return f"{model_type}_store_{store_id}_{product_id}_{version}.pth"
elif training_mode == 'global' and aggregation_method:
return f"{model_type}_global_{product_id}_{aggregation_method}_{version}.pth"
else:
# 默认产品模式
return f"{model_type}_product_{product_id}_{version}.pth"
def save_model(self,
model_data: dict,
product_id: str,
model_type: str,
version: str,
store_id: Optional[str] = None,
training_mode: str = 'product',
aggregation_method: Optional[str] = None,
product_name: Optional[str] = None) -> str:
"""
保存模型到统一位置
参数:
model_data: 包含模型状态和配置的字典
product_id: 产品ID
model_type: 模型类型
version: 版本号
store_id: 店铺ID (可选)
training_mode: 训练模式
aggregation_method: 聚合方法 (可选)
product_name: 产品名称 (可选)
返回:
模型文件路径
"""
filename = self.generate_model_filename(
product_id, model_type, version, store_id, training_mode, aggregation_method
)
# 统一保存到根目录,避免复杂的子目录结构
model_path = os.path.join(self.model_dir, filename)
# 增强模型数据,添加管理信息
enhanced_model_data = model_data.copy()
enhanced_model_data.update({
'model_manager_info': {
'product_id': product_id,
'product_name': product_name or product_id,
'model_type': model_type,
'version': version,
'store_id': store_id,
'training_mode': training_mode,
'aggregation_method': aggregation_method,
'created_at': datetime.now().isoformat(),
'filename': filename
}
})
# 保存模型
torch.save(enhanced_model_data, model_path)
print(f"模型已保存: {model_path}")
return model_path
def list_models(self,
product_id: Optional[str] = None,
model_type: Optional[str] = None,
store_id: Optional[str] = None,
training_mode: Optional[str] = None,
page: Optional[int] = None,
page_size: Optional[int] = None) -> Dict:
"""
列出所有模型文件
参数:
product_id: 产品ID过滤 (可选)
model_type: 模型类型过滤 (可选)
store_id: 店铺ID过滤 (可选)
training_mode: 训练模式过滤 (可选)
page: 页码从1开始 (可选)
page_size: 每页数量 (可选)
返回:
包含模型列表和分页信息的字典
"""
models = []
# 搜索所有.pth文件
pattern = os.path.join(self.model_dir, "*.pth")
model_files = glob.glob(pattern)
for model_file in model_files:
try:
# 解析文件名
filename = os.path.basename(model_file)
model_info = self.parse_model_filename(filename)
if not model_info:
continue
# 尝试从模型文件中读取额外信息
try:
# Try with weights_only=False first for backward compatibility
try:
model_data = torch.load(model_file, map_location='cpu', weights_only=False)
except Exception:
# If that fails, try with weights_only=True (newer PyTorch versions)
model_data = torch.load(model_file, map_location='cpu', weights_only=True)
if 'model_manager_info' in model_data:
# 使用新的管理信息
manager_info = model_data['model_manager_info']
model_info.update(manager_info)
# 添加评估指标
if 'metrics' in model_data:
model_info['metrics'] = model_data['metrics']
# 添加配置信息
if 'config' in model_data:
model_info['config'] = model_data['config']
except Exception as e:
print(f"读取模型文件失败 {model_file}: {e}")
# Continue with just the filename-based info
# 应用过滤器
if product_id and model_info.get('product_id') != product_id:
continue
if model_type and model_info.get('model_type') != model_type:
continue
if store_id and model_info.get('store_id') != store_id:
continue
if training_mode and model_info.get('training_mode') != training_mode:
continue
# 添加文件信息
model_info['filename'] = filename
model_info['file_path'] = model_file
model_info['file_size'] = os.path.getsize(model_file)
model_info['modified_at'] = datetime.fromtimestamp(
os.path.getmtime(model_file)
).isoformat()
models.append(model_info)
except Exception as e:
print(f"处理模型文件失败 {model_file}: {e}")
continue
# 按创建时间排序(最新的在前)
models.sort(key=lambda x: x.get('created_at', x.get('modified_at', '')), reverse=True)
# 计算分页信息
total_count = len(models)
# 如果没有指定分页参数,返回所有数据
if page is None or page_size is None:
return {
'models': models,
'pagination': {
'total': total_count,
'page': 1,
'page_size': total_count,
'total_pages': 1,
'has_next': False,
'has_previous': False
}
}
# 应用分页
total_pages = (total_count + page_size - 1) // page_size if page_size > 0 else 1
start_index = (page - 1) * page_size
end_index = start_index + page_size
paginated_models = models[start_index:end_index]
return {
'models': paginated_models,
'pagination': {
'total': total_count,
'page': page,
'page_size': page_size,
'total_pages': total_pages,
'has_next': page < total_pages,
'has_previous': page > 1
}
}
def parse_model_filename(self, filename: str) -> Optional[Dict]:
"""
解析模型文件名,提取模型信息
支持的格式:
- {model_type}_product_{product_id}_{version}.pth
- {model_type}_store_{store_id}_{product_id}_{version}.pth
- {model_type}_global_{product_id}_{aggregation_method}_{version}.pth
- 旧格式兼容
"""
if not filename.endswith('.pth'):
return None
base_name = filename.replace('.pth', '')
try:
# 新格式解析
if '_product_' in base_name:
# 产品模式: model_type_product_product_id_version
parts = base_name.split('_product_')
model_type = parts[0]
rest = parts[1]
# 分离产品ID和版本
if '_v' in rest:
last_v_index = rest.rfind('_v')
product_id = rest[:last_v_index]
version = rest[last_v_index+1:]
else:
product_id = rest
version = 'v1'
return {
'model_type': model_type,
'product_id': product_id,
'version': version,
'training_mode': 'product',
'store_id': None,
'aggregation_method': None
}
elif '_store_' in base_name:
# 店铺模式: model_type_store_store_id_product_id_version
parts = base_name.split('_store_')
model_type = parts[0]
rest = parts[1]
# 分离店铺ID、产品ID和版本
rest_parts = rest.split('_')
if len(rest_parts) >= 3:
store_id = rest_parts[0]
if rest_parts[-1].startswith('v'):
# 最后一部分是版本号
version = rest_parts[-1]
product_id = '_'.join(rest_parts[1:-1])
else:
version = 'v1'
product_id = '_'.join(rest_parts[1:])
return {
'model_type': model_type,
'product_id': product_id,
'version': version,
'training_mode': 'store',
'store_id': store_id,
'aggregation_method': None
}
elif '_global_' in base_name:
# 全局模式: model_type_global_product_id_aggregation_method_version
parts = base_name.split('_global_')
model_type = parts[0]
rest = parts[1]
rest_parts = rest.split('_')
if len(rest_parts) >= 3:
if rest_parts[-1].startswith('v'):
# 最后一部分是版本号
version = rest_parts[-1]
aggregation_method = rest_parts[-2]
product_id = '_'.join(rest_parts[:-2])
else:
version = 'v1'
aggregation_method = rest_parts[-1]
product_id = '_'.join(rest_parts[:-1])
return {
'model_type': model_type,
'product_id': product_id,
'version': version,
'training_mode': 'global',
'store_id': None,
'aggregation_method': aggregation_method
}
# 兼容旧格式
else:
# 尝试解析其他格式
if 'model_product_' in base_name:
parts = base_name.split('_model_product_')
model_type = parts[0]
product_part = parts[1]
if '_v' in product_part:
last_v_index = product_part.rfind('_v')
product_id = product_part[:last_v_index]
version = product_part[last_v_index+1:]
else:
product_id = product_part
version = 'v1'
return {
'model_type': model_type,
'product_id': product_id,
'version': version,
'training_mode': 'product',
'store_id': None,
'aggregation_method': None
}
except Exception as e:
print(f"解析文件名失败 {filename}: {e}")
return None
def delete_model(self, model_file: str) -> bool:
"""删除模型文件"""
try:
if os.path.exists(model_file):
os.remove(model_file)
print(f"已删除模型文件: {model_file}")
return True
else:
print(f"模型文件不存在: {model_file}")
return False
except Exception as e:
print(f"删除模型文件失败: {e}")
return False
def get_model_by_id(self, model_id: str) -> Optional[Dict]:
"""根据模型ID获取模型信息"""
models = self.list_models()
for model in models:
if model.get('filename', '').replace('.pth', '') == model_id:
return model
return None
# 全局模型管理器实例
# 确保使用项目根目录的saved_models而不是相对于当前工作目录
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(os.path.dirname(current_dir)) # 向上两级到项目根目录
absolute_model_dir = os.path.join(project_root, 'saved_models')
model_manager = ModelManager(absolute_model_dir)