2025-07-02 11:05:23 +08:00
|
|
|
|
"""
|
|
|
|
|
统一模型管理工具
|
|
|
|
|
处理模型文件的统一命名、存储和检索
|
2025-07-15 20:09:05 +08:00
|
|
|
|
遵循层级式目录结构和文件版本管理规则
|
2025-07-02 11:05:23 +08:00
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
import json
|
|
|
|
|
import torch
|
|
|
|
|
import glob
|
|
|
|
|
from datetime import datetime
|
2025-07-15 20:09:05 +08:00
|
|
|
|
from typing import List, Dict, Optional, Any
|
|
|
|
|
from threading import Lock
|
2025-07-02 11:05:23 +08:00
|
|
|
|
from core.config import DEFAULT_MODEL_DIR
|
|
|
|
|
|
|
|
|
|
class ModelManager:
|
2025-07-15 20:09:05 +08:00
|
|
|
|
"""
|
|
|
|
|
统一模型管理器,采用结构化目录和版本文件进行管理。
|
|
|
|
|
"""
|
|
|
|
|
VERSION_FILE = 'versions.json'
|
|
|
|
|
|
2025-07-02 11:05:23 +08:00
|
|
|
|
def __init__(self, model_dir: str = DEFAULT_MODEL_DIR):
|
2025-07-15 20:09:05 +08:00
|
|
|
|
self.model_dir = os.path.abspath(model_dir)
|
|
|
|
|
self.versions_path = os.path.join(self.model_dir, self.VERSION_FILE)
|
|
|
|
|
self._lock = Lock()
|
2025-07-02 11:05:23 +08:00
|
|
|
|
self.ensure_model_dir()
|
2025-07-15 20:09:05 +08:00
|
|
|
|
|
2025-07-02 11:05:23 +08:00
|
|
|
|
def ensure_model_dir(self):
|
2025-07-15 20:09:05 +08:00
|
|
|
|
"""确保模型根目录存在"""
|
|
|
|
|
os.makedirs(self.model_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
def _read_versions(self) -> Dict[str, int]:
|
|
|
|
|
"""线程安全地读取版本文件"""
|
|
|
|
|
with self._lock:
|
|
|
|
|
if not os.path.exists(self.versions_path):
|
|
|
|
|
return {}
|
|
|
|
|
try:
|
|
|
|
|
with open(self.versions_path, 'r', encoding='utf-8') as f:
|
|
|
|
|
return json.load(f)
|
|
|
|
|
except (json.JSONDecodeError, IOError):
|
|
|
|
|
return {}
|
|
|
|
|
|
|
|
|
|
def _write_versions(self, versions: Dict[str, int]):
|
|
|
|
|
"""线程安全地写入版本文件"""
|
|
|
|
|
with self._lock:
|
|
|
|
|
with open(self.versions_path, 'w', encoding='utf-8') as f:
|
|
|
|
|
json.dump(versions, f, indent=2)
|
|
|
|
|
|
|
|
|
|
def get_model_identifier(self,
|
|
|
|
|
model_type: str,
|
|
|
|
|
training_mode: str,
|
|
|
|
|
scope: str,
|
|
|
|
|
aggregation_method: Optional[str] = None) -> str:
|
2025-07-02 11:05:23 +08:00
|
|
|
|
"""
|
2025-07-15 20:09:05 +08:00
|
|
|
|
生成模型的唯一标识符,用于版本文件中的key。
|
2025-07-02 11:05:23 +08:00
|
|
|
|
"""
|
2025-07-15 20:09:05 +08:00
|
|
|
|
if training_mode == 'global':
|
|
|
|
|
return f"{training_mode}_{scope}_{aggregation_method}_{model_type}"
|
|
|
|
|
return f"{training_mode}_{scope}_{model_type}"
|
|
|
|
|
|
|
|
|
|
def get_next_version_number(self, model_identifier: str) -> int:
|
2025-07-02 11:05:23 +08:00
|
|
|
|
"""
|
2025-07-15 20:09:05 +08:00
|
|
|
|
获取指定模型的下一个版本号(整数)。
|
2025-07-02 11:05:23 +08:00
|
|
|
|
"""
|
2025-07-15 20:09:05 +08:00
|
|
|
|
versions = self._read_versions()
|
|
|
|
|
current_version = versions.get(model_identifier, 0)
|
|
|
|
|
return current_version + 1
|
|
|
|
|
|
|
|
|
|
def update_version(self, model_identifier: str, new_version: int):
|
|
|
|
|
"""
|
|
|
|
|
更新模型的最新版本号。
|
|
|
|
|
"""
|
|
|
|
|
versions = self._read_versions()
|
|
|
|
|
versions[model_identifier] = new_version
|
|
|
|
|
self._write_versions(versions)
|
|
|
|
|
|
|
|
|
|
def get_model_version_path(self,
|
|
|
|
|
model_type: str,
|
|
|
|
|
version: int,
|
2025-07-16 15:34:48 +08:00
|
|
|
|
training_mode: str,
|
2025-07-16 16:50:30 +08:00
|
|
|
|
scope: str,
|
|
|
|
|
aggregation_method: Optional[str] = None) -> str:
|
2025-07-15 20:09:05 +08:00
|
|
|
|
"""
|
2025-07-16 15:34:48 +08:00
|
|
|
|
根据 `xz训练模型保存规则.md` 中定义的新规则生成模型版本目录的完整路径。
|
2025-07-15 20:09:05 +08:00
|
|
|
|
"""
|
2025-07-16 15:34:48 +08:00
|
|
|
|
base_path = self.model_dir
|
2025-07-16 16:50:30 +08:00
|
|
|
|
path_parts = [base_path]
|
2025-07-16 15:34:48 +08:00
|
|
|
|
|
2025-07-16 16:50:30 +08:00
|
|
|
|
if training_mode == 'product':
|
|
|
|
|
# saved_models/product/{scope}/{model_type}/v{N}/
|
|
|
|
|
if not scope: raise ValueError("scope is required for 'product' training mode.")
|
|
|
|
|
path_parts.extend(['product', scope, model_type, f'v{version}'])
|
2025-07-16 15:34:48 +08:00
|
|
|
|
|
2025-07-16 16:50:30 +08:00
|
|
|
|
elif training_mode == 'store':
|
|
|
|
|
# saved_models/store/{scope}/{model_type}/v{N}/
|
|
|
|
|
if not scope: raise ValueError("scope is required for 'store' training mode.")
|
|
|
|
|
path_parts.extend(['store', scope, model_type, f'v{version}'])
|
|
|
|
|
|
|
|
|
|
elif training_mode == 'global':
|
|
|
|
|
# saved_models/global/{scope_path}/{aggregation_method}/{model_type}/v{N}/
|
|
|
|
|
if not scope: raise ValueError("scope is required for 'global' training mode.")
|
|
|
|
|
if not aggregation_method: raise ValueError("aggregation_method is required for 'global' training mode.")
|
|
|
|
|
|
|
|
|
|
scope_parts = scope.split('/')
|
|
|
|
|
path_parts.extend(['global', *scope_parts, str(aggregation_method), model_type, f'v{version}'])
|
|
|
|
|
|
2025-07-16 15:34:48 +08:00
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"不支持的 training_mode: {training_mode}")
|
|
|
|
|
|
|
|
|
|
return os.path.join(*path_parts)
|
2025-07-15 20:09:05 +08:00
|
|
|
|
|
|
|
|
|
def save_model_artifact(self,
|
|
|
|
|
artifact_data: Any,
|
|
|
|
|
artifact_name: str,
|
|
|
|
|
model_version_path: str):
|
|
|
|
|
"""
|
|
|
|
|
在指定的模型版本目录下保存一个产物。
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
Args:
|
|
|
|
|
artifact_data: 要保存的数据 (e.g., model state dict, figure object).
|
|
|
|
|
artifact_name: 标准化的产物文件名 (e.g., 'model.pth', 'loss_curve.png').
|
|
|
|
|
model_version_path: 模型版本目录的路径.
|
|
|
|
|
"""
|
|
|
|
|
os.makedirs(model_version_path, exist_ok=True)
|
|
|
|
|
full_path = os.path.join(model_version_path, artifact_name)
|
|
|
|
|
|
|
|
|
|
if artifact_name.endswith('.pth'):
|
|
|
|
|
torch.save(artifact_data, full_path)
|
|
|
|
|
elif artifact_name.endswith('.png') and hasattr(artifact_data, 'savefig'):
|
|
|
|
|
artifact_data.savefig(full_path, dpi=300, bbox_inches='tight')
|
|
|
|
|
elif artifact_name.endswith('.json'):
|
|
|
|
|
with open(full_path, 'w', encoding='utf-8') as f:
|
|
|
|
|
json.dump(artifact_data, f, indent=2, ensure_ascii=False)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(f"不支持的产物类型: {artifact_name}")
|
|
|
|
|
|
|
|
|
|
print(f"产物已保存: {full_path}")
|
|
|
|
|
|
|
|
|
|
def list_models(self,
|
2025-07-02 11:05:23 +08:00
|
|
|
|
page: Optional[int] = None,
|
|
|
|
|
page_size: Optional[int] = None) -> Dict:
|
|
|
|
|
"""
|
2025-07-16 15:34:48 +08:00
|
|
|
|
通过扫描目录结构来列出所有模型 (适配新结构)。
|
2025-07-02 11:05:23 +08:00
|
|
|
|
"""
|
2025-07-15 20:09:05 +08:00
|
|
|
|
all_models = []
|
2025-07-16 15:34:48 +08:00
|
|
|
|
# 使用glob查找所有版本目录
|
|
|
|
|
search_pattern = os.path.join(self.model_dir, '**', 'v*')
|
|
|
|
|
|
|
|
|
|
for version_path in glob.glob(search_pattern, recursive=True):
|
|
|
|
|
# 确保它是一个目录并且包含 metadata.json
|
|
|
|
|
metadata_path = os.path.join(version_path, 'metadata.json')
|
|
|
|
|
if os.path.isdir(version_path) and os.path.exists(metadata_path):
|
|
|
|
|
model_info = self._parse_info_from_path(version_path)
|
|
|
|
|
if model_info:
|
|
|
|
|
all_models.append(model_info)
|
|
|
|
|
|
|
|
|
|
# 按时间戳降序排序
|
|
|
|
|
all_models.sort(key=lambda x: x.get('timestamp', ''), reverse=True)
|
2025-07-15 20:09:05 +08:00
|
|
|
|
|
|
|
|
|
total_count = len(all_models)
|
|
|
|
|
if page and page_size:
|
|
|
|
|
start_index = (page - 1) * page_size
|
|
|
|
|
end_index = start_index + page_size
|
|
|
|
|
paginated_models = all_models[start_index:end_index]
|
|
|
|
|
else:
|
|
|
|
|
paginated_models = all_models
|
|
|
|
|
|
2025-07-02 11:05:23 +08:00
|
|
|
|
return {
|
|
|
|
|
'models': paginated_models,
|
|
|
|
|
'pagination': {
|
|
|
|
|
'total': total_count,
|
2025-07-15 20:09:05 +08:00
|
|
|
|
'page': page or 1,
|
|
|
|
|
'page_size': page_size or total_count,
|
|
|
|
|
'total_pages': (total_count + page_size - 1) // page_size if page_size and page_size > 0 else 1,
|
2025-07-02 11:05:23 +08:00
|
|
|
|
}
|
|
|
|
|
}
|
2025-07-15 20:09:05 +08:00
|
|
|
|
|
|
|
|
|
def _parse_info_from_path(self, version_path: str) -> Optional[Dict]:
|
2025-07-16 15:34:48 +08:00
|
|
|
|
"""根据新的目录结构从版本目录路径解析模型信息"""
|
2025-07-02 11:05:23 +08:00
|
|
|
|
try:
|
2025-07-15 20:09:05 +08:00
|
|
|
|
norm_path = os.path.normpath(version_path)
|
|
|
|
|
norm_model_dir = os.path.normpath(self.model_dir)
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
relative_path = os.path.relpath(norm_path, norm_model_dir)
|
|
|
|
|
parts = relative_path.split(os.sep)
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
2025-07-16 16:50:30 +08:00
|
|
|
|
if len(parts) < 4:
|
|
|
|
|
return None
|
2025-07-16 15:34:48 +08:00
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
info = {
|
|
|
|
|
'model_path': version_path,
|
|
|
|
|
'version': parts[-1],
|
2025-07-16 15:34:48 +08:00
|
|
|
|
'model_type': parts[-2],
|
2025-07-16 16:50:30 +08:00
|
|
|
|
'training_mode': parts[0],
|
2025-07-16 15:34:48 +08:00
|
|
|
|
'store_id': None,
|
|
|
|
|
'product_id': None,
|
2025-07-16 16:50:30 +08:00
|
|
|
|
'aggregation_method': None,
|
|
|
|
|
'scope': None
|
2025-07-15 20:09:05 +08:00
|
|
|
|
}
|
2025-07-16 16:50:30 +08:00
|
|
|
|
|
|
|
|
|
mode = parts[0]
|
|
|
|
|
if mode == 'product':
|
|
|
|
|
# product/{scope}/mlstm/v1
|
|
|
|
|
info['scope'] = parts[1]
|
|
|
|
|
elif mode == 'store':
|
|
|
|
|
# store/{scope}/mlstm/v1
|
|
|
|
|
info['scope'] = parts[1]
|
|
|
|
|
elif mode == 'global':
|
|
|
|
|
# global/{scope...}/sum/mlstm/v1
|
|
|
|
|
info['aggregation_method'] = parts[-3]
|
|
|
|
|
info['scope'] = '/'.join(parts[1:-3])
|
2025-07-02 11:05:23 +08:00
|
|
|
|
else:
|
2025-07-16 16:50:30 +08:00
|
|
|
|
return None # 未知模式
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
2025-07-15 20:09:05 +08:00
|
|
|
|
metadata_path = os.path.join(version_path, 'metadata.json')
|
|
|
|
|
if os.path.exists(metadata_path):
|
|
|
|
|
with open(metadata_path, 'r', encoding='utf-8') as f:
|
|
|
|
|
metadata = json.load(f)
|
|
|
|
|
info.update(metadata)
|
2025-07-16 16:50:30 +08:00
|
|
|
|
# 确保从路径解析出的关键信息覆盖元数据中的,因为路径是权威来源
|
|
|
|
|
info['version'] = parts[-1]
|
2025-07-16 15:34:48 +08:00
|
|
|
|
info['model_type'] = parts[-2]
|
2025-07-16 16:50:30 +08:00
|
|
|
|
info['training_mode'] = parts[0]
|
2025-07-15 20:09:05 +08:00
|
|
|
|
|
|
|
|
|
return info
|
|
|
|
|
except (IndexError, IOError) as e:
|
|
|
|
|
print(f"解析路径失败 {version_path}: {e}")
|
|
|
|
|
return None
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
|
|
|
|
# 全局模型管理器实例
|
2025-07-15 20:09:05 +08:00
|
|
|
|
model_manager = ModelManager()
|