ShopTRAINING/server/utils/model_manager.py

351 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
统一模型管理工具
处理模型文件的统一命名、存储和检索
"""
import os
import json
import torch
import glob
from datetime import datetime
import re
from typing import List, Dict, Optional, Tuple
from core.config import DEFAULT_MODEL_DIR
class ModelManager:
"""统一模型管理器"""
def __init__(self, model_dir: str = DEFAULT_MODEL_DIR):
self.model_dir = model_dir
self.ensure_model_dir()
def ensure_model_dir(self):
"""确保模型目录存在"""
if not os.path.exists(self.model_dir):
os.makedirs(self.model_dir)
def _get_next_version(self, model_type: str, product_id: Optional[str] = None, store_id: Optional[str] = None, training_mode: str = 'product', aggregation_method: Optional[str] = None) -> int:
"""获取下一个模型版本号 (纯数字)"""
search_pattern = self.generate_model_filename(
model_type=model_type,
version='v*',
product_id=product_id,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method
)
full_search_path = os.path.join(self.model_dir, search_pattern)
existing_files = glob.glob(full_search_path)
max_version = 0
for f in existing_files:
match = re.search(r'_v(\d+)\.pth$', os.path.basename(f))
if match:
max_version = max(max_version, int(match.group(1)))
return max_version + 1
def generate_model_filename(self,
model_type: str,
version: str,
training_mode: str = 'product',
product_id: Optional[str] = None,
store_id: Optional[str] = None,
aggregation_method: Optional[str] = None) -> str:
"""
生成统一的模型文件名
格式规范 (v2):
- 产品模式: {model_type}_product_{product_id}_{version}.pth
- 店铺模式: {model_type}_store_{store_id}_{version}.pth
- 全局模式: {model_type}_global_{aggregation_method}_{version}.pth
"""
if training_mode == 'store' and store_id:
return f"{model_type}_store_{store_id}_{version}.pth"
elif training_mode == 'global' and aggregation_method:
return f"{model_type}_global_{aggregation_method}_{version}.pth"
elif training_mode == 'product' and product_id:
return f"{model_type}_product_{product_id}_{version}.pth"
else:
# 提供一个后备或抛出错误,以避免生成无效文件名
raise ValueError(f"无法为训练模式 '{training_mode}' 生成有效的文件名缺少必需的ID。")
def save_model(self,
model_data: dict,
product_id: str,
model_type: str,
store_id: Optional[str] = None,
training_mode: str = 'product',
aggregation_method: Optional[str] = None,
product_name: Optional[str] = None,
version: Optional[str] = None) -> Tuple[str, str]:
"""
保存模型到统一位置,并自动管理版本。
参数:
...
version: (可选) 如果提供,则覆盖自动版本控制 (如 'best')。
返回:
(模型文件路径, 使用的版本号)
"""
if version is None:
next_version_num = self._get_next_version(
model_type=model_type,
product_id=product_id,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method
)
version_str = f"v{next_version_num}"
else:
version_str = version
filename = self.generate_model_filename(
model_type=model_type,
version=version_str,
training_mode=training_mode,
product_id=product_id,
store_id=store_id,
aggregation_method=aggregation_method
)
# 统一保存到根目录,避免复杂的子目录结构
model_path = os.path.join(self.model_dir, filename)
# 增强模型数据,添加管理信息
enhanced_model_data = model_data.copy()
enhanced_model_data.update({
'model_manager_info': {
'product_id': product_id,
'product_name': product_name or product_id,
'model_type': model_type,
'version': version_str,
'store_id': store_id,
'training_mode': training_mode,
'aggregation_method': aggregation_method,
'created_at': datetime.now().isoformat(),
'filename': filename
}
})
# 保存模型
torch.save(enhanced_model_data, model_path)
print(f"模型已保存: {model_path}")
return model_path, version_str
def list_models(self,
product_id: Optional[str] = None,
model_type: Optional[str] = None,
store_id: Optional[str] = None,
training_mode: Optional[str] = None,
page: Optional[int] = None,
page_size: Optional[int] = None) -> Dict:
"""
列出所有模型文件
参数:
product_id: 产品ID过滤 (可选)
model_type: 模型类型过滤 (可选)
store_id: 店铺ID过滤 (可选)
training_mode: 训练模式过滤 (可选)
page: 页码从1开始 (可选)
page_size: 每页数量 (可选)
返回:
包含模型列表和分页信息的字典
"""
models = []
# 搜索所有.pth文件
pattern = os.path.join(self.model_dir, "*.pth")
model_files = glob.glob(pattern)
for model_file in model_files:
try:
# 解析文件名
filename = os.path.basename(model_file)
model_info = self.parse_model_filename(filename)
if not model_info:
continue
# 尝试从模型文件中读取额外信息
try:
# Try with weights_only=False first for backward compatibility
try:
model_data = torch.load(model_file, map_location='cpu', weights_only=False)
except Exception:
# If that fails, try with weights_only=True (newer PyTorch versions)
model_data = torch.load(model_file, map_location='cpu', weights_only=True)
if 'model_manager_info' in model_data:
# 使用新的管理信息
manager_info = model_data['model_manager_info']
model_info.update(manager_info)
# 添加评估指标
if 'metrics' in model_data:
model_info['metrics'] = model_data['metrics']
# 添加配置信息
if 'config' in model_data:
model_info['config'] = model_data['config']
except Exception as e:
print(f"读取模型文件失败 {model_file}: {e}")
# Continue with just the filename-based info
# 应用过滤器
if product_id and model_info.get('product_id') != product_id:
continue
if model_type and model_info.get('model_type') != model_type:
continue
if store_id and model_info.get('store_id') != store_id:
continue
if training_mode and model_info.get('training_mode') != training_mode:
continue
# 添加文件信息
model_info['filename'] = filename
model_info['file_path'] = model_file
model_info['file_size'] = os.path.getsize(model_file)
model_info['modified_at'] = datetime.fromtimestamp(
os.path.getmtime(model_file)
).isoformat()
models.append(model_info)
except Exception as e:
print(f"处理模型文件失败 {model_file}: {e}")
continue
# 按创建时间排序(最新的在前)
models.sort(key=lambda x: x.get('created_at', x.get('modified_at', '')), reverse=True)
# 计算分页信息
total_count = len(models)
# 如果没有指定分页参数,返回所有数据
if page is None or page_size is None:
return {
'models': models,
'pagination': {
'total': total_count,
'page': 1,
'page_size': total_count,
'total_pages': 1,
'has_next': False,
'has_previous': False
}
}
# 应用分页
total_pages = (total_count + page_size - 1) // page_size if page_size > 0 else 1
start_index = (page - 1) * page_size
end_index = start_index + page_size
paginated_models = models[start_index:end_index]
return {
'models': paginated_models,
'pagination': {
'total': total_count,
'page': page,
'page_size': page_size,
'total_pages': total_pages,
'has_next': page < total_pages,
'has_previous': page > 1
}
}
def parse_model_filename(self, filename: str) -> Optional[Dict]:
"""
解析模型文件名,提取模型信息 (v2版)
支持的格式:
- 产品: {model_type}_product_{product_id}_{version}.pth
- 店铺: {model_type}_store_{store_id}_{version}.pth
- 全局: {model_type}_global_{aggregation_method}_{version}.pth
"""
if not filename.endswith('.pth'):
return None
base_name = filename.replace('.pth', '')
parts = base_name.split('_')
if len(parts) < 3:
return None # 格式不符合基本要求
# **核心修复**: 采用更健壮的、从后往前的解析逻辑,以支持带下划线的模型名称
try:
version = parts[-1]
identifier = parts[-2]
mode_candidate = parts[-3]
if mode_candidate == 'product':
model_type = '_'.join(parts[:-3])
return {
'model_type': model_type,
'training_mode': 'product',
'product_id': identifier,
'version': version,
}
elif mode_candidate == 'store':
model_type = '_'.join(parts[:-3])
return {
'model_type': model_type,
'training_mode': 'store',
'store_id': identifier,
'version': version,
}
elif mode_candidate == 'global':
model_type = '_'.join(parts[:-3])
return {
'model_type': model_type,
'training_mode': 'global',
'aggregation_method': identifier,
'version': version,
}
except IndexError:
# 如果文件名部分不够,则解析失败
pass
except Exception as e:
print(f"解析文件名失败 {filename}: {e}")
return None
def delete_model(self, model_file: str) -> bool:
"""删除模型文件"""
try:
if os.path.exists(model_file):
os.remove(model_file)
print(f"已删除模型文件: {model_file}")
return True
else:
print(f"模型文件不存在: {model_file}")
return False
except Exception as e:
print(f"删除模型文件失败: {e}")
return False
def get_model_by_id(self, model_id: str) -> Optional[Dict]:
"""根据模型ID获取模型信息"""
models = self.list_models()
for model in models:
if model.get('filename', '').replace('.pth', '') == model_id:
return model
return None
# 全局模型管理器实例
# 确保使用项目根目录的saved_models而不是相对于当前工作目录
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(os.path.dirname(current_dir)) # 向上两级到项目根目录
absolute_model_dir = os.path.join(project_root, 'saved_models')
model_manager = ModelManager(absolute_model_dir)