""" 统一模型管理工具 处理模型文件的统一命名、存储和检索 """ import os import json import torch import glob from datetime import datetime import re from typing import List, Dict, Optional, Tuple from core.config import DEFAULT_MODEL_DIR class ModelManager: """统一模型管理器""" def __init__(self, model_dir: str = DEFAULT_MODEL_DIR): self.model_dir = model_dir self.ensure_model_dir() def ensure_model_dir(self): """确保模型目录存在""" if not os.path.exists(self.model_dir): os.makedirs(self.model_dir) def _get_next_version(self, model_type: str, product_id: Optional[str] = None, store_id: Optional[str] = None, training_mode: str = 'product', aggregation_method: Optional[str] = None) -> int: """获取下一个模型版本号 (纯数字)""" search_pattern = self.generate_model_filename( model_type=model_type, version='v*', product_id=product_id, store_id=store_id, training_mode=training_mode, aggregation_method=aggregation_method ) full_search_path = os.path.join(self.model_dir, search_pattern) existing_files = glob.glob(full_search_path) max_version = 0 for f in existing_files: # 修正: 同时匹配 _v1.pth 和 _v1_best.pth 这样的文件名 match = re.search(r'_v(\d+)(_best)?\.pth$', os.path.basename(f)) if match: max_version = max(max_version, int(match.group(1))) return max_version + 1 def peek_next_version(self, model_type: str, product_id: Optional[str] = None, store_id: Optional[str] = None, training_mode: str = 'product', aggregation_method: Optional[str] = None) -> str: """ 预览下一个版本号字符串 (e.g., 'v3'),但不进行任何文件操作。 """ next_version_num = self._get_next_version( model_type=model_type, product_id=product_id, store_id=store_id, training_mode=training_mode, aggregation_method=aggregation_method ) return f"v{next_version_num}" def generate_model_filename(self, model_type: str, version: str, training_mode: str = 'product', product_id: Optional[str] = None, store_id: Optional[str] = None, aggregation_method: Optional[str] = None) -> str: """ 生成统一的模型文件名 格式规范 (v2): - 产品模式: {model_type}_product_{product_id}_{version}.pth - 店铺模式: {model_type}_store_{store_id}_{version}.pth - 全局模式: {model_type}_global_{aggregation_method}_{version}.pth """ if training_mode == 'store' and store_id: return f"{model_type}_store_{store_id}_{version}.pth" elif training_mode == 'global' and aggregation_method: return f"{model_type}_global_{aggregation_method}_{version}.pth" elif training_mode == 'product' and product_id: return f"{model_type}_product_{product_id}_{version}.pth" else: # 提供一个后备或抛出错误,以避免生成无效文件名 raise ValueError(f"无法为训练模式 '{training_mode}' 生成有效的文件名,缺少必需的ID。") def save_model(self, model_data: dict, product_id: str, model_type: str, store_id: Optional[str] = None, training_mode: str = 'product', aggregation_method: Optional[str] = None, product_name: Optional[str] = None, version: Optional[str] = None) -> Tuple[str, str]: """ 保存模型到统一位置,并自动管理版本。 参数: ... version: (可选) 如果提供,则覆盖自动版本控制 (如 'best')。 返回: (模型文件路径, 使用的版本号) """ # 修正: 简化版本处理逻辑,由调用方明确提供版本字符串 if version is None: # 如果未提供版本,则自动生成新版本 next_version_num = self._get_next_version( model_type=model_type, product_id=product_id, store_id=store_id, training_mode=training_mode, aggregation_method=aggregation_method ) version_str = f"v{next_version_num}" else: # 直接使用调用方提供的版本字符串 (e.g., 'v3', 'v3_best') version_str = version filename = self.generate_model_filename( model_type=model_type, version=version_str, training_mode=training_mode, product_id=product_id, store_id=store_id, aggregation_method=aggregation_method ) # 统一保存到根目录,避免复杂的子目录结构 model_path = os.path.join(self.model_dir, filename) # 增强模型数据,添加管理信息 enhanced_model_data = model_data.copy() enhanced_model_data.update({ 'model_manager_info': { 'product_id': product_id, 'product_name': product_name or product_id, 'model_type': model_type, 'version': version_str, 'store_id': store_id, 'training_mode': training_mode, 'aggregation_method': aggregation_method, 'created_at': datetime.now().isoformat(), 'filename': filename } }) # 保存模型 torch.save(enhanced_model_data, model_path) print(f"模型已保存: {model_path}") return model_path, version_str def list_models(self, product_id: Optional[str] = None, model_type: Optional[str] = None, store_id: Optional[str] = None, training_mode: Optional[str] = None, page: Optional[int] = None, page_size: Optional[int] = None) -> Dict: """ 列出所有模型文件 参数: product_id: 产品ID过滤 (可选) model_type: 模型类型过滤 (可选) store_id: 店铺ID过滤 (可选) training_mode: 训练模式过滤 (可选) page: 页码,从1开始 (可选) page_size: 每页数量 (可选) 返回: 包含模型列表和分页信息的字典 """ models = [] # 搜索所有.pth文件 pattern = os.path.join(self.model_dir, "*.pth") model_files = glob.glob(pattern) for model_file in model_files: try: # 解析文件名 filename = os.path.basename(model_file) model_info = self.parse_model_filename(filename) if not model_info: continue # 尝试从模型文件中读取额外信息 try: # Try with weights_only=False first for backward compatibility try: model_data = torch.load(model_file, map_location='cpu', weights_only=False) except Exception: # If that fails, try with weights_only=True (newer PyTorch versions) model_data = torch.load(model_file, map_location='cpu', weights_only=True) if 'model_manager_info' in model_data: # 使用新的管理信息 manager_info = model_data['model_manager_info'] model_info.update(manager_info) # 添加评估指标 if 'metrics' in model_data: model_info['metrics'] = model_data['metrics'] # 添加配置信息 if 'config' in model_data: model_info['config'] = model_data['config'] except Exception as e: print(f"读取模型文件失败 {model_file}: {e}") # Continue with just the filename-based info # 应用过滤器 if product_id and model_info.get('product_id') != product_id: continue if model_type and model_info.get('model_type') != model_type: continue if store_id and model_info.get('store_id') != store_id: continue if training_mode and model_info.get('training_mode') != training_mode: continue # 添加文件信息 model_info['filename'] = filename # 修正: 确保返回的是相对路径 model_info['file_path'] = os.path.join(self.model_dir, filename) model_info['file_size'] = os.path.getsize(model_file) model_info['modified_at'] = datetime.fromtimestamp( os.path.getmtime(model_file) ).isoformat() models.append(model_info) except Exception as e: print(f"处理模型文件失败 {model_file}: {e}") continue # 按创建时间排序(最新的在前) models.sort(key=lambda x: x.get('created_at', x.get('modified_at', '')), reverse=True) # 计算分页信息 total_count = len(models) # 如果没有指定分页参数,返回所有数据 if page is None or page_size is None: return { 'models': models, 'pagination': { 'total': total_count, 'page': 1, 'page_size': total_count, 'total_pages': 1, 'has_next': False, 'has_previous': False } } # 应用分页 total_pages = (total_count + page_size - 1) // page_size if page_size > 0 else 1 start_index = (page - 1) * page_size end_index = start_index + page_size paginated_models = models[start_index:end_index] return { 'models': paginated_models, 'pagination': { 'total': total_count, 'page': page, 'page_size': page_size, 'total_pages': total_pages, 'has_next': page < total_pages, 'has_previous': page > 1 } } def parse_model_filename(self, filename: str) -> Optional[Dict]: """ 解析模型文件名,提取模型信息 (v2版) 支持的格式: - 产品: {model_type}_product_{product_id}_{version}.pth - 店铺: {model_type}_store_{store_id}_{version}.pth - 全局: {model_type}_global_{aggregation_method}_{version}.pth """ if not filename.endswith('.pth'): return None base_name = filename.replace('.pth', '') parts = base_name.split('_') if len(parts) < 3: return None # 格式不符合基本要求 # **核心修复**: 采用更健壮的、从后往前的解析逻辑,以支持带下划线的模型名称 try: version = parts[-1] identifier = parts[-2] mode_candidate = parts[-3] if mode_candidate == 'product': model_type = '_'.join(parts[:-3]) return { 'model_type': model_type, 'training_mode': 'product', 'product_id': identifier, 'version': version, } elif mode_candidate == 'store': model_type = '_'.join(parts[:-3]) return { 'model_type': model_type, 'training_mode': 'store', 'store_id': identifier, 'version': version, } elif mode_candidate == 'global': model_type = '_'.join(parts[:-3]) return { 'model_type': model_type, 'training_mode': 'global', 'aggregation_method': identifier, 'version': version, } except IndexError: # 如果文件名部分不够,则解析失败 pass except Exception as e: print(f"解析文件名失败 {filename}: {e}") return None def delete_model(self, model_file: str) -> bool: """删除模型文件""" try: if os.path.exists(model_file): os.remove(model_file) print(f"已删除模型文件: {model_file}") return True else: print(f"模型文件不存在: {model_file}") return False except Exception as e: print(f"删除模型文件失败: {e}") return False def get_model_by_id(self, model_id: str) -> Optional[Dict]: """根据模型ID获取模型信息""" models = self.list_models() for model in models: if model.get('filename', '').replace('.pth', '') == model_id: return model return None # 全局模型管理器实例 # 修正: 直接使用在config.py中定义的相对路径 model_manager = ModelManager()