调整文件保存代码,根据前端三者训练模式,选择文件保存函数
This commit is contained in:
parent
18f505a090
commit
e51aaa5cf6
19
server/data/base_source.py
Normal file
19
server/data/base_source.py
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
import pandas as pd
|
||||||
|
|
||||||
|
class IDataSource(ABC):
|
||||||
|
"""
|
||||||
|
数据源接口,定义了获取训练数据的标准方法。
|
||||||
|
"""
|
||||||
|
@abstractmethod
|
||||||
|
def get_data(self, **filters) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
根据指定的筛选条件获取数据。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**filters: 一个包含筛选条件的字典,例如 store_ids=['S001'], product_ids=['P001']。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
一个包含所需数据的 pandas DataFrame。
|
||||||
|
"""
|
||||||
|
pass
|
46
server/data/parquet_source.py
Normal file
46
server/data/parquet_source.py
Normal file
@ -0,0 +1,46 @@
|
|||||||
|
from .base_source import IDataSource
|
||||||
|
import pandas as pd
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
class ParquetDataSource(IDataSource):
|
||||||
|
"""
|
||||||
|
一个从Parquet文件加载数据的数据源实现。
|
||||||
|
"""
|
||||||
|
def __init__(self, file_path: str):
|
||||||
|
"""
|
||||||
|
初始化Parquet数据源。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: Parquet文件的路径。
|
||||||
|
"""
|
||||||
|
self.file_path = file_path
|
||||||
|
try:
|
||||||
|
self._df = pd.read_parquet(self.file_path)
|
||||||
|
except FileNotFoundError:
|
||||||
|
print(f"警告: Parquet文件未找到于 {self.file_path}。将使用空DataFrame。")
|
||||||
|
self._df = pd.DataFrame()
|
||||||
|
|
||||||
|
def get_data(self, store_ids: Optional[List[str]] = None, product_ids: Optional[List[str]] = None, **kwargs) -> pd.DataFrame:
|
||||||
|
"""
|
||||||
|
从Parquet文件中筛选并返回数据。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
store_ids: 要筛选的店铺ID列表。
|
||||||
|
product_ids: 要筛选的药品ID列表。
|
||||||
|
**kwargs: 其他预留的筛选参数。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
一个经过筛选的pandas DataFrame。
|
||||||
|
"""
|
||||||
|
if self._df.empty:
|
||||||
|
return self._df
|
||||||
|
|
||||||
|
filtered_df = self._df.copy()
|
||||||
|
|
||||||
|
if store_ids and 'store_id' in filtered_df.columns:
|
||||||
|
filtered_df = filtered_df[filtered_df['store_id'].isin(store_ids)]
|
||||||
|
|
||||||
|
if product_ids and 'product_id' in filtered_df.columns:
|
||||||
|
filtered_df = filtered_df[filtered_df['product_id'].isin(product_ids)]
|
||||||
|
|
||||||
|
return filtered_df
|
70
server/repositories/model_repository.py
Normal file
70
server/repositories/model_repository.py
Normal file
@ -0,0 +1,70 @@
|
|||||||
|
import sqlite3
|
||||||
|
from typing import List, Dict, Any, Optional
|
||||||
|
|
||||||
|
class ModelRepository:
|
||||||
|
"""
|
||||||
|
封装所有对SQLite数据库中 `model_registry` 表的CRUD操作。
|
||||||
|
在初期开发阶段,所有方法都是空操作占位符。
|
||||||
|
"""
|
||||||
|
def __init__(self, db_path: str):
|
||||||
|
"""
|
||||||
|
初始化仓库。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
db_path: SQLite数据库文件的路径。
|
||||||
|
"""
|
||||||
|
self.db_path = db_path
|
||||||
|
self.conn = None # 连接将在需要时建立
|
||||||
|
|
||||||
|
def _get_connection(self):
|
||||||
|
"""建立并返回数据库连接。"""
|
||||||
|
if self.conn is None:
|
||||||
|
# 在实际实现中,这里会连接到 self.db_path
|
||||||
|
# self.conn = sqlite3.connect(self.db_path)
|
||||||
|
# self.conn.row_factory = sqlite3.Row
|
||||||
|
pass
|
||||||
|
return self.conn
|
||||||
|
|
||||||
|
def add_model_version(self, model_data: Dict[str, Any]):
|
||||||
|
"""
|
||||||
|
向数据库中添加一条新的模型版本记录。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_data: 一个包含模型元数据的字典。
|
||||||
|
"""
|
||||||
|
print(f"[Repository] (空操作) 准备保存模型记录: {model_data.get('model_uid')}")
|
||||||
|
# 实际实现将包含SQL INSERT语句
|
||||||
|
pass
|
||||||
|
|
||||||
|
def find_by_uid(self, model_uid: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
根据模型的唯一ID查找模型。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_uid: 模型的唯一ID。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
一个包含模型数据的字典,如果未找到则返回None。
|
||||||
|
"""
|
||||||
|
print(f"[Repository] (空操作) 准备根据UID查找模型: {model_uid}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def find_all(self, **filters) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
根据指定的筛选条件查找所有匹配的模型。
|
||||||
|
|
||||||
|
Args:
|
||||||
|
**filters: 筛选条件,例如 training_mode='global'。
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
一个包含所有匹配模型记录的字典列表。
|
||||||
|
"""
|
||||||
|
print(f"[Repository] (空操作) 准备查找模型,筛选条件: {filters}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
"""关闭数据库连接。"""
|
||||||
|
if self.conn:
|
||||||
|
# self.conn.close()
|
||||||
|
self.conn = None
|
||||||
|
print("[Repository] (空操作) 数据库连接已关闭。")
|
173
server/services/model_management_service.py
Normal file
173
server/services/model_management_service.py
Normal file
@ -0,0 +1,173 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
import shutil
|
||||||
|
import uuid
|
||||||
|
from typing import Dict, Any, List, Tuple
|
||||||
|
|
||||||
|
from server.repositories.model_repository import ModelRepository
|
||||||
|
from server.utils.hashing import generate_hash
|
||||||
|
from server.services.version_manager import VersionManager
|
||||||
|
|
||||||
|
class ModelManagementService:
|
||||||
|
"""
|
||||||
|
负责根据训练负载(payload)来管理模型的整个生命周期,
|
||||||
|
包括路径构建、版本控制、文件保存和数据库记录。
|
||||||
|
"""
|
||||||
|
def __init__(self, repository: ModelRepository, base_path='saved_models'):
|
||||||
|
self.repository = repository
|
||||||
|
self.base_path = base_path
|
||||||
|
self.version_manager = VersionManager(base_path)
|
||||||
|
|
||||||
|
def save_model_for_training(self, payload: Dict[str, Any], artifacts: Dict[str, str]):
|
||||||
|
"""
|
||||||
|
主分发函数,根据 training_mode 调用相应的处理方法。
|
||||||
|
在保存前,会强制校验所有必需的产物是否都已提供。
|
||||||
|
"""
|
||||||
|
# 步骤1:强制校验产物完整性
|
||||||
|
REQUIRED_ARTIFACTS = {'model.pth', 'checkpoint_best.pth', 'metadata.json', 'loss_curve.png'}
|
||||||
|
provided_artifacts = set(artifacts.keys())
|
||||||
|
|
||||||
|
if not REQUIRED_ARTIFACTS.issubset(provided_artifacts):
|
||||||
|
missing = REQUIRED_ARTIFACTS - provided_artifacts
|
||||||
|
raise ValueError(f"模型产物不完整,缺少以下必需文件: {', '.join(missing)}")
|
||||||
|
|
||||||
|
# 步骤2:根据训练模式获取路径和数据库记录
|
||||||
|
training_mode = payload.get('training_mode')
|
||||||
|
handler_map = {
|
||||||
|
'product': self._handle_product_training,
|
||||||
|
'store': self._handle_store_training,
|
||||||
|
'global': self._handle_global_training,
|
||||||
|
}
|
||||||
|
handler = handler_map.get(training_mode)
|
||||||
|
if not handler:
|
||||||
|
raise ValueError(f"未知的训练模式: {training_mode}")
|
||||||
|
final_path, db_record = handler(payload)
|
||||||
|
|
||||||
|
# 步骤3:创建目录并移动产物文件
|
||||||
|
os.makedirs(final_path, exist_ok=True)
|
||||||
|
for artifact_name, temp_path in artifacts.items():
|
||||||
|
# metadata.json 由db_record生成,特殊处理
|
||||||
|
if artifact_name == 'metadata.json':
|
||||||
|
continue
|
||||||
|
shutil.move(temp_path, os.path.join(final_path, artifact_name))
|
||||||
|
|
||||||
|
# 步骤4:写入最终的元数据文件
|
||||||
|
# 将训练器生成的元数据与服务层生成的元数据合并
|
||||||
|
# 从临时文件中读取训练器生成的元数据
|
||||||
|
trainer_metadata = {}
|
||||||
|
metadata_path = artifacts.get('metadata.json')
|
||||||
|
if metadata_path and os.path.exists(metadata_path):
|
||||||
|
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||||
|
# 增加异常处理,防止因文件为空或格式错误导致整个流程失败
|
||||||
|
try:
|
||||||
|
trainer_metadata = json.load(f)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print(f"警告: 无法解析元数据文件 {metadata_path}。文件可能为空或格式不正确。")
|
||||||
|
|
||||||
|
# 合并元数据
|
||||||
|
db_record.update(trainer_metadata)
|
||||||
|
|
||||||
|
with open(os.path.join(final_path, 'metadata.json'), 'w', encoding='utf-8') as f:
|
||||||
|
json.dump(db_record, f, indent=4, ensure_ascii=False)
|
||||||
|
|
||||||
|
# 步骤5:将最终记录添加到数据库
|
||||||
|
self.repository.add_model_version(db_record)
|
||||||
|
|
||||||
|
return final_path, db_record
|
||||||
|
|
||||||
|
def _get_scope_path_and_definition(self, ids: List[str]) -> Tuple[str, Dict]:
|
||||||
|
"""根据ID列表获取路径片段和范围定义 (条件哈希)"""
|
||||||
|
if len(ids) == 1:
|
||||||
|
return ids[0], {'type': 'single', 'id': ids[0]}
|
||||||
|
|
||||||
|
# 只有当ID多于一个时才使用哈希
|
||||||
|
hash_val = generate_hash(ids)
|
||||||
|
return hash_val, {'type': 'hash', 'ids': sorted(ids)}
|
||||||
|
|
||||||
|
def _handle_product_training(self, payload: Dict[str, Any]) -> Tuple[str, Dict]:
|
||||||
|
product_id = payload.get('product_id')
|
||||||
|
if not product_id:
|
||||||
|
raise ValueError("产品训练模式下 'product_id' 是必需的")
|
||||||
|
model_type = payload['model_type']
|
||||||
|
|
||||||
|
model_base_path = os.path.join('product', product_id, model_type)
|
||||||
|
next_version = self.version_manager.get_next_version(model_base_path.replace(os.sep, '/'))
|
||||||
|
|
||||||
|
final_path = os.path.join(self.base_path, model_base_path, f'v{next_version}')
|
||||||
|
model_uid = str(uuid.uuid4())
|
||||||
|
|
||||||
|
db_record = {
|
||||||
|
'model_uid': model_uid,
|
||||||
|
'training_mode': 'product',
|
||||||
|
'model_type': model_type,
|
||||||
|
'version': next_version,
|
||||||
|
'path': final_path,
|
||||||
|
'scope': {'product_id': product_id},
|
||||||
|
**payload.get('metrics', {})
|
||||||
|
}
|
||||||
|
return final_path, db_record
|
||||||
|
|
||||||
|
def _handle_store_training(self, payload: Dict[str, Any]) -> Tuple[str, Dict]:
|
||||||
|
store_id = payload.get('store_id')
|
||||||
|
if not store_id:
|
||||||
|
raise ValueError("店铺训练模式下 'store_id' 是必需的")
|
||||||
|
model_type = payload['model_type']
|
||||||
|
|
||||||
|
scope_path = store_id
|
||||||
|
scope_definition = {'type': 'single', 'id': store_id}
|
||||||
|
|
||||||
|
model_base_path = os.path.join('store', scope_path, model_type)
|
||||||
|
next_version = self.version_manager.get_next_version(model_base_path.replace(os.sep, '/'))
|
||||||
|
|
||||||
|
final_path = os.path.join(self.base_path, model_base_path, f'v{next_version}')
|
||||||
|
model_uid = str(uuid.uuid4())
|
||||||
|
|
||||||
|
db_record = {
|
||||||
|
'model_uid': model_uid,
|
||||||
|
'training_mode': 'store',
|
||||||
|
'model_type': model_type,
|
||||||
|
'version': next_version,
|
||||||
|
'path': final_path,
|
||||||
|
'scope': scope_definition,
|
||||||
|
**payload.get('metrics', {})
|
||||||
|
}
|
||||||
|
return final_path, db_record
|
||||||
|
|
||||||
|
def _handle_global_training(self, payload: Dict[str, Any]) -> Tuple[str, Dict]:
|
||||||
|
store_ids = payload.get('store_ids', [])
|
||||||
|
product_ids = payload.get('product_ids', [])
|
||||||
|
model_type = payload['model_type']
|
||||||
|
aggregation = payload.get('aggregation_method', 'all')
|
||||||
|
|
||||||
|
scope_path_parts = []
|
||||||
|
scope_definition = {}
|
||||||
|
|
||||||
|
if store_ids:
|
||||||
|
s_path, s_def = self._get_scope_path_and_definition(store_ids)
|
||||||
|
scope_path_parts.append(f"S_{s_path}")
|
||||||
|
scope_definition['stores'] = s_def
|
||||||
|
|
||||||
|
if product_ids:
|
||||||
|
p_path, p_def = self._get_scope_path_and_definition(product_ids)
|
||||||
|
scope_path_parts.append(f"P_{p_path}")
|
||||||
|
scope_definition['products'] = p_def
|
||||||
|
|
||||||
|
scope_path = "_".join(scope_path_parts) if scope_path_parts else "all"
|
||||||
|
|
||||||
|
model_base_path = os.path.join('global', scope_path, aggregation, model_type)
|
||||||
|
next_version = self.version_manager.get_next_version(model_base_path.replace(os.sep, '/'))
|
||||||
|
|
||||||
|
final_path = os.path.join(self.base_path, model_base_path, f'v{next_version}')
|
||||||
|
model_uid = str(uuid.uuid4())
|
||||||
|
|
||||||
|
db_record = {
|
||||||
|
'model_uid': model_uid,
|
||||||
|
'training_mode': 'global',
|
||||||
|
'model_type': model_type,
|
||||||
|
'version': next_version,
|
||||||
|
'path': final_path,
|
||||||
|
'scope': scope_definition,
|
||||||
|
'aggregation_method': aggregation,
|
||||||
|
**payload.get('metrics', {})
|
||||||
|
}
|
||||||
|
return final_path, db_record
|
Loading…
x
Reference in New Issue
Block a user