ShopTRAINING/server/utils/model_manager.py
xz2000 e999ed4af2 ### 2025-07-15 (续): 训练器与核心调用层重构
**核心目标**: 将新的 `ModelManager` 统一应用到项目中所有剩余的模型训练器,并重构核心调用逻辑,确保整个训练链路的架构一致性。

**1. 修改 `server/trainers/kan_trainer.py`**
*   **内容**: 完全重写了 `kan_trainer.py`。
    *   **适配接口**: 函数签名与 `mlstm_trainer` 对齐,增加了 `socketio`, `task_id`, `patience` 等参数。
    *   **集成 `ModelManager`**: 移除了所有旧的、手动的保存逻辑,改为在训练开始时调用 `model_manager` 获取版本号和路径。
    *   **标准化产物保存**: 所有产物(模型、元数据、检查点、损失曲线)均通过 `model_manager.save_model_artifact()` 保存。
    *   **增加健壮性**: 引入了早停(Early Stopping)和保存最佳检查点(Best Checkpoint)的逻辑。

**2. 修改 `server/trainers/tcn_trainer.py`**
*   **内容**: 完全重写了 `tcn_trainer.py`,应用了与 `kan_trainer` 完全相同的重构模式。
    *   移除了旧的 `save_checkpoint` 辅助函数和基于 `core.config` 的版本管理。
    *   全面转向使用 `model_manager` 进行版本控制和文件保存。
    *   统一了函数签名和进度反馈逻辑。

**3. 修改 `server/trainers/transformer_trainer.py`**
*   **内容**: 完全重写了 `transformer_trainer.py`,完成了对所有训练器的统一重构。
    *   移除了所有遗留的、基于文件名的路径拼接和保存逻辑。
    *   实现了与其它训练器一致的、基于 `ModelManager` 的标准化训练流程。

**4. 修改 `server/core/predictor.py`**
*   **内容**: 对核心预测器类 `PharmacyPredictor` 进行了彻底重构。
    *   **统一调用接口**: `train_model` 方法现在以完全一致的方式调用所有(`mlstm`, `kan`, `tcn`, `transformer`)训练器。
    *   **移除旧逻辑**: 删除了 `_parse_model_filename` 等所有基于文件名解析的旧方法。
    *   **适配 `ModelManager`**: `list_models` 和 `delete_model` 等方法现在直接调用 `model_manager` 的相应功能,不再自己实现逻辑。
    *   **简化 `predict`**: 预测方法现在直接接收标准化的模型版本路径 (`model_version_path`) 作为输入,逻辑更清晰。
2025-07-15 20:09:09 +08:00

