ShopTRAINING/server/services/model_management_service.py

173 lines
7.1 KiB
Python
Raw Normal View History

import os
import json
import shutil
import uuid
from typing import Dict, Any, List, Tuple
from server.repositories.model_repository import ModelRepository
from server.utils.hashing import generate_hash
from server.services.version_manager import VersionManager
class ModelManagementService:
"""
负责根据训练负载payload来管理模型的整个生命周期
包括路径构建版本控制文件保存和数据库记录
"""
def __init__(self, repository: ModelRepository, base_path='saved_models'):
self.repository = repository
self.base_path = base_path
self.version_manager = VersionManager(base_path)
def save_model_for_training(self, payload: Dict[str, Any], artifacts: Dict[str, str]):
"""
主分发函数根据 training_mode 调用相应的处理方法
在保存前会强制校验所有必需的产物是否都已提供
"""
# 步骤1强制校验产物完整性
REQUIRED_ARTIFACTS = {'model.pth', 'checkpoint_best.pth', 'metadata.json', 'loss_curve.png'}
provided_artifacts = set(artifacts.keys())
if not REQUIRED_ARTIFACTS.issubset(provided_artifacts):
missing = REQUIRED_ARTIFACTS - provided_artifacts
raise ValueError(f"模型产物不完整,缺少以下必需文件: {', '.join(missing)}")
# 步骤2根据训练模式获取路径和数据库记录
training_mode = payload.get('training_mode')
handler_map = {
'product': self._handle_product_training,
'store': self._handle_store_training,
'global': self._handle_global_training,
}
handler = handler_map.get(training_mode)
if not handler:
raise ValueError(f"未知的训练模式: {training_mode}")
final_path, db_record = handler(payload)
# 步骤3创建目录并移动产物文件
os.makedirs(final_path, exist_ok=True)
for artifact_name, temp_path in artifacts.items():
# metadata.json 由db_record生成特殊处理
if artifact_name == 'metadata.json':
continue
shutil.move(temp_path, os.path.join(final_path, artifact_name))
# 步骤4写入最终的元数据文件
# 将训练器生成的元数据与服务层生成的元数据合并
# 从临时文件中读取训练器生成的元数据
trainer_metadata = {}
metadata_path = artifacts.get('metadata.json')
if metadata_path and os.path.exists(metadata_path):
with open(metadata_path, 'r', encoding='utf-8') as f:
# 增加异常处理,防止因文件为空或格式错误导致整个流程失败
try:
trainer_metadata = json.load(f)
except json.JSONDecodeError:
print(f"警告: 无法解析元数据文件 {metadata_path}。文件可能为空或格式不正确。")
# 合并元数据
db_record.update(trainer_metadata)
with open(os.path.join(final_path, 'metadata.json'), 'w', encoding='utf-8') as f:
json.dump(db_record, f, indent=4, ensure_ascii=False)
# 步骤5将最终记录添加到数据库
self.repository.add_model_version(db_record)
return final_path, db_record
def _get_scope_path_and_definition(self, ids: List[str]) -> Tuple[str, Dict]:
"""根据ID列表获取路径片段和范围定义 (条件哈希)"""
if len(ids) == 1:
return ids[0], {'type': 'single', 'id': ids[0]}
# 只有当ID多于一个时才使用哈希
hash_val = generate_hash(ids)
return hash_val, {'type': 'hash', 'ids': sorted(ids)}
def _handle_product_training(self, payload: Dict[str, Any]) -> Tuple[str, Dict]:
product_id = payload.get('product_id')
if not product_id:
raise ValueError("产品训练模式下 'product_id' 是必需的")
model_type = payload['model_type']
model_base_path = os.path.join('product', product_id, model_type)
next_version = self.version_manager.get_next_version(model_base_path.replace(os.sep, '/'))
final_path = os.path.join(self.base_path, model_base_path, f'v{next_version}')
model_uid = str(uuid.uuid4())
db_record = {
'model_uid': model_uid,
'training_mode': 'product',
'model_type': model_type,
'version': next_version,
'path': final_path,
'scope': {'product_id': product_id},
**payload.get('metrics', {})
}
return final_path, db_record
def _handle_store_training(self, payload: Dict[str, Any]) -> Tuple[str, Dict]:
store_id = payload.get('store_id')
if not store_id:
raise ValueError("店铺训练模式下 'store_id' 是必需的")
model_type = payload['model_type']
scope_path = store_id
scope_definition = {'type': 'single', 'id': store_id}
model_base_path = os.path.join('store', scope_path, model_type)
next_version = self.version_manager.get_next_version(model_base_path.replace(os.sep, '/'))
final_path = os.path.join(self.base_path, model_base_path, f'v{next_version}')
model_uid = str(uuid.uuid4())
db_record = {
'model_uid': model_uid,
'training_mode': 'store',
'model_type': model_type,
'version': next_version,
'path': final_path,
'scope': scope_definition,
**payload.get('metrics', {})
}
return final_path, db_record
def _handle_global_training(self, payload: Dict[str, Any]) -> Tuple[str, Dict]:
store_ids = payload.get('store_ids', [])
product_ids = payload.get('product_ids', [])
model_type = payload['model_type']
aggregation = payload.get('aggregation_method', 'all')
scope_path_parts = []
scope_definition = {}
if store_ids:
s_path, s_def = self._get_scope_path_and_definition(store_ids)
scope_path_parts.append(f"S_{s_path}")
scope_definition['stores'] = s_def
if product_ids:
p_path, p_def = self._get_scope_path_and_definition(product_ids)
scope_path_parts.append(f"P_{p_path}")
scope_definition['products'] = p_def
scope_path = "_".join(scope_path_parts) if scope_path_parts else "all"
model_base_path = os.path.join('global', scope_path, aggregation, model_type)
next_version = self.version_manager.get_next_version(model_base_path.replace(os.sep, '/'))
final_path = os.path.join(self.base_path, model_base_path, f'v{next_version}')
model_uid = str(uuid.uuid4())
db_record = {
'model_uid': model_uid,
'training_mode': 'global',
'model_type': model_type,
'version': next_version,
'path': final_path,
'scope': scope_definition,
'aggregation_method': aggregation,
**payload.get('metrics', {})
}
return final_path, db_record