import os import json import hashlib from threading import Lock from typing import List, Dict, Any, Optional class ModelPathManager: """ 根据定义的规则管理模型训练产物的保存路径。 此类旨在集中处理所有与文件系统交互的路径生成逻辑, 确保整个应用程序遵循统一的模型保存标准。 """ def __init__(self, base_dir: str = 'saved_models'): """ 初始化路径管理器。 Args: base_dir (str): 所有模型保存的根目录。 """ # 始终使用相对于项目根目录的相对路径 self.base_dir = base_dir self.versions_file = os.path.join(self.base_dir, 'versions.json') self.lock = Lock() # 确保根目录存在 os.makedirs(self.base_dir, exist_ok=True) def _hash_ids(self, ids: List[str]) -> str: """ 对ID列表进行排序和哈希,生成一个稳定的、简短的哈希值。 Args: ids (List[str]): 需要哈希的ID列表。 Returns: str: 代表该ID集合的10位短哈希字符串。 """ if not ids: return 'none' # 排序以确保对于相同集合的ID,即使顺序不同,结果也一样 sorted_ids = sorted([str(i) for i in ids]) id_string = ",".join(sorted_ids) # 使用SHA256生成哈希值并截取前10位 return hashlib.sha256(id_string.encode('utf-8')).hexdigest()[:10] def _generate_identifier(self, training_mode: str, **kwargs: Any) -> str: """ 根据训练模式和参数生成模型的唯一标识符 (identifier)。 这个标识符将作为版本文件中的key,并用于构建目录路径。 Args: training_mode (str): 训练模式 ('product', 'store', 'global')。 **kwargs: 从API请求中传递的参数字典。 Returns: str: 模型的唯一标识符。 Raises: ValueError: 如果缺少必要的参数。 """ if training_mode == 'product': product_id = kwargs.get('product_id') if not product_id: raise ValueError("按药品训练模式需要 'product_id'。") # 对于药品训练,数据范围由 store_id 定义 store_id = kwargs.get('store_id') scope = store_id if store_id is not None else 'all' return f"product_{product_id}_scope_{scope}" elif training_mode == 'store': store_id = kwargs.get('store_id') if not store_id: raise ValueError("按店铺训练模式需要 'store_id'。") product_scope = kwargs.get('product_scope', 'all') if product_scope == 'specific': product_ids = kwargs.get('product_ids') if not product_ids: raise ValueError("店铺训练选择 specific 范围时需要 'product_ids'。") # 如果只有一个ID,直接使用ID;否则使用哈希 scope = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids) else: scope = 'all' return f"store_{store_id}_products_{scope}" elif training_mode == 'global': training_scope = kwargs.get('training_scope', 'all') if training_scope in ['all', 'all_stores_all_products']: scope_part = 'all' elif training_scope == 'selected_stores': store_ids = kwargs.get('store_ids') if not store_ids: raise ValueError("全局训练选择 selected_stores 范围时需要 'store_ids'。") # 如果只有一个ID,直接使用ID;否则使用哈希 scope_id = store_ids[0] if len(store_ids) == 1 else self._hash_ids(store_ids) scope_part = f"stores_{scope_id}" elif training_scope == 'selected_products': product_ids = kwargs.get('product_ids') if not product_ids: raise ValueError("全局训练选择 selected_products 范围时需要 'product_ids'。") # 如果只有一个ID,直接使用ID;否则使用哈希 scope_id = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids) scope_part = f"products_{scope_id}" elif training_scope == 'custom': store_ids = kwargs.get('store_ids') product_ids = kwargs.get('product_ids') if not store_ids or not product_ids: raise ValueError("全局训练选择 custom 范围时需要 'store_ids' 和 'product_ids'。") s_id = store_ids[0] if len(store_ids) == 1 else self._hash_ids(store_ids) p_id = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids) scope_part = f"custom_s_{s_id}_p_{p_id}" else: raise ValueError(f"未知的全局训练范围: {training_scope}") aggregation_method = kwargs.get('aggregation_method', 'sum') return f"global_{scope_part}_{aggregation_method}" else: raise ValueError(f"未知的训练模式: {training_mode}") def get_next_version(self, identifier: str) -> int: """ 获取指定标识符的下一个版本号。 此方法是线程安全的。 Args: identifier (str): 模型的唯一标识符。 Returns: int: 下一个可用的版本号 (从1开始)。 """ with self.lock: try: if os.path.exists(self.versions_file): with open(self.versions_file, 'r', encoding='utf-8') as f: versions_data = json.load(f) else: versions_data = {} # 如果标识符不存在,当前版本为0,下一个版本即为1 current_version = versions_data.get(identifier, 0) return current_version + 1 except (IOError, json.JSONDecodeError) as e: # 如果文件损坏或读取失败,从0开始 print(f"警告: 读取版本文件 '{self.versions_file}' 失败: {e}。将从版本1开始。") return 1 def save_version_info(self, identifier: str, new_version: int): """ 训练成功后,更新版本文件。 此方法是线程安全的。 Args: identifier (str): 模型的唯一标识符。 new_version (int): 要保存的新的版本号。 """ with self.lock: try: if os.path.exists(self.versions_file): with open(self.versions_file, 'r', encoding='utf-8') as f: versions_data = json.load(f) else: versions_data = {} versions_data[identifier] = new_version with open(self.versions_file, 'w', encoding='utf-8') as f: json.dump(versions_data, f, indent=4, ensure_ascii=False) except (IOError, json.JSONDecodeError) as e: print(f"错误: 保存版本信息到 '{self.versions_file}' 失败: {e}") # 在这种情况下,可以选择抛出异常或采取其他恢复措施 raise def get_model_paths(self, training_mode: str, model_type: str, **kwargs: Any) -> Dict[str, Any]: """ 主入口函数:为一次新的训练获取所有相关路径和版本信息。 此方法遵循扁平化文件存储规范,将逻辑路径编码到文件名中。 Args: training_mode (str): 训练模式 ('product', 'store', 'global')。 model_type (str): 模型类型 (e.g., 'mlstm', 'kan')。 **kwargs: 从API请求中传递的参数字典。 Returns: Dict[str, Any]: 一个包含所有路径和关键信息的字典。 """ # 1. 生成不含模型类型和版本的核心标识符,并将其中的分隔符替换为下划线 # 例如:product/P001/all -> product_P001_all base_identifier = self._generate_identifier(training_mode, **kwargs) # 规范化处理,将 'scope' 'products' 等关键字替换为更简洁的形式 # 例如 product_P001_scope_all -> product_P001_all core_prefix = base_identifier.replace('_scope_', '_').replace('_products_', '_') # 2. 构建用于版本控制的完整标识符 (不含版本号) # 例如: product_P001_all_mlstm version_control_identifier = f"{core_prefix}_{model_type}" # 3. 获取下一个版本号 next_version = self.get_next_version(version_control_identifier) version_str = f"v{next_version}" # 4. 构建最终的文件名前缀,包含版本号 # 例如: product_P001_all_mlstm_v2 filename_prefix = f"{version_control_identifier}_{version_str}" # 5. 确保 `saved_models` 和 `saved_models/checkpoints` 目录存在 checkpoints_base_dir = os.path.join(self.base_dir, 'checkpoints') os.makedirs(self.base_dir, exist_ok=True) os.makedirs(checkpoints_base_dir, exist_ok=True) # 6. 构建并返回包含所有扁平化路径和关键信息的字典 return { "identifier": version_control_identifier, # 用于版本控制的key "filename_prefix": filename_prefix, # 用于数据库和文件查找 "version": next_version, "base_dir": self.base_dir, "model_path": os.path.join(self.base_dir, f"{filename_prefix}_model.pth"), "metadata_path": os.path.join(self.base_dir, f"{filename_prefix}_metadata.json"), "loss_curve_path": os.path.join(self.base_dir, f"{filename_prefix}_loss_curve.png"), "checkpoint_dir": checkpoints_base_dir, # 指向公共的检查点目录 "best_checkpoint_path": os.path.join(checkpoints_base_dir, f"{filename_prefix}_checkpoint_best.pth"), # 为动态epoch检查点提供一个格式化模板 "epoch_checkpoint_template": os.path.join(checkpoints_base_dir, f"{filename_prefix}_checkpoint_epoch_{{N}}.pth") } def get_model_path_for_prediction(self, training_mode: str, model_type: str, version: int, **kwargs: Any) -> Optional[str]: """ 获取用于预测的已存在模型的完整路径 (遵循扁平化规范)。 Args: training_mode (str): 训练模式。 model_type (str): 模型类型。 version (int): 模型版本号。 **kwargs: 其他用于定位模型的参数。 Returns: Optional[str]: 模型的完整路径,如果不存在则返回None。 """ # 1. 生成不含模型类型和版本的核心标识符 base_identifier = self._generate_identifier(training_mode, **kwargs) core_prefix = base_identifier.replace('_scope_', '_').replace('_products_', '_') # 2. 构建用于版本控制的标识符 version_control_identifier = f"{core_prefix}_{model_type}" # 3. 构建完整的文件名前缀 version_str = f"v{version}" filename_prefix = f"{version_control_identifier}_{version_str}" # 4. 构建模型文件的完整路径 model_path = os.path.join(self.base_dir, f"{filename_prefix}_model.pth") return model_path if os.path.exists(model_path) else None