226 lines
9.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
统一模型管理工具
处理模型文件的统一命名、存储和检索
遵循层级式目录结构和文件版本管理规则
"""
import os
import json
import torch
import glob
from datetime import datetime
from typing import List, Dict, Optional, Any
from threading import Lock
from core.config import DEFAULT_MODEL_DIR
class ModelManager:
"""
统一模型管理器,采用结构化目录和版本文件进行管理。
"""
VERSION_FILE = 'versions.json'
def __init__(self, model_dir: str = DEFAULT_MODEL_DIR):
self.model_dir = os.path.abspath(model_dir)
self.versions_path = os.path.join(self.model_dir, self.VERSION_FILE)
self._lock = Lock()
self.ensure_model_dir()
def ensure_model_dir(self):
"""确保模型根目录存在"""
os.makedirs(self.model_dir, exist_ok=True)
def _read_versions(self) -> Dict[str, int]:
"""线程安全地读取版本文件"""
with self._lock:
if not os.path.exists(self.versions_path):
return {}
try:
with open(self.versions_path, 'r', encoding='utf-8') as f:
return json.load(f)
except (json.JSONDecodeError, IOError):
return {}
def _write_versions(self, versions: Dict[str, int]):
"""线程安全地写入版本文件"""
with self._lock:
with open(self.versions_path, 'w', encoding='utf-8') as f:
json.dump(versions, f, indent=2)
def get_model_identifier(self,
model_type: str,
training_mode: str,
scope: str,
aggregation_method: Optional[str] = None) -> str:
"""
生成模型的唯一标识符用于版本文件中的key。
"""
if training_mode == 'global':
return f"{training_mode}_{scope}_{aggregation_method}_{model_type}"
return f"{training_mode}_{scope}_{model_type}"
def get_next_version_number(self, model_identifier: str) -> int:
"""
获取指定模型的下一个版本号(整数)。
"""
versions = self._read_versions()
current_version = versions.get(model_identifier, 0)
return current_version + 1
def update_version(self, model_identifier: str, new_version: int):
"""
更新模型的最新版本号。
"""
versions = self._read_versions()
versions[model_identifier] = new_version
self._write_versions(versions)
def get_model_version_path(self,
model_type: str,
training_mode: str,
scope: str,
version: int,
aggregation_method: Optional[str] = None) -> str:
"""
根据新规则生成模型版本目录的完整路径。
"""
base_path = os.path.join(self.model_dir, training_mode, scope)
if training_mode == 'global' and aggregation_method:
base_path = os.path.join(base_path, str(aggregation_method))
version_path = os.path.join(base_path, model_type, f'v{version}')
return version_path
def save_model_artifact(self,
artifact_data: Any,
artifact_name: str,
model_version_path: str):
"""
在指定的模型版本目录下保存一个产物。
Args:
artifact_data: 要保存的数据 (e.g., model state dict, figure object).
artifact_name: 标准化的产物文件名 (e.g., 'model.pth', 'loss_curve.png').
model_version_path: 模型版本目录的路径.
"""
os.makedirs(model_version_path, exist_ok=True)
full_path = os.path.join(model_version_path, artifact_name)
if artifact_name.endswith('.pth'):
torch.save(artifact_data, full_path)
elif artifact_name.endswith('.png') and hasattr(artifact_data, 'savefig'):
artifact_data.savefig(full_path, dpi=300, bbox_inches='tight')
elif artifact_name.endswith('.json'):
with open(full_path, 'w', encoding='utf-8') as f:
json.dump(artifact_data, f, indent=2, ensure_ascii=False)
else:
raise ValueError(f"不支持的产物类型: {artifact_name}")
print(f"产物已保存: {full_path}")
def list_models(self,
page: Optional[int] = None,
page_size: Optional[int] = None) -> Dict:
"""
通过扫描目录结构来列出所有模型。
"""
all_models = []
for training_mode in os.listdir(self.model_dir):
mode_path = os.path.join(self.model_dir, training_mode)
if not os.path.isdir(mode_path) or training_mode == 'checkpoints' or training_mode == self.VERSION_FILE:
continue
for scope in os.listdir(mode_path):
scope_path = os.path.join(mode_path, scope)
if not os.path.isdir(scope_path): continue
is_global_agg_level = False
if training_mode == 'global' and os.listdir(scope_path):
try:
first_item_path = os.path.join(scope_path, os.listdir(scope_path)[0])
if os.path.isdir(first_item_path):
is_global_agg_level = True
except IndexError:
continue
if is_global_agg_level:
for agg_method in os.listdir(scope_path):
agg_path = os.path.join(scope_path, agg_method)
if not os.path.isdir(agg_path): continue
for model_type in os.listdir(agg_path):
type_path = os.path.join(agg_path, model_type)
if not os.path.isdir(type_path): continue
for version_folder in os.listdir(type_path):
if version_folder.startswith('v'):
version_path = os.path.join(type_path, version_folder)
model_info = self._parse_info_from_path(version_path)
if model_info:
all_models.append(model_info)
else:
for model_type in os.listdir(scope_path):
type_path = os.path.join(scope_path, model_type)
if not os.path.isdir(type_path): continue
for version_folder in os.listdir(type_path):
if version_folder.startswith('v'):
version_path = os.path.join(type_path, version_folder)
model_info = self._parse_info_from_path(version_path)
if model_info:
all_models.append(model_info)
total_count = len(all_models)
if page and page_size:
start_index = (page - 1) * page_size
end_index = start_index + page_size
paginated_models = all_models[start_index:end_index]
else:
paginated_models = all_models
return {
'models': paginated_models,
'pagination': {
'total': total_count,
'page': page or 1,
'page_size': page_size or total_count,
'total_pages': (total_count + page_size - 1) // page_size if page_size and page_size > 0 else 1,
}
}
def _parse_info_from_path(self, version_path: str) -> Optional[Dict]:
"""从版本目录路径解析模型信息"""
try:
norm_path = os.path.normpath(version_path)
norm_model_dir = os.path.normpath(self.model_dir)
relative_path = os.path.relpath(norm_path, norm_model_dir)
parts = relative_path.split(os.sep)
info = {
'model_path': version_path,
'version': parts[-1],
'model_type': parts[-2]
}
training_mode = parts[0]
info['training_mode'] = training_mode
if training_mode == 'global':
info['scope'] = parts[1]
info['aggregation_method'] = parts[2]
info['model_identifier'] = self.get_model_identifier(info['model_type'], training_mode, info['scope'], info['aggregation_method'])
else:
info['scope'] = parts[1]
info['aggregation_method'] = None
info['model_identifier'] = self.get_model_identifier(info['model_type'], training_mode, info['scope'])
metadata_path = os.path.join(version_path, 'metadata.json')
if os.path.exists(metadata_path):
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
info.update(metadata)
return info
except (IndexError, IOError) as e:
print(f"解析路径失败 {version_path}: {e}")
return None
# 全局模型管理器实例
model_manager = ModelManager()