ShopTRAINING/server/utils/file_save.py
xz2000 28bae35783 # 扁平化模型数据处理规范 (最终版)
**版本**: 4.0 (最终版)
**核心思想**: 逻辑路径被转换为文件名的一部分,实现极致扁平化的文件存储。

---

## 一、 文件保存规则

### 1.1. 核心原则

所有元数据都被编码到文件名中。一个逻辑上的层级路径(例如 `product/P001_all/mlstm/v2`)应该被转换为一个用下划线连接的文件名前缀(`product_P001_all_mlstm_v2`)。

### 1.2. 文件存储位置

-   **最终产物**: 所有最终模型、元数据文件、损失图等,统一存放在 `saved_models/` 根目录下。
-   **过程文件**: 所有训练过程中的检查点文件,统一存放在 `saved_models/checkpoints/` 目录下。

### 1.3. 文件名生成规则

1.  **构建逻辑路径**: 根据训练参数(模式、范围、类型、版本)确定逻辑路径。
    -   *示例*: `product/P001_all/mlstm/v2`

2.  **生成文件名前缀**: 将逻辑路径中的所有 `/` 替换为 `_`。
    -   *示例*: `product_P001_all_mlstm_v2`

3.  **拼接文件后缀**: 在前缀后加上描述文件类型的后缀。
    -   `_model.pth`
    -   `_metadata.json`
    -   `_loss_curve.png`
    -   `_checkpoint_best.pth`
    -   `_checkpoint_epoch_{N}.pth`

#### **完整示例:**

-   **最终模型**: `saved_models/product_P001_all_mlstm_v2_model.pth`
-   **元数据**: `saved_models/product_P001_all_mlstm_v2_metadata.json`
-   **最佳检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_best.pth`
-   **Epoch 50 检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_epoch_50.pth`

---

## 二、 文件读取规则

1.  **确定模型元数据**: 根据需求确定要加载的模型的训练模式、范围、类型和版本。
2.  **构建文件名前缀**: 按照与保存时相同的逻辑,将元数据拼接成文件名前缀(例如 `product_P001_all_mlstm_v2`)。
3.  **定位文件**:
    -   要加载最终模型,查找文件: `saved_models/{prefix}_model.pth`。
    -   要加载最佳检查点,查找文件: `saved_models/checkpoints/{prefix}_checkpoint_best.pth`。

---

## 三、 数据库存储规则

数据库用于索引,应存储足以重构文件名前缀的关键元数据。

#### **`models` 表结构建议:**

| 字段名 | 类型 | 描述 | 示例 |
| :--- | :--- | :--- | :--- |
| `id` | INTEGER | 主键 | 1 |
| `filename_prefix` | TEXT | **完整文件名前缀,可作为唯一标识** | `product_P001_all_mlstm_v2` |
| `model_identifier`| TEXT | 用于版本控制的标识符 (不含版本) | `product_P001_all_mlstm` |
| `version` | INTEGER | 版本号 | `2` |
| `status` | TEXT | 模型状态 | `completed`, `training`, `failed` |
| `created_at` | TEXT | 创建时间 | `2025-07-21 02:29:00` |
| `metrics_summary`| TEXT | 关键性能指标的JSON字符串 | `{"rmse": 10.5, "r2": 0.89}` |

#### **保存逻辑:**
-   训练完成后,向表中插入一条记录。`filename_prefix` 字段是查找与该次训练相关的所有文件的关键。

---

## 四、 版本记录规则

版本管理依赖于根目录下的 `versions.json` 文件,以实现原子化、线程安全的版本号递增。

-   **文件名**: `versions.json`
-   **位置**: `saved_models/versions.json`
-   **结构**: 一个JSON对象,`key` 是不包含版本号的标识符,`value` 是该标识符下最新的版本号(整数)。
    -   **Key**: `{prefix_core}_{model_type}` (例如: `product_P001_all_mlstm`)
    -   **Value**: `Integer`

#### **`versions.json` 示例:**
```json
{
  "product_P001_all_mlstm": 2,
  "store_S001_P002_transformer": 1
}
```

#### **版本管理流程:**

1.  **获取新版本**: 开始训练前,构建 `key`。读取 `versions.json`,找到对应 `key` 的 `value`。新版本号为 `value + 1` (若key不存在,则为 `1`)。
2.  **更新版本**: 训练成功后,将新的版本号写回到 `versions.json`。此过程**必须使用文件锁**以防止并发冲突。

调试完成药品预测和店铺预测
2025-07-21 16:39:52 +08:00

