**核心目标**: 将新的 `ModelManager` 统一应用到项目中所有剩余的模型训练器,并重构核心调用逻辑,确保整个训练链路的架构一致性。 **1. 修改 `server/trainers/kan_trainer.py`** * **内容**: 完全重写了 `kan_trainer.py`。 * **适配接口**: 函数签名与 `mlstm_trainer` 对齐,增加了 `socketio`, `task_id`, `patience` 等参数。 * **集成 `ModelManager`**: 移除了所有旧的、手动的保存逻辑,改为在训练开始时调用 `model_manager` 获取版本号和路径。 * **标准化产物保存**: 所有产物(模型、元数据、检查点、损失曲线)均通过 `model_manager.save_model_artifact()` 保存。 * **增加健壮性**: 引入了早停(Early Stopping)和保存最佳检查点(Best Checkpoint)的逻辑。 **2. 修改 `server/trainers/tcn_trainer.py`** * **内容**: 完全重写了 `tcn_trainer.py`,应用了与 `kan_trainer` 完全相同的重构模式。 * 移除了旧的 `save_checkpoint` 辅助函数和基于 `core.config` 的版本管理。 * 全面转向使用 `model_manager` 进行版本控制和文件保存。 * 统一了函数签名和进度反馈逻辑。 **3. 修改 `server/trainers/transformer_trainer.py`** * **内容**: 完全重写了 `transformer_trainer.py`,完成了对所有训练器的统一重构。 * 移除了所有遗留的、基于文件名的路径拼接和保存逻辑。 * 实现了与其它训练器一致的、基于 `ModelManager` 的标准化训练流程。 **4. 修改 `server/core/predictor.py`** * **内容**: 对核心预测器类 `PharmacyPredictor` 进行了彻底重构。 * **统一调用接口**: `train_model` 方法现在以完全一致的方式调用所有(`mlstm`, `kan`, `tcn`, `transformer`)训练器。 * **移除旧逻辑**: 删除了 `_parse_model_filename` 等所有基于文件名解析的旧方法。 * **适配 `ModelManager`**: `list_models` 和 `delete_model` 等方法现在直接调用 `model_manager` 的相应功能,不再自己实现逻辑。 * **简化 `predict`**: 预测方法现在直接接收标准化的模型版本路径 (`model_version_path`) 作为输入,逻辑更清晰。
226 lines
9.3 KiB
Python
226 lines
9.3 KiB
Python
"""
|
||
统一模型管理工具
|
||
处理模型文件的统一命名、存储和检索
|
||
遵循层级式目录结构和文件版本管理规则
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import torch
|
||
import glob
|
||
from datetime import datetime
|
||
from typing import List, Dict, Optional, Any
|
||
from threading import Lock
|
||
from core.config import DEFAULT_MODEL_DIR
|
||
|
||
class ModelManager:
|
||
"""
|
||
统一模型管理器,采用结构化目录和版本文件进行管理。
|
||
"""
|
||
VERSION_FILE = 'versions.json'
|
||
|
||
def __init__(self, model_dir: str = DEFAULT_MODEL_DIR):
|
||
self.model_dir = os.path.abspath(model_dir)
|
||
self.versions_path = os.path.join(self.model_dir, self.VERSION_FILE)
|
||
self._lock = Lock()
|
||
self.ensure_model_dir()
|
||
|
||
def ensure_model_dir(self):
|
||
"""确保模型根目录存在"""
|
||
os.makedirs(self.model_dir, exist_ok=True)
|
||
|
||
def _read_versions(self) -> Dict[str, int]:
|
||
"""线程安全地读取版本文件"""
|
||
with self._lock:
|
||
if not os.path.exists(self.versions_path):
|
||
return {}
|
||
try:
|
||
with open(self.versions_path, 'r', encoding='utf-8') as f:
|
||
return json.load(f)
|
||
except (json.JSONDecodeError, IOError):
|
||
return {}
|
||
|
||
def _write_versions(self, versions: Dict[str, int]):
|
||
"""线程安全地写入版本文件"""
|
||
with self._lock:
|
||
with open(self.versions_path, 'w', encoding='utf-8') as f:
|
||
json.dump(versions, f, indent=2)
|
||
|
||
def get_model_identifier(self,
|
||
model_type: str,
|
||
training_mode: str,
|
||
scope: str,
|
||
aggregation_method: Optional[str] = None) -> str:
|
||
"""
|
||
生成模型的唯一标识符,用于版本文件中的key。
|
||
"""
|
||
if training_mode == 'global':
|
||
return f"{training_mode}_{scope}_{aggregation_method}_{model_type}"
|
||
return f"{training_mode}_{scope}_{model_type}"
|
||
|
||
def get_next_version_number(self, model_identifier: str) -> int:
|
||
"""
|
||
获取指定模型的下一个版本号(整数)。
|
||
"""
|
||
versions = self._read_versions()
|
||
current_version = versions.get(model_identifier, 0)
|
||
return current_version + 1
|
||
|
||
def update_version(self, model_identifier: str, new_version: int):
|
||
"""
|
||
更新模型的最新版本号。
|
||
"""
|
||
versions = self._read_versions()
|
||
versions[model_identifier] = new_version
|
||
self._write_versions(versions)
|
||
|
||
def get_model_version_path(self,
|
||
model_type: str,
|
||
training_mode: str,
|
||
scope: str,
|
||
version: int,
|
||
aggregation_method: Optional[str] = None) -> str:
|
||
"""
|
||
根据新规则生成模型版本目录的完整路径。
|
||
"""
|
||
base_path = os.path.join(self.model_dir, training_mode, scope)
|
||
if training_mode == 'global' and aggregation_method:
|
||
base_path = os.path.join(base_path, str(aggregation_method))
|
||
|
||
version_path = os.path.join(base_path, model_type, f'v{version}')
|
||
return version_path
|
||
|
||
def save_model_artifact(self,
|
||
artifact_data: Any,
|
||
artifact_name: str,
|
||
model_version_path: str):
|
||
"""
|
||
在指定的模型版本目录下保存一个产物。
|
||
|
||
Args:
|
||
artifact_data: 要保存的数据 (e.g., model state dict, figure object).
|
||
artifact_name: 标准化的产物文件名 (e.g., 'model.pth', 'loss_curve.png').
|
||
model_version_path: 模型版本目录的路径.
|
||
"""
|
||
os.makedirs(model_version_path, exist_ok=True)
|
||
full_path = os.path.join(model_version_path, artifact_name)
|
||
|
||
if artifact_name.endswith('.pth'):
|
||
torch.save(artifact_data, full_path)
|
||
elif artifact_name.endswith('.png') and hasattr(artifact_data, 'savefig'):
|
||
artifact_data.savefig(full_path, dpi=300, bbox_inches='tight')
|
||
elif artifact_name.endswith('.json'):
|
||
with open(full_path, 'w', encoding='utf-8') as f:
|
||
json.dump(artifact_data, f, indent=2, ensure_ascii=False)
|
||
else:
|
||
raise ValueError(f"不支持的产物类型: {artifact_name}")
|
||
|
||
print(f"产物已保存: {full_path}")
|
||
|
||
def list_models(self,
|
||
page: Optional[int] = None,
|
||
page_size: Optional[int] = None) -> Dict:
|
||
"""
|
||
通过扫描目录结构来列出所有模型。
|
||
"""
|
||
all_models = []
|
||
for training_mode in os.listdir(self.model_dir):
|
||
mode_path = os.path.join(self.model_dir, training_mode)
|
||
if not os.path.isdir(mode_path) or training_mode == 'checkpoints' or training_mode == self.VERSION_FILE:
|
||
continue
|
||
|
||
for scope in os.listdir(mode_path):
|
||
scope_path = os.path.join(mode_path, scope)
|
||
if not os.path.isdir(scope_path): continue
|
||
|
||
is_global_agg_level = False
|
||
if training_mode == 'global' and os.listdir(scope_path):
|
||
try:
|
||
first_item_path = os.path.join(scope_path, os.listdir(scope_path)[0])
|
||
if os.path.isdir(first_item_path):
|
||
is_global_agg_level = True
|
||
except IndexError:
|
||
continue
|
||
|
||
if is_global_agg_level:
|
||
for agg_method in os.listdir(scope_path):
|
||
agg_path = os.path.join(scope_path, agg_method)
|
||
if not os.path.isdir(agg_path): continue
|
||
for model_type in os.listdir(agg_path):
|
||
type_path = os.path.join(agg_path, model_type)
|
||
if not os.path.isdir(type_path): continue
|
||
for version_folder in os.listdir(type_path):
|
||
if version_folder.startswith('v'):
|
||
version_path = os.path.join(type_path, version_folder)
|
||
model_info = self._parse_info_from_path(version_path)
|
||
if model_info:
|
||
all_models.append(model_info)
|
||
else:
|
||
for model_type in os.listdir(scope_path):
|
||
type_path = os.path.join(scope_path, model_type)
|
||
if not os.path.isdir(type_path): continue
|
||
for version_folder in os.listdir(type_path):
|
||
if version_folder.startswith('v'):
|
||
version_path = os.path.join(type_path, version_folder)
|
||
model_info = self._parse_info_from_path(version_path)
|
||
if model_info:
|
||
all_models.append(model_info)
|
||
|
||
total_count = len(all_models)
|
||
if page and page_size:
|
||
start_index = (page - 1) * page_size
|
||
end_index = start_index + page_size
|
||
paginated_models = all_models[start_index:end_index]
|
||
else:
|
||
paginated_models = all_models
|
||
|
||
return {
|
||
'models': paginated_models,
|
||
'pagination': {
|
||
'total': total_count,
|
||
'page': page or 1,
|
||
'page_size': page_size or total_count,
|
||
'total_pages': (total_count + page_size - 1) // page_size if page_size and page_size > 0 else 1,
|
||
}
|
||
}
|
||
|
||
def _parse_info_from_path(self, version_path: str) -> Optional[Dict]:
|
||
"""从版本目录路径解析模型信息"""
|
||
try:
|
||
norm_path = os.path.normpath(version_path)
|
||
norm_model_dir = os.path.normpath(self.model_dir)
|
||
|
||
relative_path = os.path.relpath(norm_path, norm_model_dir)
|
||
parts = relative_path.split(os.sep)
|
||
|
||
info = {
|
||
'model_path': version_path,
|
||
'version': parts[-1],
|
||
'model_type': parts[-2]
|
||
}
|
||
|
||
training_mode = parts[0]
|
||
info['training_mode'] = training_mode
|
||
|
||
if training_mode == 'global':
|
||
info['scope'] = parts[1]
|
||
info['aggregation_method'] = parts[2]
|
||
info['model_identifier'] = self.get_model_identifier(info['model_type'], training_mode, info['scope'], info['aggregation_method'])
|
||
else:
|
||
info['scope'] = parts[1]
|
||
info['aggregation_method'] = None
|
||
info['model_identifier'] = self.get_model_identifier(info['model_type'], training_mode, info['scope'])
|
||
|
||
metadata_path = os.path.join(version_path, 'metadata.json')
|
||
if os.path.exists(metadata_path):
|
||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||
metadata = json.load(f)
|
||
info.update(metadata)
|
||
|
||
return info
|
||
except (IndexError, IOError) as e:
|
||
print(f"解析路径失败 {version_path}: {e}")
|
||
return None
|
||
|
||
# 全局模型管理器实例
|
||
model_manager = ModelManager() |