import os import json import hashlib from threading import Lock from typing import List, Dict, Any, Optional class ModelPathManager: """ 根据定义的规则管理模型训练产物的保存路径。 此类旨在集中处理所有与文件系统交互的路径生成逻辑, 确保整个应用程序遵循统一的模型保存标准。 """ def __init__(self, base_dir: str = 'saved_models'): """ 初始化路径管理器。 Args: base_dir (str): 所有模型保存的根目录。 """ # 始终使用相对于项目根目录的相对路径 self.base_dir = base_dir self.versions_file = os.path.join(self.base_dir, 'versions.json') self.lock = Lock() # 确保根目录存在 os.makedirs(self.base_dir, exist_ok=True) def _hash_ids(self, ids: List[str]) -> str: """ 对ID列表进行排序和哈希,生成一个稳定的、简短的哈希值。 Args: ids (List[str]): 需要哈希的ID列表。 Returns: str: 代表该ID集合的10位短哈希字符串。 """ if not ids: return 'none' # 排序以确保对于相同集合的ID,即使顺序不同,结果也一样 sorted_ids = sorted([str(i) for i in ids]) id_string = ",".join(sorted_ids) # 使用SHA256生成哈希值并截取前10位 return hashlib.sha256(id_string.encode('utf-8')).hexdigest()[:10] def _generate_identifier(self, training_mode: str, **kwargs: Any) -> str: """ 根据训练模式和参数生成模型的唯一标识符 (identifier)。 这个标识符将作为版本文件中的key,并用于构建目录路径。 Args: training_mode (str): 训练模式 ('product', 'store', 'global')。 **kwargs: 从API请求中传递的参数字典。 Returns: str: 模型的唯一标识符。 Raises: ValueError: 如果缺少必要的参数。 """ if training_mode == 'product': product_id = kwargs.get('product_id') if not product_id: raise ValueError("按药品训练模式需要 'product_id'。") # 对于药品训练,数据范围由 store_id 定义 store_id = kwargs.get('store_id') scope = store_id if store_id is not None else 'all' return f"product_{product_id}_scope_{scope}" elif training_mode == 'store': store_id = kwargs.get('store_id') if not store_id: raise ValueError("按店铺训练模式需要 'store_id'。") product_scope = kwargs.get('product_scope', 'all') if product_scope == 'specific': product_ids = kwargs.get('product_ids') if not product_ids: raise ValueError("店铺训练选择 specific 范围时需要 'product_ids'。") # 如果只有一个ID,直接使用ID;否则使用哈希 scope = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids) else: scope = 'all' return f"store_{store_id}_products_{scope}" elif training_mode == 'global': training_scope = kwargs.get('training_scope', 'all') if training_scope in ['all', 'all_stores_all_products']: scope_part = 'all' elif training_scope == 'selected_stores': store_ids = kwargs.get('store_ids') if not store_ids: raise ValueError("全局训练选择 selected_stores 范围时需要 'store_ids'。") # 如果只有一个ID,直接使用ID;否则使用哈希 scope_id = store_ids[0] if len(store_ids) == 1 else self._hash_ids(store_ids) scope_part = f"stores_{scope_id}" elif training_scope == 'selected_products': product_ids = kwargs.get('product_ids') if not product_ids: raise ValueError("全局训练选择 selected_products 范围时需要 'product_ids'。") # 如果只有一个ID,直接使用ID;否则使用哈希 scope_id = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids) scope_part = f"products_{scope_id}" elif training_scope == 'custom': store_ids = kwargs.get('store_ids') product_ids = kwargs.get('product_ids') if not store_ids or not product_ids: raise ValueError("全局训练选择 custom 范围时需要 'store_ids' 和 'product_ids'。") s_id = store_ids[0] if len(store_ids) == 1 else self._hash_ids(store_ids) p_id = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids) scope_part = f"custom_s_{s_id}_p_{p_id}" else: raise ValueError(f"未知的全局训练范围: {training_scope}") aggregation_method = kwargs.get('aggregation_method', 'sum') return f"global_{scope_part}_agg_{aggregation_method}" else: raise ValueError(f"未知的训练模式: {training_mode}") def get_next_version(self, identifier: str) -> int: """ 获取指定标识符的下一个版本号。 此方法是线程安全的。 Args: identifier (str): 模型的唯一标识符。 Returns: int: 下一个可用的版本号 (从1开始)。 """ with self.lock: try: if os.path.exists(self.versions_file): with open(self.versions_file, 'r', encoding='utf-8') as f: versions_data = json.load(f) else: versions_data = {} # 如果标识符不存在,当前版本为0,下一个版本即为1 current_version = versions_data.get(identifier, 0) return current_version + 1 except (IOError, json.JSONDecodeError) as e: # 如果文件损坏或读取失败,从0开始 print(f"警告: 读取版本文件 '{self.versions_file}' 失败: {e}。将从版本1开始。") return 1 def save_version_info(self, identifier: str, new_version: int): """ 训练成功后,更新版本文件。 此方法是线程安全的。 Args: identifier (str): 模型的唯一标识符。 new_version (int): 要保存的新的版本号。 """ with self.lock: try: if os.path.exists(self.versions_file): with open(self.versions_file, 'r', encoding='utf-8') as f: versions_data = json.load(f) else: versions_data = {} versions_data[identifier] = new_version with open(self.versions_file, 'w', encoding='utf-8') as f: json.dump(versions_data, f, indent=4, ensure_ascii=False) except (IOError, json.JSONDecodeError) as e: print(f"错误: 保存版本信息到 '{self.versions_file}' 失败: {e}") # 在这种情况下,可以选择抛出异常或采取其他恢复措施 raise def get_model_paths(self, training_mode: str, model_type: str, **kwargs: Any) -> Dict[str, Any]: """ 主入口函数:为一次新的训练获取所有相关路径和版本信息。 此方法会生成唯一的模型标识符,获取新版本号,并构建所有产物的完整路径。 Args: training_mode (str): 训练模式 ('product', 'store', 'global')。 model_type (str): 模型类型 (e.g., 'mlstm', 'kan')。 **kwargs: 从API请求中传递的参数字典。 Returns: Dict[str, Any]: 一个包含所有路径和关键信息的字典。 """ # 1. 生成唯一标识符,并加上模型类型,确保不同模型类型有不同的版本控制 base_identifier = self._generate_identifier(training_mode, **kwargs) full_identifier = f"{base_identifier}_{model_type}" # 2. 获取下一个版本号 next_version = self.get_next_version(full_identifier) version_str = f"v{next_version}" # 3. 根据规则构建基础路径 if training_mode == 'product': product_id = kwargs.get('product_id') store_id = kwargs.get('store_id') scope = store_id if store_id is not None else 'all' scope_folder = f"{product_id}_{scope}" path_parts = [training_mode, scope_folder, model_type, version_str] elif training_mode == 'store': store_id = kwargs.get('store_id') product_scope = kwargs.get('product_scope', 'all') if product_scope == 'specific': product_ids = kwargs.get('product_ids', []) # 如果只有一个ID,直接使用ID;否则使用哈希 scope = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids) else: scope = 'all' scope_folder = f"{store_id}_{scope}" path_parts = [training_mode, scope_folder, model_type, version_str] elif training_mode == 'global': aggregation_method = kwargs.get('aggregation_method', 'sum') training_scope = kwargs.get('training_scope', 'all') scope_parts = [training_mode] if training_scope in ['all', 'all_stores_all_products']: scope_parts.append('all') elif training_scope == 'selected_stores': store_ids = kwargs.get('store_ids', []) scope_id = store_ids[0] if len(store_ids) == 1 else self._hash_ids(store_ids) scope_parts.extend(['stores', scope_id]) elif training_scope == 'selected_products': product_ids = kwargs.get('product_ids', []) scope_id = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids) scope_parts.extend(['products', scope_id]) elif training_scope == 'custom': store_ids = kwargs.get('store_ids', []) product_ids = kwargs.get('product_ids', []) s_id = store_ids[0] if len(store_ids) == 1 else self._hash_ids(store_ids) p_id = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids) scope_parts.extend(['custom', s_id, p_id]) scope_parts.extend([aggregation_method, model_type, version_str]) path_parts = scope_parts else: raise ValueError(f"未知的训练模式: {training_mode}") # 4. 创建版本目录 version_dir = os.path.join(self.base_dir, *path_parts) os.makedirs(version_dir, exist_ok=True) # 创建检查点子目录 checkpoint_dir = os.path.join(version_dir, 'checkpoints') os.makedirs(checkpoint_dir, exist_ok=True) # 5. 构建并返回包含所有信息的字典 return { "identifier": full_identifier, "version": next_version, "base_dir": self.base_dir, "version_dir": version_dir, "model_path": os.path.join(version_dir, "model.pth"), "metadata_path": os.path.join(version_dir, "metadata.json"), "loss_curve_path": os.path.join(version_dir, "loss_curve.png"), "checkpoint_dir": checkpoint_dir, "best_checkpoint_path": os.path.join(checkpoint_dir, "checkpoint_best.pth") }