2025-07-02 11:05:23 +08:00
|
|
|
|
"""
|
|
|
|
|
统一模型管理工具
|
|
|
|
|
处理模型文件的统一命名、存储和检索
|
2025-07-15 20:09:05 +08:00
|
|
|
|
遵循层级式目录结构和文件版本管理规则
|
2025-07-02 11:05:23 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
import json
|
|
|
|
|
import torch
|
|
|
|
|
import glob
|
|
|
|
|
from datetime import datetime
|
2025-07-15 20:09:05 +08:00
|
|
|
|
from typing import List, Dict, Optional, Any
|
|
|
|
|
from threading import Lock
|
2025-07-02 11:05:23 +08:00
|
|
|
|
from core.config import DEFAULT_MODEL_DIR
|
|
|
|
|
|
|
|
|
|
class ModelManager:
|
2025-07-15 20:09:05 +08:00
|
|
|
|
"""
|
|
|
|
|
统一模型管理器,采用结构化目录和版本文件进行管理。
|
|
|
|
|
"""
|
|
|
|
|
VERSION_FILE = 'versions.json'
|
|
|
|
|
|
2025-07-02 11:05:23 +08:00
|
|
|
|
def __init__(self, model_dir: str = DEFAULT_MODEL_DIR):
|
2025-07-15 20:09:05 +08:00
|
|
|
|
self.model_dir = os.path.abspath(model_dir)
|
|
|
|
|
self.versions_path = os.path.join(self.model_dir, self.VERSION_FILE)
|
|
|
|
|
self._lock = Lock()
|
2025-07-02 11:05:23 +08:00
|
|
|
|
self.ensure_model_dir()
|
2025-07-15 20:09:05 +08:00
|
|
|
|
|
2025-07-02 11:05:23 +08:00
|
|
|
|
def ensure_model_dir(self):
|
2025-07-15 20:09:05 +08:00
|
|
|
|
"""确保模型根目录存在"""
|
|
|
|
|
os.makedirs(self.model_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
def _read_versions(self) -> Dict[str, int]:
|
|
|
|
|
"""线程安全地读取版本文件"""
|
|
|
|
|
with self._lock:
|
|
|
|
|
if not os.path.exists(self.versions_path):
|
|
|
|
|
return {}
|
|
|
|
|
try:
|
|
|
|
|
with open(self.versions_path, 'r', encoding='utf-8') as f:
|
|
|
|
|
return json.load(f)
|
|
|
|
|
except (json.JSONDecodeError, IOError):
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
def _write_versions(self, versions: Dict[str, int]):
|
|
|
|
|
"""线程安全地写入版本文件"""
|
|
|
|
|
with self._lock:
|
|
|
|
|
with open(self.versions_path, 'w', encoding='utf-8') as f:
|
|
|
|
|
json.dump(versions, f, indent=2)
|
|
|
|
|
|
|
|
|
|
def get_model_identifier(self,
|
|
|
|
|
model_type: str,
|
|
|
|
|
training_mode: str,
|
|
|
|
|
scope: str,
|
|
|
|
|
aggregation_method: Optional[str] = None) -> str:
|
2025-07-02 11:05:23 +08:00
|
|
|
|
"""
|
2025-07-15 20:09:05 +08:00
|
|
|
|
生成模型的唯一标识符,用于版本文件中的key。
|
2025-07-02 11:05:23 +08:00
|
|
|
|
"""
|
2025-07-15 20:09:05 +08:00
|
|
|
|
if training_mode == 'global':
|
|
|
|
|
return f"{training_mode}_{scope}_{aggregation_method}_{model_type}"
|
|
|
|
|
return f"{training_mode}_{scope}_{model_type}"
|
|
|
|
|
|
|
|
|
|
def get_next_version_number(self, model_identifier: str) -> int:
|
2025-07-02 11:05:23 +08:00
|
|
|
|
"""
|
2025-07-15 20:09:05 +08:00
|
|
|
|
获取指定模型的下一个版本号(整数)。
|
2025-07-02 11:05:23 +08:00
|
|
|
|
"""
|
2025-07-15 20:09:05 +08:00
|
|
|
|
versions = self._read_versions()
|
|
|
|
|
current_version = versions.get(model_identifier, 0)
|
|
|
|
|
return current_version + 1
|
|
|
|
|
|
|
|
|
|
def update_version(self, model_identifier: str, new_version: int):
|
|
|
|
|
"""
|
|
|
|
|
更新模型的最新版本号。
|
|
|
|
|
"""
|
|
|
|
|
versions = self._read_versions()
|
|
|
|
|
versions[model_identifier] = new_version
|
|
|
|
|
self._write_versions(versions)
|
|
|
|
|
|
|
|
|
|
def get_model_version_path(self,
|
|
|
|
|
model_type: str,
|
|
|
|
|
training_mode: str,
|
|
|
|
|
scope: str,
|
|
|
|
|
version: int,
|
|
|
|
|
aggregation_method: Optional[str] = None) -> str:
|
|
|
|
|
"""
|
|
|
|
|
根据新规则生成模型版本目录的完整路径。
|
|
|
|
|
"""
|
|
|
|
|
base_path = os.path.join(self.model_dir, training_mode, scope)
|
|
|
|
|
if training_mode == 'global' and aggregation_method:
|
|
|
|
|
base_path = os.path.join(base_path, str(aggregation_method))
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
version_path = os.path.join(base_path, model_type, f'v{version}')
|
|
|
|
|
return version_path
|
|
|
|
|
|
|
|
|
|
def save_model_artifact(self,
|
|
|
|
|
artifact_data: Any,
|
|
|
|
|
artifact_name: str,
|
|
|
|
|
model_version_path: str):
|
|
|
|
|
"""
|
|
|
|
|
在指定的模型版本目录下保存一个产物。
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
Args:
|
|
|
|
|
artifact_data: 要保存的数据 (e.g., model state dict, figure object).
|
|
|
|
|
artifact_name: 标准化的产物文件名 (e.g., 'model.pth', 'loss_curve.png').
|
|
|
|
|
model_version_path: 模型版本目录的路径.
|
|
|
|
|
"""
|
|
|
|
|
os.makedirs(model_version_path, exist_ok=True)
|
|
|
|
|
full_path = os.path.join(model_version_path, artifact_name)
|
|
|
|
|
|
|
|
|
|
if artifact_name.endswith('.pth'):
|
|
|
|
|
torch.save(artifact_data, full_path)
|
|
|
|
|
elif artifact_name.endswith('.png') and hasattr(artifact_data, 'savefig'):
|
|
|
|
|
artifact_data.savefig(full_path, dpi=300, bbox_inches='tight')
|
|
|
|
|
elif artifact_name.endswith('.json'):
|
|
|
|
|
with open(full_path, 'w', encoding='utf-8') as f:
|
|
|
|
|
json.dump(artifact_data, f, indent=2, ensure_ascii=False)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"不支持的产物类型: {artifact_name}")
|
|
|
|
|
|
|
|
|
|
print(f"产物已保存: {full_path}")
|
|
|
|
|
|
|
|
|
|
def list_models(self,
|
2025-07-02 11:05:23 +08:00
|
|
|
|
page: Optional[int] = None,
|
|
|
|
|
page_size: Optional[int] = None) -> Dict:
|
|
|
|
|
"""
|
2025-07-15 20:09:05 +08:00
|
|
|
|
通过扫描目录结构来列出所有模型。
|
2025-07-02 11:05:23 +08:00
|
|
|
|
"""
|
2025-07-15 20:09:05 +08:00
|
|
|
|
all_models = []
|
|
|
|
|
for training_mode in os.listdir(self.model_dir):
|
|
|
|
|
mode_path = os.path.join(self.model_dir, training_mode)
|
|
|
|
|
if not os.path.isdir(mode_path) or training_mode == 'checkpoints' or training_mode == self.VERSION_FILE:
|
2025-07-02 11:05:23 +08:00
|
|
|
|
continue
|
2025-07-15 20:09:05 +08:00
|
|
|
|
|
|
|
|
|
for scope in os.listdir(mode_path):
|
|
|
|
|
scope_path = os.path.join(mode_path, scope)
|
|
|
|
|
if not os.path.isdir(scope_path): continue
|
|
|
|
|
|
|
|
|
|
is_global_agg_level = False
|
|
|
|
|
if training_mode == 'global' and os.listdir(scope_path):
|
|
|
|
|
try:
|
|
|
|
|
first_item_path = os.path.join(scope_path, os.listdir(scope_path)[0])
|
|
|
|
|
if os.path.isdir(first_item_path):
|
|
|
|
|
is_global_agg_level = True
|
|
|
|
|
except IndexError:
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
if is_global_agg_level:
|
|
|
|
|
for agg_method in os.listdir(scope_path):
|
|
|
|
|
agg_path = os.path.join(scope_path, agg_method)
|
|
|
|
|
if not os.path.isdir(agg_path): continue
|
|
|
|
|
for model_type in os.listdir(agg_path):
|
|
|
|
|
type_path = os.path.join(agg_path, model_type)
|
|
|
|
|
if not os.path.isdir(type_path): continue
|
|
|
|
|
for version_folder in os.listdir(type_path):
|
|
|
|
|
if version_folder.startswith('v'):
|
|
|
|
|
version_path = os.path.join(type_path, version_folder)
|
|
|
|
|
model_info = self._parse_info_from_path(version_path)
|
|
|
|
|
if model_info:
|
|
|
|
|
all_models.append(model_info)
|
|
|
|
|
else:
|
|
|
|
|
for model_type in os.listdir(scope_path):
|
|
|
|
|
type_path = os.path.join(scope_path, model_type)
|
|
|
|
|
if not os.path.isdir(type_path): continue
|
|
|
|
|
for version_folder in os.listdir(type_path):
|
|
|
|
|
if version_folder.startswith('v'):
|
|
|
|
|
version_path = os.path.join(type_path, version_folder)
|
|
|
|
|
model_info = self._parse_info_from_path(version_path)
|
|
|
|
|
if model_info:
|
|
|
|
|
all_models.append(model_info)
|
|
|
|
|
|
|
|
|
|
total_count = len(all_models)
|
|
|
|
|
if page and page_size:
|
|
|
|
|
start_index = (page - 1) * page_size
|
|
|
|
|
end_index = start_index + page_size
|
|
|
|
|
paginated_models = all_models[start_index:end_index]
|
|
|
|
|
else:
|
|
|
|
|
paginated_models = all_models
|
|
|
|
|
|
2025-07-02 11:05:23 +08:00
|
|
|
|
return {
|
|
|
|
|
'models': paginated_models,
|
|
|
|
|
'pagination': {
|
|
|
|
|
'total': total_count,
|
2025-07-15 20:09:05 +08:00
|
|
|
|
'page': page or 1,
|
|
|
|
|
'page_size': page_size or total_count,
|
|
|
|
|
'total_pages': (total_count + page_size - 1) // page_size if page_size and page_size > 0 else 1,
|
2025-07-02 11:05:23 +08:00
|
|
|
|
}
|
|
|
|
|
}
|
2025-07-15 20:09:05 +08:00
|
|
|
|
|
|
|
|
|
def _parse_info_from_path(self, version_path: str) -> Optional[Dict]:
|
|
|
|
|
"""从版本目录路径解析模型信息"""
|
2025-07-02 11:05:23 +08:00
|
|
|
|
try:
|
2025-07-15 20:09:05 +08:00
|
|
|
|
norm_path = os.path.normpath(version_path)
|
|
|
|
|
norm_model_dir = os.path.normpath(self.model_dir)
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
relative_path = os.path.relpath(norm_path, norm_model_dir)
|
|
|
|
|
parts = relative_path.split(os.sep)
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
info = {
|
|
|
|
|
'model_path': version_path,
|
|
|
|
|
'version': parts[-1],
|
|
|
|
|
'model_type': parts[-2]
|
|
|
|
|
}
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
training_mode = parts[0]
|
|
|
|
|
info['training_mode'] = training_mode
|
|
|
|
|
|
|
|
|
|
if training_mode == 'global':
|
|
|
|
|
info['scope'] = parts[1]
|
|
|
|
|
info['aggregation_method'] = parts[2]
|
|
|
|
|
info['model_identifier'] = self.get_model_identifier(info['model_type'], training_mode, info['scope'], info['aggregation_method'])
|
2025-07-02 11:05:23 +08:00
|
|
|
|
else:
|
2025-07-15 20:09:05 +08:00
|
|
|
|
info['scope'] = parts[1]
|
|
|
|
|
info['aggregation_method'] = None
|
|
|
|
|
info['model_identifier'] = self.get_model_identifier(info['model_type'], training_mode, info['scope'])
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
metadata_path = os.path.join(version_path, 'metadata.json')
|
|
|
|
|
if os.path.exists(metadata_path):
|
|
|
|
|
with open(metadata_path, 'r', encoding='utf-8') as f:
|
|
|
|
|
metadata = json.load(f)
|
|
|
|
|
info.update(metadata)
|
|
|
|
|
|
|
|
|
|
return info
|
|
|
|
|
except (IndexError, IOError) as e:
|
|
|
|
|
print(f"解析路径失败 {version_path}: {e}")
|
|
|
|
|
return None
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
|
|
|
|
# 全局模型管理器实例
|
2025-07-15 20:09:05 +08:00
|
|
|
|
model_manager = ModelManager()
|