""" 统一模型管理工具 处理模型文件的统一命名、存储和检索 遵循层级式目录结构和文件版本管理规则 """ import os import json import torch import glob from datetime import datetime from typing import List, Dict, Optional, Any from threading import Lock from core.config import DEFAULT_MODEL_DIR class ModelManager: """ 统一模型管理器,采用结构化目录和版本文件进行管理。 """ VERSION_FILE = 'versions.json' def __init__(self, model_dir: str = DEFAULT_MODEL_DIR): self.model_dir = os.path.abspath(model_dir) self.versions_path = os.path.join(self.model_dir, self.VERSION_FILE) self._lock = Lock() self.ensure_model_dir() def ensure_model_dir(self): """确保模型根目录存在""" os.makedirs(self.model_dir, exist_ok=True) def _read_versions(self) -> Dict[str, int]: """线程安全地读取版本文件""" with self._lock: if not os.path.exists(self.versions_path): return {} try: with open(self.versions_path, 'r', encoding='utf-8') as f: return json.load(f) except (json.JSONDecodeError, IOError): return {} def _write_versions(self, versions: Dict[str, int]): """线程安全地写入版本文件""" with self._lock: with open(self.versions_path, 'w', encoding='utf-8') as f: json.dump(versions, f, indent=2) def get_model_identifier(self, model_type: str, training_mode: str, scope: str, aggregation_method: Optional[str] = None) -> str: """ 生成模型的唯一标识符,用于版本文件中的key。 """ if training_mode == 'global': return f"{training_mode}_{scope}_{aggregation_method}_{model_type}" return f"{training_mode}_{scope}_{model_type}" def get_next_version_number(self, model_identifier: str) -> int: """ 获取指定模型的下一个版本号(整数)。 """ versions = self._read_versions() current_version = versions.get(model_identifier, 0) return current_version + 1 def update_version(self, model_identifier: str, new_version: int): """ 更新模型的最新版本号。 """ versions = self._read_versions() versions[model_identifier] = new_version self._write_versions(versions) def get_model_version_path(self, model_type: str, version: int, training_mode: str, aggregation_method: Optional[str] = None, store_id: Optional[str] = None, product_id: Optional[str] = None, scope: Optional[str] = None) -> str: # scope为了兼容旧调用 """ 根据 `xz训练模型保存规则.md` 中定义的新规则生成模型版本目录的完整路径。 """ # 基础路径始终是 self.model_dir base_path = self.model_dir # 确定第一级目录,根据规则,所有模式都在 'global' 下 path_parts = [base_path, 'global'] if training_mode == 'global': # global/all/{aggregation_method}/{model_type}/v{N}/ path_parts.extend(['all', str(aggregation_method)]) elif training_mode == 'stores': # global/stores/{store_id}/{aggregation_method}/{model_type}/v{N}/ if not store_id: raise ValueError("store_id is required for 'stores' training mode.") path_parts.extend(['stores', store_id, str(aggregation_method)]) elif training_mode == 'products': # global/products/{product_id}/{aggregation_method}/{model_type}/v{N}/ if not product_id: raise ValueError("product_id is required for 'products' training mode.") path_parts.extend(['products', product_id, str(aggregation_method)]) elif training_mode == 'custom': # global/custom/{store_id}/{product_id}/{aggregation_method}/{model_type}/v{N}/ if not store_id or not product_id: raise ValueError("store_id and product_id are required for 'custom' training mode.") path_parts.extend(['custom', store_id, product_id, str(aggregation_method)]) else: raise ValueError(f"不支持的 training_mode: {training_mode}") path_parts.extend([model_type, f'v{version}']) return os.path.join(*path_parts) def save_model_artifact(self, artifact_data: Any, artifact_name: str, model_version_path: str): """ 在指定的模型版本目录下保存一个产物。 Args: artifact_data: 要保存的数据 (e.g., model state dict, figure object). artifact_name: 标准化的产物文件名 (e.g., 'model.pth', 'loss_curve.png'). model_version_path: 模型版本目录的路径. """ os.makedirs(model_version_path, exist_ok=True) full_path = os.path.join(model_version_path, artifact_name) if artifact_name.endswith('.pth'): torch.save(artifact_data, full_path) elif artifact_name.endswith('.png') and hasattr(artifact_data, 'savefig'): artifact_data.savefig(full_path, dpi=300, bbox_inches='tight') elif artifact_name.endswith('.json'): with open(full_path, 'w', encoding='utf-8') as f: json.dump(artifact_data, f, indent=2, ensure_ascii=False) else: raise ValueError(f"不支持的产物类型: {artifact_name}") print(f"产物已保存: {full_path}") def list_models(self, page: Optional[int] = None, page_size: Optional[int] = None) -> Dict: """ 通过扫描目录结构来列出所有模型 (适配新结构)。 """ all_models = [] # 使用glob查找所有版本目录 search_pattern = os.path.join(self.model_dir, '**', 'v*') for version_path in glob.glob(search_pattern, recursive=True): # 确保它是一个目录并且包含 metadata.json metadata_path = os.path.join(version_path, 'metadata.json') if os.path.isdir(version_path) and os.path.exists(metadata_path): model_info = self._parse_info_from_path(version_path) if model_info: all_models.append(model_info) # 按时间戳降序排序 all_models.sort(key=lambda x: x.get('timestamp', ''), reverse=True) total_count = len(all_models) if page and page_size: start_index = (page - 1) * page_size end_index = start_index + page_size paginated_models = all_models[start_index:end_index] else: paginated_models = all_models return { 'models': paginated_models, 'pagination': { 'total': total_count, 'page': page or 1, 'page_size': page_size or total_count, 'total_pages': (total_count + page_size - 1) // page_size if page_size and page_size > 0 else 1, } } def _parse_info_from_path(self, version_path: str) -> Optional[Dict]: """根据新的目录结构从版本目录路径解析模型信息""" try: norm_path = os.path.normpath(version_path) norm_model_dir = os.path.normpath(self.model_dir) relative_path = os.path.relpath(norm_path, norm_model_dir) parts = relative_path.split(os.sep) # 期望路径: global/{scope_type}/{id...}/{agg_method}/{model_type}/v{N} if parts[0] != 'global' or len(parts) < 5: return None # 不是规范的新路径 info = { 'model_path': version_path, 'version': parts[-1], 'model_type': parts[-2], 'store_id': None, 'product_id': None, } scope_type = parts[1] # all, stores, products, custom if scope_type == 'all': # global/all/sum/mlstm/v1 info['training_mode'] = 'global' info['aggregation_method'] = parts[2] elif scope_type == 'stores': # global/stores/S001/sum/mlstm/v1 info['training_mode'] = 'stores' info['store_id'] = parts[2] info['aggregation_method'] = parts[3] elif scope_type == 'products': # global/products/P001/sum/mlstm/v1 info['training_mode'] = 'products' info['product_id'] = parts[2] info['aggregation_method'] = parts[3] elif scope_type == 'custom': # global/custom/S001/P001/sum/mlstm/v1 info['training_mode'] = 'custom' info['store_id'] = parts[2] info['product_id'] = parts[3] info['aggregation_method'] = parts[4] else: return None metadata_path = os.path.join(version_path, 'metadata.json') if os.path.exists(metadata_path): with open(metadata_path, 'r', encoding='utf-8') as f: metadata = json.load(f) # 确保从路径解析出的ID覆盖元数据中的,因为路径是权威来源 info.update(metadata) info['version'] = parts[-1] # 重新覆盖,确保正确 info['model_type'] = parts[-2] return info except (IndexError, IOError) as e: print(f"解析路径失败 {version_path}: {e}") return None # 全局模型管理器实例 model_manager = ModelManager()