import os import json import shutil import uuid from typing import Dict, Any, List, Tuple from server.repositories.model_repository import ModelRepository from server.utils.hashing import generate_hash from server.services.version_manager import VersionManager class ModelManagementService: """ 负责根据训练负载(payload)来管理模型的整个生命周期, 包括路径构建、版本控制、文件保存和数据库记录。 """ def __init__(self, repository: ModelRepository, base_path='saved_models'): self.repository = repository self.base_path = base_path self.version_manager = VersionManager(base_path) def save_model_for_training(self, payload: Dict[str, Any], artifacts: Dict[str, str]): """ 主分发函数,根据 training_mode 调用相应的处理方法。 在保存前,会强制校验所有必需的产物是否都已提供。 """ # 步骤1:强制校验产物完整性 REQUIRED_ARTIFACTS = {'model.pth', 'checkpoint_best.pth', 'metadata.json', 'loss_curve.png'} provided_artifacts = set(artifacts.keys()) if not REQUIRED_ARTIFACTS.issubset(provided_artifacts): missing = REQUIRED_ARTIFACTS - provided_artifacts raise ValueError(f"模型产物不完整,缺少以下必需文件: {', '.join(missing)}") # 步骤2:根据训练模式获取路径和数据库记录 training_mode = payload.get('training_mode') handler_map = { 'product': self._handle_product_training, 'store': self._handle_store_training, 'global': self._handle_global_training, } handler = handler_map.get(training_mode) if not handler: raise ValueError(f"未知的训练模式: {training_mode}") final_path, db_record = handler(payload) # 步骤3:创建目录并移动产物文件 os.makedirs(final_path, exist_ok=True) for artifact_name, temp_path in artifacts.items(): # metadata.json 由db_record生成,特殊处理 if artifact_name == 'metadata.json': continue shutil.move(temp_path, os.path.join(final_path, artifact_name)) # 步骤4:写入最终的元数据文件 # 将训练器生成的元数据与服务层生成的元数据合并 # 从临时文件中读取训练器生成的元数据 trainer_metadata = {} metadata_path = artifacts.get('metadata.json') if metadata_path and os.path.exists(metadata_path): with open(metadata_path, 'r', encoding='utf-8') as f: # 增加异常处理,防止因文件为空或格式错误导致整个流程失败 try: trainer_metadata = json.load(f) except json.JSONDecodeError: print(f"警告: 无法解析元数据文件 {metadata_path}。文件可能为空或格式不正确。") # 合并元数据 db_record.update(trainer_metadata) with open(os.path.join(final_path, 'metadata.json'), 'w', encoding='utf-8') as f: json.dump(db_record, f, indent=4, ensure_ascii=False) # 步骤5:将最终记录添加到数据库 self.repository.add_model_version(db_record) return final_path, db_record def _get_scope_path_and_definition(self, ids: List[str]) -> Tuple[str, Dict]: """根据ID列表获取路径片段和范围定义 (条件哈希)""" if len(ids) == 1: return ids[0], {'type': 'single', 'id': ids[0]} # 只有当ID多于一个时才使用哈希 hash_val = generate_hash(ids) return hash_val, {'type': 'hash', 'ids': sorted(ids)} def _handle_product_training(self, payload: Dict[str, Any]) -> Tuple[str, Dict]: product_id = payload.get('product_id') if not product_id: raise ValueError("产品训练模式下 'product_id' 是必需的") model_type = payload['model_type'] model_base_path = os.path.join('product', product_id, model_type) next_version = self.version_manager.get_next_version(model_base_path.replace(os.sep, '/')) final_path = os.path.join(self.base_path, model_base_path, f'v{next_version}') model_uid = str(uuid.uuid4()) db_record = { 'model_uid': model_uid, 'training_mode': 'product', 'model_type': model_type, 'version': next_version, 'path': final_path, 'scope': {'product_id': product_id}, **payload.get('metrics', {}) } return final_path, db_record def _handle_store_training(self, payload: Dict[str, Any]) -> Tuple[str, Dict]: store_id = payload.get('store_id') if not store_id: raise ValueError("店铺训练模式下 'store_id' 是必需的") model_type = payload['model_type'] scope_path = store_id scope_definition = {'type': 'single', 'id': store_id} model_base_path = os.path.join('store', scope_path, model_type) next_version = self.version_manager.get_next_version(model_base_path.replace(os.sep, '/')) final_path = os.path.join(self.base_path, model_base_path, f'v{next_version}') model_uid = str(uuid.uuid4()) db_record = { 'model_uid': model_uid, 'training_mode': 'store', 'model_type': model_type, 'version': next_version, 'path': final_path, 'scope': scope_definition, **payload.get('metrics', {}) } return final_path, db_record def _handle_global_training(self, payload: Dict[str, Any]) -> Tuple[str, Dict]: store_ids = payload.get('store_ids', []) product_ids = payload.get('product_ids', []) model_type = payload['model_type'] aggregation = payload.get('aggregation_method', 'all') scope_path_parts = [] scope_definition = {} if store_ids: s_path, s_def = self._get_scope_path_and_definition(store_ids) scope_path_parts.append(f"S_{s_path}") scope_definition['stores'] = s_def if product_ids: p_path, p_def = self._get_scope_path_and_definition(product_ids) scope_path_parts.append(f"P_{p_path}") scope_definition['products'] = p_def scope_path = "_".join(scope_path_parts) if scope_path_parts else "all" model_base_path = os.path.join('global', scope_path, aggregation, model_type) next_version = self.version_manager.get_next_version(model_base_path.replace(os.sep, '/')) final_path = os.path.join(self.base_path, model_base_path, f'v{next_version}') model_uid = str(uuid.uuid4()) db_record = { 'model_uid': model_uid, 'training_mode': 'global', 'model_type': model_type, 'version': next_version, 'path': final_path, 'scope': scope_definition, 'aggregation_method': aggregation, **payload.get('metrics', {}) } return final_path, db_record