ShopTRAINING/server/utils/model_manager.py
xz2000 28bae35783 # 扁平化模型数据处理规范 (最终版)
**版本**: 4.0 (最终版)
**核心思想**: 逻辑路径被转换为文件名的一部分,实现极致扁平化的文件存储。

---

## 一、 文件保存规则

### 1.1. 核心原则

所有元数据都被编码到文件名中。一个逻辑上的层级路径(例如 `product/P001_all/mlstm/v2`)应该被转换为一个用下划线连接的文件名前缀(`product_P001_all_mlstm_v2`)。

### 1.2. 文件存储位置

-   **最终产物**: 所有最终模型、元数据文件、损失图等,统一存放在 `saved_models/` 根目录下。
-   **过程文件**: 所有训练过程中的检查点文件,统一存放在 `saved_models/checkpoints/` 目录下。

### 1.3. 文件名生成规则

1.  **构建逻辑路径**: 根据训练参数(模式、范围、类型、版本)确定逻辑路径。
    -   *示例*: `product/P001_all/mlstm/v2`

2.  **生成文件名前缀**: 将逻辑路径中的所有 `/` 替换为 `_`。
    -   *示例*: `product_P001_all_mlstm_v2`

3.  **拼接文件后缀**: 在前缀后加上描述文件类型的后缀。
    -   `_model.pth`
    -   `_metadata.json`
    -   `_loss_curve.png`
    -   `_checkpoint_best.pth`
    -   `_checkpoint_epoch_{N}.pth`

#### **完整示例:**

-   **最终模型**: `saved_models/product_P001_all_mlstm_v2_model.pth`
-   **元数据**: `saved_models/product_P001_all_mlstm_v2_metadata.json`
-   **最佳检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_best.pth`
-   **Epoch 50 检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_epoch_50.pth`

---

## 二、 文件读取规则

1.  **确定模型元数据**: 根据需求确定要加载的模型的训练模式、范围、类型和版本。
2.  **构建文件名前缀**: 按照与保存时相同的逻辑,将元数据拼接成文件名前缀(例如 `product_P001_all_mlstm_v2`)。
3.  **定位文件**:
    -   要加载最终模型,查找文件: `saved_models/{prefix}_model.pth`。
    -   要加载最佳检查点,查找文件: `saved_models/checkpoints/{prefix}_checkpoint_best.pth`。

---

## 三、 数据库存储规则

数据库用于索引,应存储足以重构文件名前缀的关键元数据。

#### **`models` 表结构建议:**

| 字段名 | 类型 | 描述 | 示例 |
| :--- | :--- | :--- | :--- |
| `id` | INTEGER | 主键 | 1 |
| `filename_prefix` | TEXT | **完整文件名前缀,可作为唯一标识** | `product_P001_all_mlstm_v2` |
| `model_identifier`| TEXT | 用于版本控制的标识符 (不含版本) | `product_P001_all_mlstm` |
| `version` | INTEGER | 版本号 | `2` |
| `status` | TEXT | 模型状态 | `completed`, `training`, `failed` |
| `created_at` | TEXT | 创建时间 | `2025-07-21 02:29:00` |
| `metrics_summary`| TEXT | 关键性能指标的JSON字符串 | `{"rmse": 10.5, "r2": 0.89}` |

#### **保存逻辑:**
-   训练完成后,向表中插入一条记录。`filename_prefix` 字段是查找与该次训练相关的所有文件的关键。

---

## 四、 版本记录规则

版本管理依赖于根目录下的 `versions.json` 文件,以实现原子化、线程安全的版本号递增。

-   **文件名**: `versions.json`
-   **位置**: `saved_models/versions.json`
-   **结构**: 一个JSON对象,`key` 是不包含版本号的标识符,`value` 是该标识符下最新的版本号(整数)。
    -   **Key**: `{prefix_core}_{model_type}` (例如: `product_P001_all_mlstm`)
    -   **Value**: `Integer`

#### **`versions.json` 示例:**
```json
{
  "product_P001_all_mlstm": 2,
  "store_S001_P002_transformer": 1
}
```

#### **版本管理流程:**

1.  **获取新版本**: 开始训练前,构建 `key`。读取 `versions.json`,找到对应 `key` 的 `value`。新版本号为 `value + 1` (若key不存在,则为 `1`)。
2.  **更新版本**: 训练成功后,将新的版本号写回到 `versions.json`。此过程**必须使用文件锁**以防止并发冲突。

调试完成药品预测和店铺预测
2025-07-21 16:39:52 +08:00

