ShopTRAINING/server/services/model_management_service.py

173 lines
7.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
import shutil
import uuid
from typing import Dict, Any, List, Tuple
from server.repositories.model_repository import ModelRepository
from server.utils.hashing import generate_hash
from server.services.version_manager import VersionManager
class ModelManagementService:
"""
负责根据训练负载payload来管理模型的整个生命周期
包括路径构建、版本控制、文件保存和数据库记录。
"""
def __init__(self, repository: ModelRepository, base_path='saved_models'):
self.repository = repository
self.base_path = base_path
self.version_manager = VersionManager(base_path)
def save_model_for_training(self, payload: Dict[str, Any], artifacts: Dict[str, str]):
"""
主分发函数,根据 training_mode 调用相应的处理方法。
在保存前,会强制校验所有必需的产物是否都已提供。
"""
# 步骤1强制校验产物完整性
REQUIRED_ARTIFACTS = {'model.pth', 'checkpoint_best.pth', 'metadata.json', 'loss_curve.png'}
provided_artifacts = set(artifacts.keys())
if not REQUIRED_ARTIFACTS.issubset(provided_artifacts):
missing = REQUIRED_ARTIFACTS - provided_artifacts
raise ValueError(f"模型产物不完整,缺少以下必需文件: {', '.join(missing)}")
# 步骤2根据训练模式获取路径和数据库记录
training_mode = payload.get('training_mode')
handler_map = {
'product': self._handle_product_training,
'store': self._handle_store_training,
'global': self._handle_global_training,
}
handler = handler_map.get(training_mode)
if not handler:
raise ValueError(f"未知的训练模式: {training_mode}")
final_path, db_record = handler(payload)
# 步骤3创建目录并移动产物文件
os.makedirs(final_path, exist_ok=True)
for artifact_name, temp_path in artifacts.items():
# metadata.json 由db_record生成特殊处理
if artifact_name == 'metadata.json':
continue
shutil.move(temp_path, os.path.join(final_path, artifact_name))
# 步骤4写入最终的元数据文件
# 将训练器生成的元数据与服务层生成的元数据合并
# 从临时文件中读取训练器生成的元数据
trainer_metadata = {}
metadata_path = artifacts.get('metadata.json')
if metadata_path and os.path.exists(metadata_path):
with open(metadata_path, 'r', encoding='utf-8') as f:
# 增加异常处理,防止因文件为空或格式错误导致整个流程失败
try:
trainer_metadata = json.load(f)
except json.JSONDecodeError:
print(f"警告: 无法解析元数据文件 {metadata_path}。文件可能为空或格式不正确。")
# 合并元数据
db_record.update(trainer_metadata)
with open(os.path.join(final_path, 'metadata.json'), 'w', encoding='utf-8') as f:
json.dump(db_record, f, indent=4, ensure_ascii=False)
# 步骤5将最终记录添加到数据库
self.repository.add_model_version(db_record)
return final_path, db_record
def _get_scope_path_and_definition(self, ids: List[str]) -> Tuple[str, Dict]:
"""根据ID列表获取路径片段和范围定义 (条件哈希)"""
if len(ids) == 1:
return ids[0], {'type': 'single', 'id': ids[0]}
# 只有当ID多于一个时才使用哈希
hash_val = generate_hash(ids)
return hash_val, {'type': 'hash', 'ids': sorted(ids)}
def _handle_product_training(self, payload: Dict[str, Any]) -> Tuple[str, Dict]:
product_id = payload.get('product_id')
if not product_id:
raise ValueError("产品训练模式下 'product_id' 是必需的")
model_type = payload['model_type']
model_base_path = os.path.join('product', product_id, model_type)
next_version = self.version_manager.get_next_version(model_base_path.replace(os.sep, '/'))
final_path = os.path.join(self.base_path, model_base_path, f'v{next_version}')
model_uid = str(uuid.uuid4())
db_record = {
'model_uid': model_uid,
'training_mode': 'product',
'model_type': model_type,
'version': next_version,
'path': final_path,
'scope': {'product_id': product_id},
**payload.get('metrics', {})
}
return final_path, db_record
def _handle_store_training(self, payload: Dict[str, Any]) -> Tuple[str, Dict]:
store_id = payload.get('store_id')
if not store_id:
raise ValueError("店铺训练模式下 'store_id' 是必需的")
model_type = payload['model_type']
scope_path = store_id
scope_definition = {'type': 'single', 'id': store_id}
model_base_path = os.path.join('store', scope_path, model_type)
next_version = self.version_manager.get_next_version(model_base_path.replace(os.sep, '/'))
final_path = os.path.join(self.base_path, model_base_path, f'v{next_version}')
model_uid = str(uuid.uuid4())
db_record = {
'model_uid': model_uid,
'training_mode': 'store',
'model_type': model_type,
'version': next_version,
'path': final_path,
'scope': scope_definition,
**payload.get('metrics', {})
}
return final_path, db_record
def _handle_global_training(self, payload: Dict[str, Any]) -> Tuple[str, Dict]:
store_ids = payload.get('store_ids', [])
product_ids = payload.get('product_ids', [])
model_type = payload['model_type']
aggregation = payload.get('aggregation_method', 'all')
scope_path_parts = []
scope_definition = {}
if store_ids:
s_path, s_def = self._get_scope_path_and_definition(store_ids)
scope_path_parts.append(f"S_{s_path}")
scope_definition['stores'] = s_def
if product_ids:
p_path, p_def = self._get_scope_path_and_definition(product_ids)
scope_path_parts.append(f"P_{p_path}")
scope_definition['products'] = p_def
scope_path = "_".join(scope_path_parts) if scope_path_parts else "all"
model_base_path = os.path.join('global', scope_path, aggregation, model_type)
next_version = self.version_manager.get_next_version(model_base_path.replace(os.sep, '/'))
final_path = os.path.join(self.base_path, model_base_path, f'v{next_version}')
model_uid = str(uuid.uuid4())
db_record = {
'model_uid': model_uid,
'training_mode': 'global',
'model_type': model_type,
'version': next_version,
'path': final_path,
'scope': scope_definition,
'aggregation_method': aggregation,
**payload.get('metrics', {})
}
return final_path, db_record