**版本**: 4.0 (最终版) **核心思想**: 逻辑路径被转换为文件名的一部分,实现极致扁平化的文件存储。 --- ## 一、 文件保存规则 ### 1.1. 核心原则 所有元数据都被编码到文件名中。一个逻辑上的层级路径(例如 `product/P001_all/mlstm/v2`)应该被转换为一个用下划线连接的文件名前缀(`product_P001_all_mlstm_v2`)。 ### 1.2. 文件存储位置 - **最终产物**: 所有最终模型、元数据文件、损失图等,统一存放在 `saved_models/` 根目录下。 - **过程文件**: 所有训练过程中的检查点文件,统一存放在 `saved_models/checkpoints/` 目录下。 ### 1.3. 文件名生成规则 1. **构建逻辑路径**: 根据训练参数(模式、范围、类型、版本)确定逻辑路径。 - *示例*: `product/P001_all/mlstm/v2` 2. **生成文件名前缀**: 将逻辑路径中的所有 `/` 替换为 `_`。 - *示例*: `product_P001_all_mlstm_v2` 3. **拼接文件后缀**: 在前缀后加上描述文件类型的后缀。 - `_model.pth` - `_metadata.json` - `_loss_curve.png` - `_checkpoint_best.pth` - `_checkpoint_epoch_{N}.pth` #### **完整示例:** - **最终模型**: `saved_models/product_P001_all_mlstm_v2_model.pth` - **元数据**: `saved_models/product_P001_all_mlstm_v2_metadata.json` - **最佳检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_best.pth` - **Epoch 50 检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_epoch_50.pth` --- ## 二、 文件读取规则 1. **确定模型元数据**: 根据需求确定要加载的模型的训练模式、范围、类型和版本。 2. **构建文件名前缀**: 按照与保存时相同的逻辑,将元数据拼接成文件名前缀(例如 `product_P001_all_mlstm_v2`)。 3. **定位文件**: - 要加载最终模型,查找文件: `saved_models/{prefix}_model.pth`。 - 要加载最佳检查点,查找文件: `saved_models/checkpoints/{prefix}_checkpoint_best.pth`。 --- ## 三、 数据库存储规则 数据库用于索引,应存储足以重构文件名前缀的关键元数据。 #### **`models` 表结构建议:** | 字段名 | 类型 | 描述 | 示例 | | :--- | :--- | :--- | :--- | | `id` | INTEGER | 主键 | 1 | | `filename_prefix` | TEXT | **完整文件名前缀,可作为唯一标识** | `product_P001_all_mlstm_v2` | | `model_identifier`| TEXT | 用于版本控制的标识符 (不含版本) | `product_P001_all_mlstm` | | `version` | INTEGER | 版本号 | `2` | | `status` | TEXT | 模型状态 | `completed`, `training`, `failed` | | `created_at` | TEXT | 创建时间 | `2025-07-21 02:29:00` | | `metrics_summary`| TEXT | 关键性能指标的JSON字符串 | `{"rmse": 10.5, "r2": 0.89}` | #### **保存逻辑:** - 训练完成后,向表中插入一条记录。`filename_prefix` 字段是查找与该次训练相关的所有文件的关键。 --- ## 四、 版本记录规则 版本管理依赖于根目录下的 `versions.json` 文件,以实现原子化、线程安全的版本号递增。 - **文件名**: `versions.json` - **位置**: `saved_models/versions.json` - **结构**: 一个JSON对象,`key` 是不包含版本号的标识符,`value` 是该标识符下最新的版本号(整数)。 - **Key**: `{prefix_core}_{model_type}` (例如: `product_P001_all_mlstm`) - **Value**: `Integer` #### **`versions.json` 示例:** ```json { "product_P001_all_mlstm": 2, "store_S001_P002_transformer": 1 } ``` #### **版本管理流程:** 1. **获取新版本**: 开始训练前,构建 `key`。读取 `versions.json`,找到对应 `key` 的 `value`。新版本号为 `value + 1` (若key不存在,则为 `1`)。 2. **更新版本**: 训练成功后,将新的版本号写回到 `versions.json`。此过程**必须使用文件锁**以防止并发冲突。 调试完成药品预测和店铺预测
257 lines
12 KiB
Python
257 lines
12 KiB
Python
import os
|
||
import json
|
||
import hashlib
|
||
from threading import Lock
|
||
from typing import List, Dict, Any, Optional
|
||
|
||
class ModelPathManager:
|
||
"""
|
||
根据定义的规则管理模型训练产物的保存路径。
|
||
此类旨在集中处理所有与文件系统交互的路径生成逻辑,
|
||
确保整个应用程序遵循统一的模型保存标准。
|
||
"""
|
||
def __init__(self, base_dir: str = 'saved_models'):
|
||
"""
|
||
初始化路径管理器。
|
||
|
||
Args:
|
||
base_dir (str): 所有模型保存的根目录。
|
||
"""
|
||
# 始终使用相对于项目根目录的相对路径
|
||
self.base_dir = base_dir
|
||
self.versions_file = os.path.join(self.base_dir, 'versions.json')
|
||
self.lock = Lock()
|
||
|
||
# 确保根目录存在
|
||
os.makedirs(self.base_dir, exist_ok=True)
|
||
|
||
def _hash_ids(self, ids: List[str]) -> str:
|
||
"""
|
||
对ID列表进行排序和哈希,生成一个稳定的、简短的哈希值。
|
||
|
||
Args:
|
||
ids (List[str]): 需要哈希的ID列表。
|
||
|
||
Returns:
|
||
str: 代表该ID集合的10位短哈希字符串。
|
||
"""
|
||
if not ids:
|
||
return 'none'
|
||
# 排序以确保对于相同集合的ID,即使顺序不同,结果也一样
|
||
sorted_ids = sorted([str(i) for i in ids])
|
||
id_string = ",".join(sorted_ids)
|
||
|
||
# 使用SHA256生成哈希值并截取前10位
|
||
return hashlib.sha256(id_string.encode('utf-8')).hexdigest()[:10]
|
||
|
||
def _generate_identifier(self, training_mode: str, **kwargs: Any) -> str:
|
||
"""
|
||
根据训练模式和参数生成模型的唯一标识符 (identifier)。
|
||
这个标识符将作为版本文件中的key,并用于构建目录路径。
|
||
|
||
Args:
|
||
training_mode (str): 训练模式 ('product', 'store', 'global')。
|
||
**kwargs: 从API请求中传递的参数字典。
|
||
|
||
Returns:
|
||
str: 模型的唯一标识符。
|
||
|
||
Raises:
|
||
ValueError: 如果缺少必要的参数。
|
||
"""
|
||
if training_mode == 'product':
|
||
product_id = kwargs.get('product_id')
|
||
if not product_id:
|
||
raise ValueError("按药品训练模式需要 'product_id'。")
|
||
# 对于药品训练,数据范围由 store_id 定义
|
||
store_id = kwargs.get('store_id')
|
||
scope = store_id if store_id is not None else 'all'
|
||
return f"product_{product_id}_scope_{scope}"
|
||
|
||
elif training_mode == 'store':
|
||
store_id = kwargs.get('store_id')
|
||
if not store_id:
|
||
raise ValueError("按店铺训练模式需要 'store_id'。")
|
||
|
||
product_scope = kwargs.get('product_scope', 'all')
|
||
if product_scope == 'specific':
|
||
product_ids = kwargs.get('product_ids')
|
||
if not product_ids:
|
||
raise ValueError("店铺训练选择 specific 范围时需要 'product_ids'。")
|
||
# 如果只有一个ID,直接使用ID;否则使用哈希
|
||
scope = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids)
|
||
else:
|
||
scope = 'all'
|
||
return f"store_{store_id}_products_{scope}"
|
||
|
||
elif training_mode == 'global':
|
||
training_scope = kwargs.get('training_scope', 'all')
|
||
|
||
if training_scope in ['all', 'all_stores_all_products']:
|
||
scope_part = 'all'
|
||
elif training_scope == 'selected_stores':
|
||
store_ids = kwargs.get('store_ids')
|
||
if not store_ids:
|
||
raise ValueError("全局训练选择 selected_stores 范围时需要 'store_ids'。")
|
||
# 如果只有一个ID,直接使用ID;否则使用哈希
|
||
scope_id = store_ids[0] if len(store_ids) == 1 else self._hash_ids(store_ids)
|
||
scope_part = f"stores_{scope_id}"
|
||
elif training_scope == 'selected_products':
|
||
product_ids = kwargs.get('product_ids')
|
||
if not product_ids:
|
||
raise ValueError("全局训练选择 selected_products 范围时需要 'product_ids'。")
|
||
# 如果只有一个ID,直接使用ID;否则使用哈希
|
||
scope_id = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids)
|
||
scope_part = f"products_{scope_id}"
|
||
elif training_scope == 'custom':
|
||
store_ids = kwargs.get('store_ids')
|
||
product_ids = kwargs.get('product_ids')
|
||
if not store_ids or not product_ids:
|
||
raise ValueError("全局训练选择 custom 范围时需要 'store_ids' 和 'product_ids'。")
|
||
s_id = store_ids[0] if len(store_ids) == 1 else self._hash_ids(store_ids)
|
||
p_id = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids)
|
||
scope_part = f"custom_s_{s_id}_p_{p_id}"
|
||
else:
|
||
raise ValueError(f"未知的全局训练范围: {training_scope}")
|
||
|
||
aggregation_method = kwargs.get('aggregation_method', 'sum')
|
||
return f"global_{scope_part}_{aggregation_method}"
|
||
|
||
else:
|
||
raise ValueError(f"未知的训练模式: {training_mode}")
|
||
|
||
def get_next_version(self, identifier: str) -> int:
|
||
"""
|
||
获取指定标识符的下一个版本号。
|
||
此方法是线程安全的。
|
||
|
||
Args:
|
||
identifier (str): 模型的唯一标识符。
|
||
|
||
Returns:
|
||
int: 下一个可用的版本号 (从1开始)。
|
||
"""
|
||
with self.lock:
|
||
try:
|
||
if os.path.exists(self.versions_file):
|
||
with open(self.versions_file, 'r', encoding='utf-8') as f:
|
||
versions_data = json.load(f)
|
||
else:
|
||
versions_data = {}
|
||
|
||
# 如果标识符不存在,当前版本为0,下一个版本即为1
|
||
current_version = versions_data.get(identifier, 0)
|
||
return current_version + 1
|
||
except (IOError, json.JSONDecodeError) as e:
|
||
# 如果文件损坏或读取失败,从0开始
|
||
print(f"警告: 读取版本文件 '{self.versions_file}' 失败: {e}。将从版本1开始。")
|
||
return 1
|
||
|
||
def save_version_info(self, identifier: str, new_version: int):
|
||
"""
|
||
训练成功后,更新版本文件。
|
||
此方法是线程安全的。
|
||
|
||
Args:
|
||
identifier (str): 模型的唯一标识符。
|
||
new_version (int): 要保存的新的版本号。
|
||
"""
|
||
with self.lock:
|
||
try:
|
||
if os.path.exists(self.versions_file):
|
||
with open(self.versions_file, 'r', encoding='utf-8') as f:
|
||
versions_data = json.load(f)
|
||
else:
|
||
versions_data = {}
|
||
|
||
versions_data[identifier] = new_version
|
||
|
||
with open(self.versions_file, 'w', encoding='utf-8') as f:
|
||
json.dump(versions_data, f, indent=4, ensure_ascii=False)
|
||
except (IOError, json.JSONDecodeError) as e:
|
||
print(f"错误: 保存版本信息到 '{self.versions_file}' 失败: {e}")
|
||
# 在这种情况下,可以选择抛出异常或采取其他恢复措施
|
||
raise
|
||
|
||
def get_model_paths(self, training_mode: str, model_type: str, **kwargs: Any) -> Dict[str, Any]:
|
||
"""
|
||
主入口函数:为一次新的训练获取所有相关路径和版本信息。
|
||
此方法遵循扁平化文件存储规范,将逻辑路径编码到文件名中。
|
||
|
||
Args:
|
||
training_mode (str): 训练模式 ('product', 'store', 'global')。
|
||
model_type (str): 模型类型 (e.g., 'mlstm', 'kan')。
|
||
**kwargs: 从API请求中传递的参数字典。
|
||
|
||
Returns:
|
||
Dict[str, Any]: 一个包含所有路径和关键信息的字典。
|
||
"""
|
||
# 1. 生成不含模型类型和版本的核心标识符,并将其中的分隔符替换为下划线
|
||
# 例如:product/P001/all -> product_P001_all
|
||
base_identifier = self._generate_identifier(training_mode, **kwargs)
|
||
|
||
# 规范化处理,将 'scope' 'products' 等关键字替换为更简洁的形式
|
||
# 例如 product_P001_scope_all -> product_P001_all
|
||
core_prefix = base_identifier.replace('_scope_', '_').replace('_products_', '_')
|
||
|
||
# 2. 构建用于版本控制的完整标识符 (不含版本号)
|
||
# 例如: product_P001_all_mlstm
|
||
version_control_identifier = f"{core_prefix}_{model_type}"
|
||
|
||
# 3. 获取下一个版本号
|
||
next_version = self.get_next_version(version_control_identifier)
|
||
version_str = f"v{next_version}"
|
||
|
||
# 4. 构建最终的文件名前缀,包含版本号
|
||
# 例如: product_P001_all_mlstm_v2
|
||
filename_prefix = f"{version_control_identifier}_{version_str}"
|
||
|
||
# 5. 确保 `saved_models` 和 `saved_models/checkpoints` 目录存在
|
||
checkpoints_base_dir = os.path.join(self.base_dir, 'checkpoints')
|
||
os.makedirs(self.base_dir, exist_ok=True)
|
||
os.makedirs(checkpoints_base_dir, exist_ok=True)
|
||
|
||
# 6. 构建并返回包含所有扁平化路径和关键信息的字典
|
||
return {
|
||
"identifier": version_control_identifier, # 用于版本控制的key
|
||
"filename_prefix": filename_prefix, # 用于数据库和文件查找
|
||
"version": next_version,
|
||
"base_dir": self.base_dir,
|
||
"model_path": os.path.join(self.base_dir, f"{filename_prefix}_model.pth"),
|
||
"metadata_path": os.path.join(self.base_dir, f"{filename_prefix}_metadata.json"),
|
||
"loss_curve_path": os.path.join(self.base_dir, f"{filename_prefix}_loss_curve.png"),
|
||
"checkpoint_dir": checkpoints_base_dir, # 指向公共的检查点目录
|
||
"best_checkpoint_path": os.path.join(checkpoints_base_dir, f"{filename_prefix}_checkpoint_best.pth"),
|
||
# 为动态epoch检查点提供一个格式化模板
|
||
"epoch_checkpoint_template": os.path.join(checkpoints_base_dir, f"{filename_prefix}_checkpoint_epoch_{{N}}.pth")
|
||
}
|
||
|
||
def get_model_path_for_prediction(self, training_mode: str, model_type: str, version: int, **kwargs: Any) -> Optional[str]:
|
||
"""
|
||
获取用于预测的已存在模型的完整路径 (遵循扁平化规范)。
|
||
|
||
Args:
|
||
training_mode (str): 训练模式。
|
||
model_type (str): 模型类型。
|
||
version (int): 模型版本号。
|
||
**kwargs: 其他用于定位模型的参数。
|
||
|
||
Returns:
|
||
Optional[str]: 模型的完整路径,如果不存在则返回None。
|
||
"""
|
||
# 1. 生成不含模型类型和版本的核心标识符
|
||
base_identifier = self._generate_identifier(training_mode, **kwargs)
|
||
core_prefix = base_identifier.replace('_scope_', '_').replace('_products_', '_')
|
||
|
||
# 2. 构建用于版本控制的标识符
|
||
version_control_identifier = f"{core_prefix}_{model_type}"
|
||
|
||
# 3. 构建完整的文件名前缀
|
||
version_str = f"v{version}"
|
||
filename_prefix = f"{version_control_identifier}_{version_str}"
|
||
|
||
# 4. 构建模型文件的完整路径
|
||
model_path = os.path.join(self.base_dir, f"{filename_prefix}_model.pth")
|
||
|
||
return model_path if os.path.exists(model_path) else None
|