257 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import json
import hashlib
from threading import Lock
from typing import List, Dict, Any, Optional
class ModelPathManager:
"""
根据定义的规则管理模型训练产物的保存路径。
此类旨在集中处理所有与文件系统交互的路径生成逻辑,
确保整个应用程序遵循统一的模型保存标准。
"""
def __init__(self, base_dir: str = 'saved_models'):
"""
初始化路径管理器。
Args:
base_dir (str): 所有模型保存的根目录。
"""
# 始终使用相对于项目根目录的相对路径
self.base_dir = base_dir
self.versions_file = os.path.join(self.base_dir, 'versions.json')
self.lock = Lock()
# 确保根目录存在
os.makedirs(self.base_dir, exist_ok=True)
def _hash_ids(self, ids: List[str]) -> str:
"""
对ID列表进行排序和哈希生成一个稳定的、简短的哈希值。
Args:
ids (List[str]): 需要哈希的ID列表。
Returns:
str: 代表该ID集合的10位短哈希字符串。
"""
if not ids:
return 'none'
# 排序以确保对于相同集合的ID即使顺序不同结果也一样
sorted_ids = sorted([str(i) for i in ids])
id_string = ",".join(sorted_ids)
# 使用SHA256生成哈希值并截取前10位
return hashlib.sha256(id_string.encode('utf-8')).hexdigest()[:10]
def _generate_identifier(self, training_mode: str, **kwargs: Any) -> str:
"""
根据训练模式和参数生成模型的唯一标识符 (identifier)。
这个标识符将作为版本文件中的key并用于构建目录路径。
Args:
training_mode (str): 训练模式 ('product', 'store', 'global')。
**kwargs: 从API请求中传递的参数字典。
Returns:
str: 模型的唯一标识符。
Raises:
ValueError: 如果缺少必要的参数。
"""
if training_mode == 'product':
product_id = kwargs.get('product_id')
if not product_id:
raise ValueError("按药品训练模式需要 'product_id'")
# 对于药品训练,数据范围由 store_id 定义
store_id = kwargs.get('store_id')
scope = store_id if store_id is not None else 'all'
return f"product_{product_id}_scope_{scope}"
elif training_mode == 'store':
store_id = kwargs.get('store_id')
if not store_id:
raise ValueError("按店铺训练模式需要 'store_id'")
product_scope = kwargs.get('product_scope', 'all')
if product_scope == 'specific':
product_ids = kwargs.get('product_ids')
if not product_ids:
raise ValueError("店铺训练选择 specific 范围时需要 'product_ids'")
# 如果只有一个ID直接使用ID否则使用哈希
scope = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids)
else:
scope = 'all'
return f"store_{store_id}_products_{scope}"
elif training_mode == 'global':
training_scope = kwargs.get('training_scope', 'all')
if training_scope in ['all', 'all_stores_all_products']:
scope_part = 'all'
elif training_scope == 'selected_stores':
store_ids = kwargs.get('store_ids')
if not store_ids:
raise ValueError("全局训练选择 selected_stores 范围时需要 'store_ids'")
# 如果只有一个ID直接使用ID否则使用哈希
scope_id = store_ids[0] if len(store_ids) == 1 else self._hash_ids(store_ids)
scope_part = f"stores_{scope_id}"
elif training_scope == 'selected_products':
product_ids = kwargs.get('product_ids')
if not product_ids:
raise ValueError("全局训练选择 selected_products 范围时需要 'product_ids'")
# 如果只有一个ID直接使用ID否则使用哈希
scope_id = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids)
scope_part = f"products_{scope_id}"
elif training_scope == 'custom':
store_ids = kwargs.get('store_ids')
product_ids = kwargs.get('product_ids')
if not store_ids or not product_ids:
raise ValueError("全局训练选择 custom 范围时需要 'store_ids''product_ids'")
s_id = store_ids[0] if len(store_ids) == 1 else self._hash_ids(store_ids)
p_id = product_ids[0] if len(product_ids) == 1 else self._hash_ids(product_ids)
scope_part = f"custom_s_{s_id}_p_{p_id}"
else:
raise ValueError(f"未知的全局训练范围: {training_scope}")
aggregation_method = kwargs.get('aggregation_method', 'sum')
return f"global_{scope_part}_{aggregation_method}"
else:
raise ValueError(f"未知的训练模式: {training_mode}")
def get_next_version(self, identifier: str) -> int:
"""
获取指定标识符的下一个版本号。
此方法是线程安全的。
Args:
identifier (str): 模型的唯一标识符。
Returns:
int: 下一个可用的版本号 (从1开始)。
"""
with self.lock:
try:
if os.path.exists(self.versions_file):
with open(self.versions_file, 'r', encoding='utf-8') as f:
versions_data = json.load(f)
else:
versions_data = {}
# 如果标识符不存在当前版本为0下一个版本即为1
current_version = versions_data.get(identifier, 0)
return current_version + 1
except (IOError, json.JSONDecodeError) as e:
# 如果文件损坏或读取失败从0开始
print(f"警告: 读取版本文件 '{self.versions_file}' 失败: {e}。将从版本1开始。")
return 1
def save_version_info(self, identifier: str, new_version: int):
"""
训练成功后,更新版本文件。
此方法是线程安全的。
Args:
identifier (str): 模型的唯一标识符。
new_version (int): 要保存的新的版本号。
"""
with self.lock:
try:
if os.path.exists(self.versions_file):
with open(self.versions_file, 'r', encoding='utf-8') as f:
versions_data = json.load(f)
else:
versions_data = {}
versions_data[identifier] = new_version
with open(self.versions_file, 'w', encoding='utf-8') as f:
json.dump(versions_data, f, indent=4, ensure_ascii=False)
except (IOError, json.JSONDecodeError) as e:
print(f"错误: 保存版本信息到 '{self.versions_file}' 失败: {e}")
# 在这种情况下,可以选择抛出异常或采取其他恢复措施
raise
def get_model_paths(self, training_mode: str, model_type: str, **kwargs: Any) -> Dict[str, Any]:
"""
主入口函数:为一次新的训练获取所有相关路径和版本信息。
此方法遵循扁平化文件存储规范,将逻辑路径编码到文件名中。
Args:
training_mode (str): 训练模式 ('product', 'store', 'global')。
model_type (str): 模型类型 (e.g., 'mlstm', 'kan')。
**kwargs: 从API请求中传递的参数字典。
Returns:
Dict[str, Any]: 一个包含所有路径和关键信息的字典。
"""
# 1. 生成不含模型类型和版本的核心标识符,并将其中的分隔符替换为下划线
# 例如product/P001/all -> product_P001_all
base_identifier = self._generate_identifier(training_mode, **kwargs)
# 规范化处理,将 'scope' 'products' 等关键字替换为更简洁的形式
# 例如 product_P001_scope_all -> product_P001_all
core_prefix = base_identifier.replace('_scope_', '_').replace('_products_', '_')
# 2. 构建用于版本控制的完整标识符 (不含版本号)
# 例如: product_P001_all_mlstm
version_control_identifier = f"{core_prefix}_{model_type}"
# 3. 获取下一个版本号
next_version = self.get_next_version(version_control_identifier)
version_str = f"v{next_version}"
# 4. 构建最终的文件名前缀,包含版本号
# 例如: product_P001_all_mlstm_v2
filename_prefix = f"{version_control_identifier}_{version_str}"
# 5. 确保 `saved_models` 和 `saved_models/checkpoints` 目录存在
checkpoints_base_dir = os.path.join(self.base_dir, 'checkpoints')
os.makedirs(self.base_dir, exist_ok=True)
os.makedirs(checkpoints_base_dir, exist_ok=True)
# 6. 构建并返回包含所有扁平化路径和关键信息的字典
return {
"identifier": version_control_identifier, # 用于版本控制的key
"filename_prefix": filename_prefix, # 用于数据库和文件查找
"version": next_version,
"base_dir": self.base_dir,
"model_path": os.path.join(self.base_dir, f"{filename_prefix}_model.pth"),
"metadata_path": os.path.join(self.base_dir, f"{filename_prefix}_metadata.json"),
"loss_curve_path": os.path.join(self.base_dir, f"{filename_prefix}_loss_curve.png"),
"checkpoint_dir": checkpoints_base_dir, # 指向公共的检查点目录
"best_checkpoint_path": os.path.join(checkpoints_base_dir, f"{filename_prefix}_checkpoint_best.pth"),
# 为动态epoch检查点提供一个格式化模板
"epoch_checkpoint_template": os.path.join(checkpoints_base_dir, f"{filename_prefix}_checkpoint_epoch_{{N}}.pth")
}
def get_model_path_for_prediction(self, training_mode: str, model_type: str, version: int, **kwargs: Any) -> Optional[str]:
"""
获取用于预测的已存在模型的完整路径 (遵循扁平化规范)。
Args:
training_mode (str): 训练模式。
model_type (str): 模型类型。
version (int): 模型版本号。
**kwargs: 其他用于定位模型的参数。
Returns:
Optional[str]: 模型的完整路径如果不存在则返回None。
"""
# 1. 生成不含模型类型和版本的核心标识符
base_identifier = self._generate_identifier(training_mode, **kwargs)
core_prefix = base_identifier.replace('_scope_', '_').replace('_products_', '_')
# 2. 构建用于版本控制的标识符
version_control_identifier = f"{core_prefix}_{model_type}"
# 3. 构建完整的文件名前缀
version_str = f"v{version}"
filename_prefix = f"{version_control_identifier}_{version_str}"
# 4. 构建模型文件的完整路径
model_path = os.path.join(self.base_dir, f"{filename_prefix}_model.pth")
return model_path if os.path.exists(model_path) else None