401 lines
15 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
统一模型管理工具
处理模型文件的统一命名、存储和检索
"""
import os
import json
import torch
import glob
from datetime import datetime
from typing import List, Dict, Optional, Tuple
from core.config import DEFAULT_MODEL_DIR
class ModelManager:
"""统一模型管理器"""
def __init__(self, model_dir: str = DEFAULT_MODEL_DIR):
self.model_dir = model_dir
self.ensure_model_dir()
def ensure_model_dir(self):
"""确保模型目录存在"""
if not os.path.exists(self.model_dir):
os.makedirs(self.model_dir)
def generate_model_filename(self,
product_id: str,
model_type: str,
version: str,
store_id: Optional[str] = None,
training_mode: str = 'product',
aggregation_method: Optional[str] = None) -> str:
"""
生成统一的模型文件名
格式规范:
- 产品模式: {model_type}_product_{product_id}_{version}.pth
- 店铺模式: {model_type}_store_{store_id}_{product_id}_{version}.pth
- 全局模式: {model_type}_global_{product_id}_{aggregation_method}_{version}.pth
"""
if training_mode == 'store' and store_id:
return f"{model_type}_store_{store_id}_{product_id}_{version}.pth"
elif training_mode == 'global' and aggregation_method:
return f"{model_type}_global_{product_id}_{aggregation_method}_{version}.pth"
else:
# 默认产品模式
return f"{model_type}_product_{product_id}_{version}.pth"
def save_model(self,
model_data: dict,
product_id: str,
model_type: str,
version: str,
store_id: Optional[str] = None,
training_mode: str = 'product',
aggregation_method: Optional[str] = None,
product_name: Optional[str] = None) -> str:
"""
保存模型到统一位置
参数:
model_data: 包含模型状态和配置的字典
product_id: 产品ID
model_type: 模型类型
version: 版本号
store_id: 店铺ID (可选)
training_mode: 训练模式
aggregation_method: 聚合方法 (可选)
product_name: 产品名称 (可选)
返回:
模型文件路径
"""
filename = self.generate_model_filename(
product_id, model_type, version, store_id, training_mode, aggregation_method
)
# 统一保存到根目录,避免复杂的子目录结构
model_path = os.path.join(self.model_dir, filename)
# 增强模型数据,添加管理信息
enhanced_model_data = model_data.copy()
enhanced_model_data.update({
'model_manager_info': {
'product_id': product_id,
'product_name': product_name or product_id,
'model_type': model_type,
'version': version,
'store_id': store_id,
'training_mode': training_mode,
'aggregation_method': aggregation_method,
'created_at': datetime.now().isoformat(),
'filename': filename
}
})
# 保存模型
torch.save(enhanced_model_data, model_path)
print(f"模型已保存: {model_path}")
return model_path
def list_models(self,
product_id: Optional[str] = None,
model_type: Optional[str] = None,
store_id: Optional[str] = None,
training_mode: Optional[str] = None,
page: Optional[int] = None,
page_size: Optional[int] = None) -> Dict:
"""
列出所有模型文件
参数:
product_id: 产品ID过滤 (可选)
model_type: 模型类型过滤 (可选)
store_id: 店铺ID过滤 (可选)
training_mode: 训练模式过滤 (可选)
page: 页码从1开始 (可选)
page_size: 每页数量 (可选)
返回:
包含模型列表和分页信息的字典
"""
models = []
# 搜索所有.pth文件
pattern = os.path.join(self.model_dir, "*.pth")
model_files = glob.glob(pattern)
for model_file in model_files:
try:
# 解析文件名
filename = os.path.basename(model_file)
model_info = self.parse_model_filename(filename)
if not model_info:
continue
# 尝试从模型文件中读取额外信息
try:
# Try with weights_only=False first for backward compatibility
try:
model_data = torch.load(model_file, map_location='cpu', weights_only=False)
except Exception:
# If that fails, try with weights_only=True (newer PyTorch versions)
model_data = torch.load(model_file, map_location='cpu', weights_only=True)
if 'model_manager_info' in model_data:
# 使用新的管理信息
manager_info = model_data['model_manager_info']
model_info.update(manager_info)
# 添加评估指标
if 'metrics' in model_data:
model_info['metrics'] = model_data['metrics']
# 添加配置信息
if 'config' in model_data:
model_info['config'] = model_data['config']
except Exception as e:
print(f"读取模型文件失败 {model_file}: {e}")
# Continue with just the filename-based info
# 应用过滤器
if product_id and model_info.get('product_id') != product_id:
continue
if model_type and model_info.get('model_type') != model_type:
continue
if store_id and model_info.get('store_id') != store_id:
continue
if training_mode and model_info.get('training_mode') != training_mode:
continue
# 添加文件信息
model_info['filename'] = filename
model_info['file_path'] = model_file
model_info['file_size'] = os.path.getsize(model_file)
model_info['modified_at'] = datetime.fromtimestamp(
os.path.getmtime(model_file)
).isoformat()
models.append(model_info)
except Exception as e:
print(f"处理模型文件失败 {model_file}: {e}")
continue
# 按创建时间排序(最新的在前)
models.sort(key=lambda x: x.get('created_at', x.get('modified_at', '')), reverse=True)
# 计算分页信息
total_count = len(models)
# 如果没有指定分页参数,返回所有数据
if page is None or page_size is None:
return {
'models': models,
'pagination': {
'total': total_count,
'page': 1,
'page_size': total_count,
'total_pages': 1,
'has_next': False,
'has_previous': False
}
}
# 应用分页
total_pages = (total_count + page_size - 1) // page_size if page_size > 0 else 1
start_index = (page - 1) * page_size
end_index = start_index + page_size
paginated_models = models[start_index:end_index]
return {
'models': paginated_models,
'pagination': {
'total': total_count,
'page': page,
'page_size': page_size,
'total_pages': total_pages,
'has_next': page < total_pages,
'has_previous': page > 1
}
}
def parse_model_filename(self, filename: str) -> Optional[Dict]:
"""
解析模型文件名,提取模型信息
支持的格式:
- {model_type}_product_{product_id}_{version}.pth
- {model_type}_store_{store_id}_{product_id}_{version}.pth
- {model_type}_global_{product_id}_{aggregation_method}_{version}.pth
- 旧格式兼容
"""
if not filename.endswith('.pth'):
return None
base_name = filename.replace('.pth', '')
try:
# 新格式解析
if '_product_' in base_name:
# 产品模式: model_type_product_product_id_version
parts = base_name.split('_product_')
model_type = parts[0]
rest = parts[1]
# 分离产品ID和版本
if '_v' in rest:
last_v_index = rest.rfind('_v')
product_id = rest[:last_v_index]
version = rest[last_v_index+1:]
else:
product_id = rest
version = 'v1'
return {
'model_type': model_type,
'product_id': product_id,
'version': version,
'training_mode': 'product',
'store_id': None,
'aggregation_method': None
}
elif '_store_' in base_name:
# 店铺模式: model_type_store_store_id_product_id_version
parts = base_name.split('_store_')
model_type = parts[0]
rest = parts[1]
# 分离店铺ID、产品ID和版本
rest_parts = rest.split('_')
if len(rest_parts) >= 3:
store_id = rest_parts[0]
if rest_parts[-1].startswith('v'):
# 最后一部分是版本号
version = rest_parts[-1]
product_id = '_'.join(rest_parts[1:-1])
else:
version = 'v1'
product_id = '_'.join(rest_parts[1:])
return {
'model_type': model_type,
'product_id': product_id,
'version': version,
'training_mode': 'store',
'store_id': store_id,
'aggregation_method': None
}
elif '_global_' in base_name:
# 全局模式: model_type_global_product_id_aggregation_method_version
parts = base_name.split('_global_')
model_type = parts[0]
rest = parts[1]
rest_parts = rest.split('_')
if len(rest_parts) >= 3:
if rest_parts[-1].startswith('v'):
# 最后一部分是版本号
version = rest_parts[-1]
aggregation_method = rest_parts[-2]
product_id = '_'.join(rest_parts[:-2])
else:
version = 'v1'
aggregation_method = rest_parts[-1]
product_id = '_'.join(rest_parts[:-1])
return {
'model_type': model_type,
'product_id': product_id,
'version': version,
'training_mode': 'global',
'store_id': None,
'aggregation_method': aggregation_method
}
# 兼容以 _model.pth 结尾的格式
elif base_name.endswith('_model'):
name_part = base_name.rsplit('_model', 1)[0]
parts = name_part.split('_')
# 假设格式为 {product_id}_{...}_{model_type}_{version}
if len(parts) >= 3:
version = parts[-1]
model_type = parts[-2]
product_id = '_'.join(parts[:-2]) # The rest is product_id + scope
return {
'model_type': model_type,
'product_id': product_id,
'version': version,
'training_mode': 'product', # Assumption
'store_id': None,
'aggregation_method': None
}
# 兼容旧格式
else:
# 尝试解析其他格式
if 'model_product_' in base_name:
parts = base_name.split('_model_product_')
model_type = parts[0]
product_part = parts[1]
if '_v' in product_part:
last_v_index = product_part.rfind('_v')
product_id = product_part[:last_v_index]
version = product_part[last_v_index+1:]
else:
product_id = product_part
version = 'v1'
return {
'model_type': model_type,
'product_id': product_id,
'version': version,
'training_mode': 'product',
'store_id': None,
'aggregation_method': None
}
except Exception as e:
print(f"解析文件名失败 {filename}: {e}")
return None
def delete_model(self, model_file: str) -> bool:
"""删除模型文件"""
try:
if os.path.exists(model_file):
os.remove(model_file)
print(f"已删除模型文件: {model_file}")
return True
else:
print(f"模型文件不存在: {model_file}")
return False
except Exception as e:
print(f"删除模型文件失败: {e}")
return False
def get_model_by_id(self, model_id: str) -> Optional[Dict]:
"""根据模型ID获取模型信息"""
models = self.list_models()
for model in models:
if model.get('filename', '').replace('.pth', '') == model_id:
return model
return None
# 全局模型管理器实例
# 确保使用项目根目录的saved_models而不是相对于当前工作目录
import os
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.dirname(os.path.dirname(current_dir)) # 向上两级到项目根目录
absolute_model_dir = os.path.join(project_root, 'saved_models')
model_manager = ModelManager(absolute_model_dir)