diff --git a/server/data/base_source.py b/server/data/base_source.py new file mode 100644 index 0000000..bb0d4cd --- /dev/null +++ b/server/data/base_source.py @@ -0,0 +1,19 @@ +from abc import ABC, abstractmethod +import pandas as pd + +class IDataSource(ABC): + """ + 数据源接口,定义了获取训练数据的标准方法。 + """ + @abstractmethod + def get_data(self, **filters) -> pd.DataFrame: + """ + 根据指定的筛选条件获取数据。 + + Args: + **filters: 一个包含筛选条件的字典,例如 store_ids=['S001'], product_ids=['P001']。 + + Returns: + 一个包含所需数据的 pandas DataFrame。 + """ + pass \ No newline at end of file diff --git a/server/data/parquet_source.py b/server/data/parquet_source.py new file mode 100644 index 0000000..7499a13 --- /dev/null +++ b/server/data/parquet_source.py @@ -0,0 +1,46 @@ +from .base_source import IDataSource +import pandas as pd +from typing import List, Optional + +class ParquetDataSource(IDataSource): + """ + 一个从Parquet文件加载数据的数据源实现。 + """ + def __init__(self, file_path: str): + """ + 初始化Parquet数据源。 + + Args: + file_path: Parquet文件的路径。 + """ + self.file_path = file_path + try: + self._df = pd.read_parquet(self.file_path) + except FileNotFoundError: + print(f"警告: Parquet文件未找到于 {self.file_path}。将使用空DataFrame。") + self._df = pd.DataFrame() + + def get_data(self, store_ids: Optional[List[str]] = None, product_ids: Optional[List[str]] = None, **kwargs) -> pd.DataFrame: + """ + 从Parquet文件中筛选并返回数据。 + + Args: + store_ids: 要筛选的店铺ID列表。 + product_ids: 要筛选的药品ID列表。 + **kwargs: 其他预留的筛选参数。 + + Returns: + 一个经过筛选的pandas DataFrame。 + """ + if self._df.empty: + return self._df + + filtered_df = self._df.copy() + + if store_ids and 'store_id' in filtered_df.columns: + filtered_df = filtered_df[filtered_df['store_id'].isin(store_ids)] + + if product_ids and 'product_id' in filtered_df.columns: + filtered_df = filtered_df[filtered_df['product_id'].isin(product_ids)] + + return filtered_df \ No newline at end of file diff --git a/server/repositories/model_repository.py b/server/repositories/model_repository.py new file mode 100644 index 0000000..435e8be --- /dev/null +++ b/server/repositories/model_repository.py @@ -0,0 +1,70 @@ +import sqlite3 +from typing import List, Dict, Any, Optional + +class ModelRepository: + """ + 封装所有对SQLite数据库中 `model_registry` 表的CRUD操作。 + 在初期开发阶段,所有方法都是空操作占位符。 + """ + def __init__(self, db_path: str): + """ + 初始化仓库。 + + Args: + db_path: SQLite数据库文件的路径。 + """ + self.db_path = db_path + self.conn = None # 连接将在需要时建立 + + def _get_connection(self): + """建立并返回数据库连接。""" + if self.conn is None: + # 在实际实现中,这里会连接到 self.db_path + # self.conn = sqlite3.connect(self.db_path) + # self.conn.row_factory = sqlite3.Row + pass + return self.conn + + def add_model_version(self, model_data: Dict[str, Any]): + """ + 向数据库中添加一条新的模型版本记录。 + + Args: + model_data: 一个包含模型元数据的字典。 + """ + print(f"[Repository] (空操作) 准备保存模型记录: {model_data.get('model_uid')}") + # 实际实现将包含SQL INSERT语句 + pass + + def find_by_uid(self, model_uid: str) -> Optional[Dict[str, Any]]: + """ + 根据模型的唯一ID查找模型。 + + Args: + model_uid: 模型的唯一ID。 + + Returns: + 一个包含模型数据的字典,如果未找到则返回None。 + """ + print(f"[Repository] (空操作) 准备根据UID查找模型: {model_uid}") + return None + + def find_all(self, **filters) -> List[Dict[str, Any]]: + """ + 根据指定的筛选条件查找所有匹配的模型。 + + Args: + **filters: 筛选条件,例如 training_mode='global'。 + + Returns: + 一个包含所有匹配模型记录的字典列表。 + """ + print(f"[Repository] (空操作) 准备查找模型,筛选条件: {filters}") + return [] + + def close(self): + """关闭数据库连接。""" + if self.conn: + # self.conn.close() + self.conn = None + print("[Repository] (空操作) 数据库连接已关闭。") diff --git a/server/services/model_management_service.py b/server/services/model_management_service.py new file mode 100644 index 0000000..0c43533 --- /dev/null +++ b/server/services/model_management_service.py @@ -0,0 +1,173 @@ +import os +import json +import shutil +import uuid +from typing import Dict, Any, List, Tuple + +from server.repositories.model_repository import ModelRepository +from server.utils.hashing import generate_hash +from server.services.version_manager import VersionManager + +class ModelManagementService: + """ + 负责根据训练负载(payload)来管理模型的整个生命周期, + 包括路径构建、版本控制、文件保存和数据库记录。 + """ + def __init__(self, repository: ModelRepository, base_path='saved_models'): + self.repository = repository + self.base_path = base_path + self.version_manager = VersionManager(base_path) + + def save_model_for_training(self, payload: Dict[str, Any], artifacts: Dict[str, str]): + """ + 主分发函数,根据 training_mode 调用相应的处理方法。 + 在保存前,会强制校验所有必需的产物是否都已提供。 + """ + # 步骤1:强制校验产物完整性 + REQUIRED_ARTIFACTS = {'model.pth', 'checkpoint_best.pth', 'metadata.json', 'loss_curve.png'} + provided_artifacts = set(artifacts.keys()) + + if not REQUIRED_ARTIFACTS.issubset(provided_artifacts): + missing = REQUIRED_ARTIFACTS - provided_artifacts + raise ValueError(f"模型产物不完整,缺少以下必需文件: {', '.join(missing)}") + + # 步骤2:根据训练模式获取路径和数据库记录 + training_mode = payload.get('training_mode') + handler_map = { + 'product': self._handle_product_training, + 'store': self._handle_store_training, + 'global': self._handle_global_training, + } + handler = handler_map.get(training_mode) + if not handler: + raise ValueError(f"未知的训练模式: {training_mode}") + final_path, db_record = handler(payload) + + # 步骤3:创建目录并移动产物文件 + os.makedirs(final_path, exist_ok=True) + for artifact_name, temp_path in artifacts.items(): + # metadata.json 由db_record生成,特殊处理 + if artifact_name == 'metadata.json': + continue + shutil.move(temp_path, os.path.join(final_path, artifact_name)) + + # 步骤4:写入最终的元数据文件 + # 将训练器生成的元数据与服务层生成的元数据合并 + # 从临时文件中读取训练器生成的元数据 + trainer_metadata = {} + metadata_path = artifacts.get('metadata.json') + if metadata_path and os.path.exists(metadata_path): + with open(metadata_path, 'r', encoding='utf-8') as f: + # 增加异常处理,防止因文件为空或格式错误导致整个流程失败 + try: + trainer_metadata = json.load(f) + except json.JSONDecodeError: + print(f"警告: 无法解析元数据文件 {metadata_path}。文件可能为空或格式不正确。") + + # 合并元数据 + db_record.update(trainer_metadata) + + with open(os.path.join(final_path, 'metadata.json'), 'w', encoding='utf-8') as f: + json.dump(db_record, f, indent=4, ensure_ascii=False) + + # 步骤5:将最终记录添加到数据库 + self.repository.add_model_version(db_record) + + return final_path, db_record + + def _get_scope_path_and_definition(self, ids: List[str]) -> Tuple[str, Dict]: + """根据ID列表获取路径片段和范围定义 (条件哈希)""" + if len(ids) == 1: + return ids[0], {'type': 'single', 'id': ids[0]} + + # 只有当ID多于一个时才使用哈希 + hash_val = generate_hash(ids) + return hash_val, {'type': 'hash', 'ids': sorted(ids)} + + def _handle_product_training(self, payload: Dict[str, Any]) -> Tuple[str, Dict]: + product_id = payload.get('product_id') + if not product_id: + raise ValueError("产品训练模式下 'product_id' 是必需的") + model_type = payload['model_type'] + + model_base_path = os.path.join('product', product_id, model_type) + next_version = self.version_manager.get_next_version(model_base_path.replace(os.sep, '/')) + + final_path = os.path.join(self.base_path, model_base_path, f'v{next_version}') + model_uid = str(uuid.uuid4()) + + db_record = { + 'model_uid': model_uid, + 'training_mode': 'product', + 'model_type': model_type, + 'version': next_version, + 'path': final_path, + 'scope': {'product_id': product_id}, + **payload.get('metrics', {}) + } + return final_path, db_record + + def _handle_store_training(self, payload: Dict[str, Any]) -> Tuple[str, Dict]: + store_id = payload.get('store_id') + if not store_id: + raise ValueError("店铺训练模式下 'store_id' 是必需的") + model_type = payload['model_type'] + + scope_path = store_id + scope_definition = {'type': 'single', 'id': store_id} + + model_base_path = os.path.join('store', scope_path, model_type) + next_version = self.version_manager.get_next_version(model_base_path.replace(os.sep, '/')) + + final_path = os.path.join(self.base_path, model_base_path, f'v{next_version}') + model_uid = str(uuid.uuid4()) + + db_record = { + 'model_uid': model_uid, + 'training_mode': 'store', + 'model_type': model_type, + 'version': next_version, + 'path': final_path, + 'scope': scope_definition, + **payload.get('metrics', {}) + } + return final_path, db_record + + def _handle_global_training(self, payload: Dict[str, Any]) -> Tuple[str, Dict]: + store_ids = payload.get('store_ids', []) + product_ids = payload.get('product_ids', []) + model_type = payload['model_type'] + aggregation = payload.get('aggregation_method', 'all') + + scope_path_parts = [] + scope_definition = {} + + if store_ids: + s_path, s_def = self._get_scope_path_and_definition(store_ids) + scope_path_parts.append(f"S_{s_path}") + scope_definition['stores'] = s_def + + if product_ids: + p_path, p_def = self._get_scope_path_and_definition(product_ids) + scope_path_parts.append(f"P_{p_path}") + scope_definition['products'] = p_def + + scope_path = "_".join(scope_path_parts) if scope_path_parts else "all" + + model_base_path = os.path.join('global', scope_path, aggregation, model_type) + next_version = self.version_manager.get_next_version(model_base_path.replace(os.sep, '/')) + + final_path = os.path.join(self.base_path, model_base_path, f'v{next_version}') + model_uid = str(uuid.uuid4()) + + db_record = { + 'model_uid': model_uid, + 'training_mode': 'global', + 'model_type': model_type, + 'version': next_version, + 'path': final_path, + 'scope': scope_definition, + 'aggregation_method': aggregation, + **payload.get('metrics', {}) + } + return final_path, db_record \ No newline at end of file