""" 统一模型管理工具 处理模型文件的统一命名、存储和检索 """ import os import json import torch import glob from datetime import datetime import re from typing import List, Dict, Optional, Tuple from core.config import DEFAULT_MODEL_DIR class ModelManager: """统一模型管理器""" def __init__(self, model_dir: str = DEFAULT_MODEL_DIR): self.model_dir = model_dir self.ensure_model_dir() def ensure_model_dir(self): """确保模型目录存在""" if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) def _get_next_version(self, product_id: str, model_type: str, store_id: Optional[str] = None, training_mode: str = 'product') -> int: """获取下一个模型版本号 (纯数字)""" search_pattern = self.generate_model_filename( product_id=product_id, model_type=model_type, version='v*', store_id=store_id, training_mode=training_mode ) full_search_path = os.path.join(self.model_dir, search_pattern) existing_files = glob.glob(full_search_path) max_version = 0 for f in existing_files: match = re.search(r'_v(\d+)\.pth$', os.path.basename(f)) if match: max_version = max(max_version, int(match.group(1))) return max_version + 1 def generate_model_filename(self, product_id: str, model_type: str, version: str, store_id: Optional[str] = None, training_mode: str = 'product', aggregation_method: Optional[str] = None) -> str: """ 生成统一的模型文件名 格式规范: - 产品模式: {model_type}_product_{product_id}_{version}.pth - 店铺模式: {model_type}_store_{store_id}_{product_id}_{version}.pth - 全局模式: {model_type}_global_{product_id}_{aggregation_method}_{version}.pth """ if training_mode == 'store' and store_id: return f"{model_type}_store_{store_id}_{product_id}_{version}.pth" elif training_mode == 'global' and aggregation_method: return f"{model_type}_global_{product_id}_{aggregation_method}_{version}.pth" else: # 默认产品模式 return f"{model_type}_product_{product_id}_{version}.pth" def save_model(self, model_data: dict, product_id: str, model_type: str, store_id: Optional[str] = None, training_mode: str = 'product', aggregation_method: Optional[str] = None, product_name: Optional[str] = None, version: Optional[str] = None) -> Tuple[str, str]: """ 保存模型到统一位置,并自动管理版本。 参数: ... version: (可选) 如果提供,则覆盖自动版本控制 (如 'best')。 返回: (模型文件路径, 使用的版本号) """ if version is None: next_version_num = self._get_next_version(product_id, model_type, store_id, training_mode) version_str = f"v{next_version_num}" else: version_str = version filename = self.generate_model_filename( product_id, model_type, version_str, store_id, training_mode, aggregation_method ) # 统一保存到根目录,避免复杂的子目录结构 model_path = os.path.join(self.model_dir, filename) # 增强模型数据,添加管理信息 enhanced_model_data = model_data.copy() enhanced_model_data.update({ 'model_manager_info': { 'product_id': product_id, 'product_name': product_name or product_id, 'model_type': model_type, 'version': version_str, 'store_id': store_id, 'training_mode': training_mode, 'aggregation_method': aggregation_method, 'created_at': datetime.now().isoformat(), 'filename': filename } }) # 保存模型 torch.save(enhanced_model_data, model_path) print(f"模型已保存: {model_path}") return model_path, version_str def list_models(self, product_id: Optional[str] = None, model_type: Optional[str] = None, store_id: Optional[str] = None, training_mode: Optional[str] = None, page: Optional[int] = None, page_size: Optional[int] = None) -> Dict: """ 列出所有模型文件 参数: product_id: 产品ID过滤 (可选) model_type: 模型类型过滤 (可选) store_id: 店铺ID过滤 (可选) training_mode: 训练模式过滤 (可选) page: 页码,从1开始 (可选) page_size: 每页数量 (可选) 返回: 包含模型列表和分页信息的字典 """ models = [] # 搜索所有.pth文件 pattern = os.path.join(self.model_dir, "*.pth") model_files = glob.glob(pattern) for model_file in model_files: try: # 解析文件名 filename = os.path.basename(model_file) model_info = self.parse_model_filename(filename) if not model_info: continue # 尝试从模型文件中读取额外信息 try: # Try with weights_only=False first for backward compatibility try: model_data = torch.load(model_file, map_location='cpu', weights_only=False) except Exception: # If that fails, try with weights_only=True (newer PyTorch versions) model_data = torch.load(model_file, map_location='cpu', weights_only=True) if 'model_manager_info' in model_data: # 使用新的管理信息 manager_info = model_data['model_manager_info'] model_info.update(manager_info) # 添加评估指标 if 'metrics' in model_data: model_info['metrics'] = model_data['metrics'] # 添加配置信息 if 'config' in model_data: model_info['config'] = model_data['config'] except Exception as e: print(f"读取模型文件失败 {model_file}: {e}") # Continue with just the filename-based info # 应用过滤器 if product_id and model_info.get('product_id') != product_id: continue if model_type and model_info.get('model_type') != model_type: continue if store_id and model_info.get('store_id') != store_id: continue if training_mode and model_info.get('training_mode') != training_mode: continue # 添加文件信息 model_info['filename'] = filename model_info['file_path'] = model_file model_info['file_size'] = os.path.getsize(model_file) model_info['modified_at'] = datetime.fromtimestamp( os.path.getmtime(model_file) ).isoformat() models.append(model_info) except Exception as e: print(f"处理模型文件失败 {model_file}: {e}") continue # 按创建时间排序(最新的在前) models.sort(key=lambda x: x.get('created_at', x.get('modified_at', '')), reverse=True) # 计算分页信息 total_count = len(models) # 如果没有指定分页参数,返回所有数据 if page is None or page_size is None: return { 'models': models, 'pagination': { 'total': total_count, 'page': 1, 'page_size': total_count, 'total_pages': 1, 'has_next': False, 'has_previous': False } } # 应用分页 total_pages = (total_count + page_size - 1) // page_size if page_size > 0 else 1 start_index = (page - 1) * page_size end_index = start_index + page_size paginated_models = models[start_index:end_index] return { 'models': paginated_models, 'pagination': { 'total': total_count, 'page': page, 'page_size': page_size, 'total_pages': total_pages, 'has_next': page < total_pages, 'has_previous': page > 1 } } def parse_model_filename(self, filename: str) -> Optional[Dict]: """ 解析模型文件名,提取模型信息 支持的格式: - {model_type}_product_{product_id}_{version}.pth - {model_type}_store_{store_id}_{product_id}_{version}.pth - {model_type}_global_{product_id}_{aggregation_method}_{version}.pth - 旧格式兼容 """ if not filename.endswith('.pth'): return None base_name = filename.replace('.pth', '') try: # 新格式解析 if '_product_' in base_name: # 产品模式: model_type_product_product_id_version parts = base_name.split('_product_') model_type = parts[0] rest = parts[1] # 分离产品ID和版本 if '_v' in rest: last_v_index = rest.rfind('_v') product_id = rest[:last_v_index] version = rest[last_v_index+1:] else: product_id = rest version = 'v1' return { 'model_type': model_type, 'product_id': product_id, 'version': version, 'training_mode': 'product', 'store_id': None, 'aggregation_method': None } elif '_store_' in base_name: # 店铺模式: model_type_store_store_id_product_id_version parts = base_name.split('_store_') model_type = parts[0] rest = parts[1] # 分离店铺ID、产品ID和版本 rest_parts = rest.split('_') if len(rest_parts) >= 3: store_id = rest_parts[0] if rest_parts[-1].startswith('v'): # 最后一部分是版本号 version = rest_parts[-1] product_id = '_'.join(rest_parts[1:-1]) else: version = 'v1' product_id = '_'.join(rest_parts[1:]) return { 'model_type': model_type, 'product_id': product_id, 'version': version, 'training_mode': 'store', 'store_id': store_id, 'aggregation_method': None } elif '_global_' in base_name: # 全局模式: model_type_global_product_id_aggregation_method_version parts = base_name.split('_global_') model_type = parts[0] rest = parts[1] rest_parts = rest.split('_') if len(rest_parts) >= 3: if rest_parts[-1].startswith('v'): # 最后一部分是版本号 version = rest_parts[-1] aggregation_method = rest_parts[-2] product_id = '_'.join(rest_parts[:-2]) else: version = 'v1' aggregation_method = rest_parts[-1] product_id = '_'.join(rest_parts[:-1]) return { 'model_type': model_type, 'product_id': product_id, 'version': version, 'training_mode': 'global', 'store_id': None, 'aggregation_method': aggregation_method } # 兼容旧格式 else: # 尝试解析其他格式 if 'model_product_' in base_name: parts = base_name.split('_model_product_') model_type = parts[0] product_part = parts[1] if '_v' in product_part: last_v_index = product_part.rfind('_v') product_id = product_part[:last_v_index] version = product_part[last_v_index+1:] else: product_id = product_part version = 'v1' return { 'model_type': model_type, 'product_id': product_id, 'version': version, 'training_mode': 'product', 'store_id': None, 'aggregation_method': None } except Exception as e: print(f"解析文件名失败 {filename}: {e}") return None def delete_model(self, model_file: str) -> bool: """删除模型文件""" try: if os.path.exists(model_file): os.remove(model_file) print(f"已删除模型文件: {model_file}") return True else: print(f"模型文件不存在: {model_file}") return False except Exception as e: print(f"删除模型文件失败: {e}") return False def get_model_by_id(self, model_id: str) -> Optional[Dict]: """根据模型ID获取模型信息""" models = self.list_models() for model in models: if model.get('filename', '').replace('.pth', '') == model_id: return model return None # 全局模型管理器实例 # 确保使用项目根目录的saved_models,而不是相对于当前工作目录 import os current_dir = os.path.dirname(os.path.abspath(__file__)) project_root = os.path.dirname(os.path.dirname(current_dir)) # 向上两级到项目根目录 absolute_model_dir = os.path.join(project_root, 'saved_models') model_manager = ModelManager(absolute_model_dir)