265 lines
12 KiB
Python
265 lines
12 KiB
Python
|
import os
|
|||
|
import json
|
|||
|
import hashlib
|
|||
|
from threading import Lock
|
|||
|
from typing import List, Dict, Any, Optional
|
|||
|
|
|||
|
class ModelPathManager:
|
|||
|
"""
|
|||
|
根据定义的规则管理模型训练产物的保存路径。
|
|||
|
此类旨在集中处理所有与文件系统交互的路径生成逻辑,
|
|||
|
确保整个应用程序遵循统一的模型保存标准。
|
|||
|
"""
|
|||
|
def __init__(self, base_dir: str = 'saved_models'):
|
|||
|
"""
|
|||
|
初始化路径管理器。
|
|||
|
|
|||
|
Args:
|
|||
|
base_dir (str): 所有模型保存的根目录。
|
|||
|
"""
|
|||
|
# 始终使用相对于项目根目录的相对路径
|
|||
|
self.base_dir = base_dir
|
|||
|
self.versions_file = os.path.join(self.base_dir, 'versions.json')
|
|||
|
self.lock = Lock()
|
|||
|
|
|||
|
# 确保根目录存在
|
|||
|
os.makedirs(self.base_dir, exist_ok=True)
|
|||
|
|
|||
|
def _hash_ids(self, ids: List[str]) -> str:
|
|||
|
"""
|
|||
|
对ID列表进行排序和哈希,生成一个稳定的、简短的哈希值。
|
|||
|
|
|||
|
Args:
|
|||
|
ids (List[str]): 需要哈希的ID列表。
|
|||
|
|
|||
|
Returns:
|
|||
|
str: 代表该ID集合的10位短哈希字符串。
|
|||
|
"""
|
|||
|
if not ids:
|
|||
|
return 'none'
|
|||
|
# 排序以确保对于相同集合的ID,即使顺序不同,结果也一样
|
|||
|
sorted_ids = sorted([str(i) for i in ids])
|
|||
|
id_string = ",".join(sorted_ids)
|
|||
|
|
|||
|
# 使用SHA256生成哈希值并截取前10位
|
|||
|
return hashlib.sha256(id_string.encode('utf-8')).hexdigest()[:10]
|
|||
|
|
|||
|
def _generate_identifier(self, training_mode: str, **kwargs: Any) -> str:
|
|||
|
"""
|
|||
|
根据训练模式和参数生成模型的唯一标识符 (identifier)。
|
|||
|
这个标识符将作为版本文件中的key,并用于构建目录路径。
|
|||
|
|
|||
|
Args:
|
|||
|
training_mode (str): 训练模式 ('product', 'store', 'global')。
|
|||
|
**kwargs: 从API请求中传递的参数字典。
|
|||
|
|
|||
|
Returns:
|
|||
|
str: 模型的唯一标识符。
|
|||
|
|
|||
|
Raises:
|
|||
|
ValueError: 如果缺少必要的参数。
|
|||
|
"""
|
|||
|
if training_mode == 'product':
|
|||
|
product_id = kwargs.get('product_id')
|
|||
|
if not product_id:
|
|||
|
raise ValueError("按药品训练模式需要 'product_id'。")
|
|||
|
# 对于药品训练,数据范围由 store_id 定义
|
|||
|
store_id = kwargs.get('store_id')
|
|||
|
scope = store_id if store_id is not None else 'all'
|
|||
|
return f"product_{product_id}_scope_{scope}"
|
|||
|
|
|||
|
elif training_mode == 'store':
|
|||
|
store_id = kwargs.get('store_id')
|
|||
|
if not store_id:
|
|||
|
raise ValueError("按店铺训练模式需要 'store_id'。")
|
|||
|
|
|||
|
product_scope = kwargs.get('product_scope', 'all')
|
|||
|
if product_scope == 'specific':
|
|||
|
product_ids = kwargs.get('product_ids')
|
|||
|
if not product_ids:
|
|||
|
raise ValueError("店铺训练选择 specific 范围时需要 'product_ids'。")
|
|||
|
# 如果只有一个ID,直接使用ID;否则使用哈希
|
|||
|
scope = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids)
|
|||
|
else:
|
|||
|
scope = 'all'
|
|||
|
return f"store_{store_id}_products_{scope}"
|
|||
|
|
|||
|
elif training_mode == 'global':
|
|||
|
training_scope = kwargs.get('training_scope', 'all')
|
|||
|
|
|||
|
if training_scope in ['all', 'all_stores_all_products']:
|
|||
|
scope_part = 'all'
|
|||
|
elif training_scope == 'selected_stores':
|
|||
|
store_ids = kwargs.get('store_ids')
|
|||
|
if not store_ids:
|
|||
|
raise ValueError("全局训练选择 selected_stores 范围时需要 'store_ids'。")
|
|||
|
# 如果只有一个ID,直接使用ID;否则使用哈希
|
|||
|
scope_id = store_ids[0] if len(store_ids) == 1 else self._hash_ids(store_ids)
|
|||
|
scope_part = f"stores_{scope_id}"
|
|||
|
elif training_scope == 'selected_products':
|
|||
|
product_ids = kwargs.get('product_ids')
|
|||
|
if not product_ids:
|
|||
|
raise ValueError("全局训练选择 selected_products 范围时需要 'product_ids'。")
|
|||
|
# 如果只有一个ID,直接使用ID;否则使用哈希
|
|||
|
scope_id = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids)
|
|||
|
scope_part = f"products_{scope_id}"
|
|||
|
elif training_scope == 'custom':
|
|||
|
store_ids = kwargs.get('store_ids')
|
|||
|
product_ids = kwargs.get('product_ids')
|
|||
|
if not store_ids or not product_ids:
|
|||
|
raise ValueError("全局训练选择 custom 范围时需要 'store_ids' 和 'product_ids'。")
|
|||
|
s_id = store_ids[0] if len(store_ids) == 1 else self._hash_ids(store_ids)
|
|||
|
p_id = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids)
|
|||
|
scope_part = f"custom_s_{s_id}_p_{p_id}"
|
|||
|
else:
|
|||
|
raise ValueError(f"未知的全局训练范围: {training_scope}")
|
|||
|
|
|||
|
aggregation_method = kwargs.get('aggregation_method', 'sum')
|
|||
|
return f"global_{scope_part}_agg_{aggregation_method}"
|
|||
|
|
|||
|
else:
|
|||
|
raise ValueError(f"未知的训练模式: {training_mode}")
|
|||
|
|
|||
|
def get_next_version(self, identifier: str) -> int:
|
|||
|
"""
|
|||
|
获取指定标识符的下一个版本号。
|
|||
|
此方法是线程安全的。
|
|||
|
|
|||
|
Args:
|
|||
|
identifier (str): 模型的唯一标识符。
|
|||
|
|
|||
|
Returns:
|
|||
|
int: 下一个可用的版本号 (从1开始)。
|
|||
|
"""
|
|||
|
with self.lock:
|
|||
|
try:
|
|||
|
if os.path.exists(self.versions_file):
|
|||
|
with open(self.versions_file, 'r', encoding='utf-8') as f:
|
|||
|
versions_data = json.load(f)
|
|||
|
else:
|
|||
|
versions_data = {}
|
|||
|
|
|||
|
# 如果标识符不存在,当前版本为0,下一个版本即为1
|
|||
|
current_version = versions_data.get(identifier, 0)
|
|||
|
return current_version + 1
|
|||
|
except (IOError, json.JSONDecodeError) as e:
|
|||
|
# 如果文件损坏或读取失败,从0开始
|
|||
|
print(f"警告: 读取版本文件 '{self.versions_file}' 失败: {e}。将从版本1开始。")
|
|||
|
return 1
|
|||
|
|
|||
|
def save_version_info(self, identifier: str, new_version: int):
|
|||
|
"""
|
|||
|
训练成功后,更新版本文件。
|
|||
|
此方法是线程安全的。
|
|||
|
|
|||
|
Args:
|
|||
|
identifier (str): 模型的唯一标识符。
|
|||
|
new_version (int): 要保存的新的版本号。
|
|||
|
"""
|
|||
|
with self.lock:
|
|||
|
try:
|
|||
|
if os.path.exists(self.versions_file):
|
|||
|
with open(self.versions_file, 'r', encoding='utf-8') as f:
|
|||
|
versions_data = json.load(f)
|
|||
|
else:
|
|||
|
versions_data = {}
|
|||
|
|
|||
|
versions_data[identifier] = new_version
|
|||
|
|
|||
|
with open(self.versions_file, 'w', encoding='utf-8') as f:
|
|||
|
json.dump(versions_data, f, indent=4, ensure_ascii=False)
|
|||
|
except (IOError, json.JSONDecodeError) as e:
|
|||
|
print(f"错误: 保存版本信息到 '{self.versions_file}' 失败: {e}")
|
|||
|
# 在这种情况下,可以选择抛出异常或采取其他恢复措施
|
|||
|
raise
|
|||
|
|
|||
|
def get_model_paths(self, training_mode: str, model_type: str, **kwargs: Any) -> Dict[str, Any]:
|
|||
|
"""
|
|||
|
主入口函数:为一次新的训练获取所有相关路径和版本信息。
|
|||
|
此方法会生成唯一的模型标识符,获取新版本号,并构建所有产物的完整路径。
|
|||
|
|
|||
|
Args:
|
|||
|
training_mode (str): 训练模式 ('product', 'store', 'global')。
|
|||
|
model_type (str): 模型类型 (e.g., 'mlstm', 'kan')。
|
|||
|
**kwargs: 从API请求中传递的参数字典。
|
|||
|
|
|||
|
Returns:
|
|||
|
Dict[str, Any]: 一个包含所有路径和关键信息的字典。
|
|||
|
"""
|
|||
|
# 1. 生成唯一标识符,并加上模型类型,确保不同模型类型有不同的版本控制
|
|||
|
base_identifier = self._generate_identifier(training_mode, **kwargs)
|
|||
|
full_identifier = f"{base_identifier}_{model_type}"
|
|||
|
|
|||
|
# 2. 获取下一个版本号
|
|||
|
next_version = self.get_next_version(full_identifier)
|
|||
|
version_str = f"v{next_version}"
|
|||
|
|
|||
|
# 3. 根据规则构建基础路径
|
|||
|
if training_mode == 'product':
|
|||
|
product_id = kwargs.get('product_id')
|
|||
|
store_id = kwargs.get('store_id')
|
|||
|
scope = store_id if store_id is not None else 'all'
|
|||
|
scope_folder = f"{product_id}_{scope}"
|
|||
|
path_parts = [training_mode, scope_folder, model_type, version_str]
|
|||
|
|
|||
|
elif training_mode == 'store':
|
|||
|
store_id = kwargs.get('store_id')
|
|||
|
product_scope = kwargs.get('product_scope', 'all')
|
|||
|
if product_scope == 'specific':
|
|||
|
product_ids = kwargs.get('product_ids', [])
|
|||
|
# 如果只有一个ID,直接使用ID;否则使用哈希
|
|||
|
scope = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids)
|
|||
|
else:
|
|||
|
scope = 'all'
|
|||
|
scope_folder = f"{store_id}_{scope}"
|
|||
|
path_parts = [training_mode, scope_folder, model_type, version_str]
|
|||
|
|
|||
|
elif training_mode == 'global':
|
|||
|
aggregation_method = kwargs.get('aggregation_method', 'sum')
|
|||
|
training_scope = kwargs.get('training_scope', 'all')
|
|||
|
|
|||
|
scope_parts = [training_mode]
|
|||
|
if training_scope in ['all', 'all_stores_all_products']:
|
|||
|
scope_parts.append('all')
|
|||
|
elif training_scope == 'selected_stores':
|
|||
|
store_ids = kwargs.get('store_ids', [])
|
|||
|
scope_id = store_ids[0] if len(store_ids) == 1 else self._hash_ids(store_ids)
|
|||
|
scope_parts.extend(['stores', scope_id])
|
|||
|
elif training_scope == 'selected_products':
|
|||
|
product_ids = kwargs.get('product_ids', [])
|
|||
|
scope_id = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids)
|
|||
|
scope_parts.extend(['products', scope_id])
|
|||
|
elif training_scope == 'custom':
|
|||
|
store_ids = kwargs.get('store_ids', [])
|
|||
|
product_ids = kwargs.get('product_ids', [])
|
|||
|
s_id = store_ids[0] if len(store_ids) == 1 else self._hash_ids(store_ids)
|
|||
|
p_id = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids)
|
|||
|
scope_parts.extend(['custom', s_id, p_id])
|
|||
|
|
|||
|
scope_parts.extend([aggregation_method, model_type, version_str])
|
|||
|
path_parts = scope_parts
|
|||
|
|
|||
|
else:
|
|||
|
raise ValueError(f"未知的训练模式: {training_mode}")
|
|||
|
|
|||
|
# 4. 创建版本目录
|
|||
|
version_dir = os.path.join(self.base_dir, *path_parts)
|
|||
|
os.makedirs(version_dir, exist_ok=True)
|
|||
|
|
|||
|
# 创建检查点子目录
|
|||
|
checkpoint_dir = os.path.join(version_dir, 'checkpoints')
|
|||
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|||
|
|
|||
|
# 5. 构建并返回包含所有信息的字典
|
|||
|
return {
|
|||
|
"identifier": full_identifier,
|
|||
|
"version": next_version,
|
|||
|
"base_dir": self.base_dir,
|
|||
|
"version_dir": version_dir,
|
|||
|
"model_path": os.path.join(version_dir, "model.pth"),
|
|||
|
"metadata_path": os.path.join(version_dir, "metadata.json"),
|
|||
|
"loss_curve_path": os.path.join(version_dir, "loss_curve.png"),
|
|||
|
"checkpoint_dir": checkpoint_dir,
|
|||
|
"best_checkpoint_path": os.path.join(checkpoint_dir, "checkpoint_best.pth")
|
|||
|
}
|