""" 统一模型管理工具 处理模型文件的统一命名、存储和检索 遵循层级式目录结构和文件版本管理规则 """ import os import json import torch import glob from datetime import datetime from typing import List, Dict, Optional, Any from threading import Lock from core.config import DEFAULT_MODEL_DIR class ModelManager: """ 统一模型管理器,采用结构化目录和版本文件进行管理。 """ VERSION_FILE = 'versions.json' def __init__(self, model_dir: str = DEFAULT_MODEL_DIR): self.model_dir = os.path.abspath(model_dir) self.versions_path = os.path.join(self.model_dir, self.VERSION_FILE) self._lock = Lock() self.ensure_model_dir() def ensure_model_dir(self): """确保模型根目录存在""" os.makedirs(self.model_dir, exist_ok=True) def _read_versions(self) -> Dict[str, int]: """线程安全地读取版本文件""" with self._lock: if not os.path.exists(self.versions_path): return {} try: with open(self.versions_path, 'r', encoding='utf-8') as f: return json.load(f) except (json.JSONDecodeError, IOError): return {} def _write_versions(self, versions: Dict[str, int]): """线程安全地写入版本文件""" with self._lock: with open(self.versions_path, 'w', encoding='utf-8') as f: json.dump(versions, f, indent=2) def get_model_identifier(self, model_type: str, training_mode: str, scope: str, aggregation_method: Optional[str] = None) -> str: """ 生成模型的唯一标识符,用于版本文件中的key。 """ if training_mode == 'global': return f"{training_mode}_{scope}_{aggregation_method}_{model_type}" return f"{training_mode}_{scope}_{model_type}" def get_next_version_number(self, model_identifier: str) -> int: """ 获取指定模型的下一个版本号(整数)。 """ versions = self._read_versions() current_version = versions.get(model_identifier, 0) return current_version + 1 def update_version(self, model_identifier: str, new_version: int): """ 更新模型的最新版本号。 """ versions = self._read_versions() versions[model_identifier] = new_version self._write_versions(versions) def get_model_version_path(self, model_type: str, training_mode: str, scope: str, version: int, aggregation_method: Optional[str] = None) -> str: """ 根据新规则生成模型版本目录的完整路径。 """ base_path = os.path.join(self.model_dir, training_mode, scope) if training_mode == 'global' and aggregation_method: base_path = os.path.join(base_path, str(aggregation_method)) version_path = os.path.join(base_path, model_type, f'v{version}') return version_path def save_model_artifact(self, artifact_data: Any, artifact_name: str, model_version_path: str): """ 在指定的模型版本目录下保存一个产物。 Args: artifact_data: 要保存的数据 (e.g., model state dict, figure object). artifact_name: 标准化的产物文件名 (e.g., 'model.pth', 'loss_curve.png'). model_version_path: 模型版本目录的路径. """ os.makedirs(model_version_path, exist_ok=True) full_path = os.path.join(model_version_path, artifact_name) if artifact_name.endswith('.pth'): torch.save(artifact_data, full_path) elif artifact_name.endswith('.png') and hasattr(artifact_data, 'savefig'): artifact_data.savefig(full_path, dpi=300, bbox_inches='tight') elif artifact_name.endswith('.json'): with open(full_path, 'w', encoding='utf-8') as f: json.dump(artifact_data, f, indent=2, ensure_ascii=False) else: raise ValueError(f"不支持的产物类型: {artifact_name}") print(f"产物已保存: {full_path}") def list_models(self, page: Optional[int] = None, page_size: Optional[int] = None) -> Dict: """ 通过扫描目录结构来列出所有模型。 """ all_models = [] for training_mode in os.listdir(self.model_dir): mode_path = os.path.join(self.model_dir, training_mode) if not os.path.isdir(mode_path) or training_mode == 'checkpoints' or training_mode == self.VERSION_FILE: continue for scope in os.listdir(mode_path): scope_path = os.path.join(mode_path, scope) if not os.path.isdir(scope_path): continue is_global_agg_level = False if training_mode == 'global' and os.listdir(scope_path): try: first_item_path = os.path.join(scope_path, os.listdir(scope_path)[0]) if os.path.isdir(first_item_path): is_global_agg_level = True except IndexError: continue if is_global_agg_level: for agg_method in os.listdir(scope_path): agg_path = os.path.join(scope_path, agg_method) if not os.path.isdir(agg_path): continue for model_type in os.listdir(agg_path): type_path = os.path.join(agg_path, model_type) if not os.path.isdir(type_path): continue for version_folder in os.listdir(type_path): if version_folder.startswith('v'): version_path = os.path.join(type_path, version_folder) model_info = self._parse_info_from_path(version_path) if model_info: all_models.append(model_info) else: for model_type in os.listdir(scope_path): type_path = os.path.join(scope_path, model_type) if not os.path.isdir(type_path): continue for version_folder in os.listdir(type_path): if version_folder.startswith('v'): version_path = os.path.join(type_path, version_folder) model_info = self._parse_info_from_path(version_path) if model_info: all_models.append(model_info) total_count = len(all_models) if page and page_size: start_index = (page - 1) * page_size end_index = start_index + page_size paginated_models = all_models[start_index:end_index] else: paginated_models = all_models return { 'models': paginated_models, 'pagination': { 'total': total_count, 'page': page or 1, 'page_size': page_size or total_count, 'total_pages': (total_count + page_size - 1) // page_size if page_size and page_size > 0 else 1, } } def _parse_info_from_path(self, version_path: str) -> Optional[Dict]: """从版本目录路径解析模型信息""" try: norm_path = os.path.normpath(version_path) norm_model_dir = os.path.normpath(self.model_dir) relative_path = os.path.relpath(norm_path, norm_model_dir) parts = relative_path.split(os.sep) info = { 'model_path': version_path, 'version': parts[-1], 'model_type': parts[-2] } training_mode = parts[0] info['training_mode'] = training_mode if training_mode == 'global': info['scope'] = parts[1] info['aggregation_method'] = parts[2] info['model_identifier'] = self.get_model_identifier(info['model_type'], training_mode, info['scope'], info['aggregation_method']) else: info['scope'] = parts[1] info['aggregation_method'] = None info['model_identifier'] = self.get_model_identifier(info['model_type'], training_mode, info['scope']) metadata_path = os.path.join(version_path, 'metadata.json') if os.path.exists(metadata_path): with open(metadata_path, 'r', encoding='utf-8') as f: metadata = json.load(f) info.update(metadata) return info except (IndexError, IOError) as e: print(f"解析路径失败 {version_path}: {e}") return None # 全局模型管理器实例 model_manager = ModelManager()