**版本**: 4.0 (最终版) **核心思想**: 逻辑路径被转换为文件名的一部分,实现极致扁平化的文件存储。 --- ## 一、 文件保存规则 ### 1.1. 核心原则 所有元数据都被编码到文件名中。一个逻辑上的层级路径(例如 `product/P001_all/mlstm/v2`)应该被转换为一个用下划线连接的文件名前缀(`product_P001_all_mlstm_v2`)。 ### 1.2. 文件存储位置 - **最终产物**: 所有最终模型、元数据文件、损失图等,统一存放在 `saved_models/` 根目录下。 - **过程文件**: 所有训练过程中的检查点文件,统一存放在 `saved_models/checkpoints/` 目录下。 ### 1.3. 文件名生成规则 1. **构建逻辑路径**: 根据训练参数(模式、范围、类型、版本)确定逻辑路径。 - *示例*: `product/P001_all/mlstm/v2` 2. **生成文件名前缀**: 将逻辑路径中的所有 `/` 替换为 `_`。 - *示例*: `product_P001_all_mlstm_v2` 3. **拼接文件后缀**: 在前缀后加上描述文件类型的后缀。 - `_model.pth` - `_metadata.json` - `_loss_curve.png` - `_checkpoint_best.pth` - `_checkpoint_epoch_{N}.pth` #### **完整示例:** - **最终模型**: `saved_models/product_P001_all_mlstm_v2_model.pth` - **元数据**: `saved_models/product_P001_all_mlstm_v2_metadata.json` - **最佳检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_best.pth` - **Epoch 50 检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_epoch_50.pth` --- ## 二、 文件读取规则 1. **确定模型元数据**: 根据需求确定要加载的模型的训练模式、范围、类型和版本。 2. **构建文件名前缀**: 按照与保存时相同的逻辑,将元数据拼接成文件名前缀(例如 `product_P001_all_mlstm_v2`)。 3. **定位文件**: - 要加载最终模型,查找文件: `saved_models/{prefix}_model.pth`。 - 要加载最佳检查点,查找文件: `saved_models/checkpoints/{prefix}_checkpoint_best.pth`。 --- ## 三、 数据库存储规则 数据库用于索引,应存储足以重构文件名前缀的关键元数据。 #### **`models` 表结构建议:** | 字段名 | 类型 | 描述 | 示例 | | :--- | :--- | :--- | :--- | | `id` | INTEGER | 主键 | 1 | | `filename_prefix` | TEXT | **完整文件名前缀,可作为唯一标识** | `product_P001_all_mlstm_v2` | | `model_identifier`| TEXT | 用于版本控制的标识符 (不含版本) | `product_P001_all_mlstm` | | `version` | INTEGER | 版本号 | `2` | | `status` | TEXT | 模型状态 | `completed`, `training`, `failed` | | `created_at` | TEXT | 创建时间 | `2025-07-21 02:29:00` | | `metrics_summary`| TEXT | 关键性能指标的JSON字符串 | `{"rmse": 10.5, "r2": 0.89}` | #### **保存逻辑:** - 训练完成后,向表中插入一条记录。`filename_prefix` 字段是查找与该次训练相关的所有文件的关键。 --- ## 四、 版本记录规则 版本管理依赖于根目录下的 `versions.json` 文件,以实现原子化、线程安全的版本号递增。 - **文件名**: `versions.json` - **位置**: `saved_models/versions.json` - **结构**: 一个JSON对象,`key` 是不包含版本号的标识符,`value` 是该标识符下最新的版本号(整数)。 - **Key**: `{prefix_core}_{model_type}` (例如: `product_P001_all_mlstm`) - **Value**: `Integer` #### **`versions.json` 示例:** ```json { "product_P001_all_mlstm": 2, "store_S001_P002_transformer": 1 } ``` #### **版本管理流程:** 1. **获取新版本**: 开始训练前,构建 `key`。读取 `versions.json`,找到对应 `key` 的 `value`。新版本号为 `value + 1` (若key不存在,则为 `1`)。 2. **更新版本**: 训练成功后,将新的版本号写回到 `versions.json`。此过程**必须使用文件锁**以防止并发冲突。 调试完成药品预测和店铺预测
401 lines
15 KiB
Python
401 lines
15 KiB
Python
"""
|
||
统一模型管理工具
|
||
处理模型文件的统一命名、存储和检索
|
||
"""
|
||
|
||
import os
|
||
import json
|
||
import torch
|
||
import glob
|
||
from datetime import datetime
|
||
from typing import List, Dict, Optional, Tuple
|
||
from core.config import DEFAULT_MODEL_DIR
|
||
|
||
|
||
class ModelManager:
|
||
"""统一模型管理器"""
|
||
|
||
def __init__(self, model_dir: str = DEFAULT_MODEL_DIR):
|
||
self.model_dir = model_dir
|
||
self.ensure_model_dir()
|
||
|
||
def ensure_model_dir(self):
|
||
"""确保模型目录存在"""
|
||
if not os.path.exists(self.model_dir):
|
||
os.makedirs(self.model_dir)
|
||
|
||
def generate_model_filename(self,
|
||
product_id: str,
|
||
model_type: str,
|
||
version: str,
|
||
store_id: Optional[str] = None,
|
||
training_mode: str = 'product',
|
||
aggregation_method: Optional[str] = None) -> str:
|
||
"""
|
||
生成统一的模型文件名
|
||
|
||
格式规范:
|
||
- 产品模式: {model_type}_product_{product_id}_{version}.pth
|
||
- 店铺模式: {model_type}_store_{store_id}_{product_id}_{version}.pth
|
||
- 全局模式: {model_type}_global_{product_id}_{aggregation_method}_{version}.pth
|
||
"""
|
||
if training_mode == 'store' and store_id:
|
||
return f"{model_type}_store_{store_id}_{product_id}_{version}.pth"
|
||
elif training_mode == 'global' and aggregation_method:
|
||
return f"{model_type}_global_{product_id}_{aggregation_method}_{version}.pth"
|
||
else:
|
||
# 默认产品模式
|
||
return f"{model_type}_product_{product_id}_{version}.pth"
|
||
|
||
def save_model(self,
|
||
model_data: dict,
|
||
product_id: str,
|
||
model_type: str,
|
||
version: str,
|
||
store_id: Optional[str] = None,
|
||
training_mode: str = 'product',
|
||
aggregation_method: Optional[str] = None,
|
||
product_name: Optional[str] = None) -> str:
|
||
"""
|
||
保存模型到统一位置
|
||
|
||
参数:
|
||
model_data: 包含模型状态和配置的字典
|
||
product_id: 产品ID
|
||
model_type: 模型类型
|
||
version: 版本号
|
||
store_id: 店铺ID (可选)
|
||
training_mode: 训练模式
|
||
aggregation_method: 聚合方法 (可选)
|
||
product_name: 产品名称 (可选)
|
||
|
||
返回:
|
||
模型文件路径
|
||
"""
|
||
filename = self.generate_model_filename(
|
||
product_id, model_type, version, store_id, training_mode, aggregation_method
|
||
)
|
||
|
||
# 统一保存到根目录,避免复杂的子目录结构
|
||
model_path = os.path.join(self.model_dir, filename)
|
||
|
||
# 增强模型数据,添加管理信息
|
||
enhanced_model_data = model_data.copy()
|
||
enhanced_model_data.update({
|
||
'model_manager_info': {
|
||
'product_id': product_id,
|
||
'product_name': product_name or product_id,
|
||
'model_type': model_type,
|
||
'version': version,
|
||
'store_id': store_id,
|
||
'training_mode': training_mode,
|
||
'aggregation_method': aggregation_method,
|
||
'created_at': datetime.now().isoformat(),
|
||
'filename': filename
|
||
}
|
||
})
|
||
|
||
# 保存模型
|
||
torch.save(enhanced_model_data, model_path)
|
||
|
||
print(f"模型已保存: {model_path}")
|
||
return model_path
|
||
|
||
def list_models(self,
|
||
product_id: Optional[str] = None,
|
||
model_type: Optional[str] = None,
|
||
store_id: Optional[str] = None,
|
||
training_mode: Optional[str] = None,
|
||
page: Optional[int] = None,
|
||
page_size: Optional[int] = None) -> Dict:
|
||
"""
|
||
列出所有模型文件
|
||
|
||
参数:
|
||
product_id: 产品ID过滤 (可选)
|
||
model_type: 模型类型过滤 (可选)
|
||
store_id: 店铺ID过滤 (可选)
|
||
training_mode: 训练模式过滤 (可选)
|
||
page: 页码,从1开始 (可选)
|
||
page_size: 每页数量 (可选)
|
||
|
||
返回:
|
||
包含模型列表和分页信息的字典
|
||
"""
|
||
models = []
|
||
|
||
# 搜索所有.pth文件
|
||
pattern = os.path.join(self.model_dir, "*.pth")
|
||
model_files = glob.glob(pattern)
|
||
|
||
for model_file in model_files:
|
||
try:
|
||
# 解析文件名
|
||
filename = os.path.basename(model_file)
|
||
model_info = self.parse_model_filename(filename)
|
||
|
||
if not model_info:
|
||
continue
|
||
|
||
# 尝试从模型文件中读取额外信息
|
||
try:
|
||
# Try with weights_only=False first for backward compatibility
|
||
try:
|
||
model_data = torch.load(model_file, map_location='cpu', weights_only=False)
|
||
except Exception:
|
||
# If that fails, try with weights_only=True (newer PyTorch versions)
|
||
model_data = torch.load(model_file, map_location='cpu', weights_only=True)
|
||
|
||
if 'model_manager_info' in model_data:
|
||
# 使用新的管理信息
|
||
manager_info = model_data['model_manager_info']
|
||
model_info.update(manager_info)
|
||
|
||
# 添加评估指标
|
||
if 'metrics' in model_data:
|
||
model_info['metrics'] = model_data['metrics']
|
||
|
||
# 添加配置信息
|
||
if 'config' in model_data:
|
||
model_info['config'] = model_data['config']
|
||
|
||
except Exception as e:
|
||
print(f"读取模型文件失败 {model_file}: {e}")
|
||
# Continue with just the filename-based info
|
||
|
||
# 应用过滤器
|
||
if product_id and model_info.get('product_id') != product_id:
|
||
continue
|
||
if model_type and model_info.get('model_type') != model_type:
|
||
continue
|
||
if store_id and model_info.get('store_id') != store_id:
|
||
continue
|
||
if training_mode and model_info.get('training_mode') != training_mode:
|
||
continue
|
||
|
||
# 添加文件信息
|
||
model_info['filename'] = filename
|
||
model_info['file_path'] = model_file
|
||
model_info['file_size'] = os.path.getsize(model_file)
|
||
model_info['modified_at'] = datetime.fromtimestamp(
|
||
os.path.getmtime(model_file)
|
||
).isoformat()
|
||
|
||
models.append(model_info)
|
||
|
||
except Exception as e:
|
||
print(f"处理模型文件失败 {model_file}: {e}")
|
||
continue
|
||
|
||
# 按创建时间排序(最新的在前)
|
||
models.sort(key=lambda x: x.get('created_at', x.get('modified_at', '')), reverse=True)
|
||
|
||
# 计算分页信息
|
||
total_count = len(models)
|
||
|
||
# 如果没有指定分页参数,返回所有数据
|
||
if page is None or page_size is None:
|
||
return {
|
||
'models': models,
|
||
'pagination': {
|
||
'total': total_count,
|
||
'page': 1,
|
||
'page_size': total_count,
|
||
'total_pages': 1,
|
||
'has_next': False,
|
||
'has_previous': False
|
||
}
|
||
}
|
||
|
||
# 应用分页
|
||
total_pages = (total_count + page_size - 1) // page_size if page_size > 0 else 1
|
||
start_index = (page - 1) * page_size
|
||
end_index = start_index + page_size
|
||
|
||
paginated_models = models[start_index:end_index]
|
||
|
||
return {
|
||
'models': paginated_models,
|
||
'pagination': {
|
||
'total': total_count,
|
||
'page': page,
|
||
'page_size': page_size,
|
||
'total_pages': total_pages,
|
||
'has_next': page < total_pages,
|
||
'has_previous': page > 1
|
||
}
|
||
}
|
||
|
||
def parse_model_filename(self, filename: str) -> Optional[Dict]:
|
||
"""
|
||
解析模型文件名,提取模型信息
|
||
|
||
支持的格式:
|
||
- {model_type}_product_{product_id}_{version}.pth
|
||
- {model_type}_store_{store_id}_{product_id}_{version}.pth
|
||
- {model_type}_global_{product_id}_{aggregation_method}_{version}.pth
|
||
- 旧格式兼容
|
||
"""
|
||
if not filename.endswith('.pth'):
|
||
return None
|
||
|
||
base_name = filename.replace('.pth', '')
|
||
|
||
try:
|
||
# 新格式解析
|
||
if '_product_' in base_name:
|
||
# 产品模式: model_type_product_product_id_version
|
||
parts = base_name.split('_product_')
|
||
model_type = parts[0]
|
||
rest = parts[1]
|
||
|
||
# 分离产品ID和版本
|
||
if '_v' in rest:
|
||
last_v_index = rest.rfind('_v')
|
||
product_id = rest[:last_v_index]
|
||
version = rest[last_v_index+1:]
|
||
else:
|
||
product_id = rest
|
||
version = 'v1'
|
||
|
||
return {
|
||
'model_type': model_type,
|
||
'product_id': product_id,
|
||
'version': version,
|
||
'training_mode': 'product',
|
||
'store_id': None,
|
||
'aggregation_method': None
|
||
}
|
||
|
||
elif '_store_' in base_name:
|
||
# 店铺模式: model_type_store_store_id_product_id_version
|
||
parts = base_name.split('_store_')
|
||
model_type = parts[0]
|
||
rest = parts[1]
|
||
|
||
# 分离店铺ID、产品ID和版本
|
||
rest_parts = rest.split('_')
|
||
if len(rest_parts) >= 3:
|
||
store_id = rest_parts[0]
|
||
if rest_parts[-1].startswith('v'):
|
||
# 最后一部分是版本号
|
||
version = rest_parts[-1]
|
||
product_id = '_'.join(rest_parts[1:-1])
|
||
else:
|
||
version = 'v1'
|
||
product_id = '_'.join(rest_parts[1:])
|
||
|
||
return {
|
||
'model_type': model_type,
|
||
'product_id': product_id,
|
||
'version': version,
|
||
'training_mode': 'store',
|
||
'store_id': store_id,
|
||
'aggregation_method': None
|
||
}
|
||
|
||
elif '_global_' in base_name:
|
||
# 全局模式: model_type_global_product_id_aggregation_method_version
|
||
parts = base_name.split('_global_')
|
||
model_type = parts[0]
|
||
rest = parts[1]
|
||
|
||
rest_parts = rest.split('_')
|
||
if len(rest_parts) >= 3:
|
||
if rest_parts[-1].startswith('v'):
|
||
# 最后一部分是版本号
|
||
version = rest_parts[-1]
|
||
aggregation_method = rest_parts[-2]
|
||
product_id = '_'.join(rest_parts[:-2])
|
||
else:
|
||
version = 'v1'
|
||
aggregation_method = rest_parts[-1]
|
||
product_id = '_'.join(rest_parts[:-1])
|
||
|
||
return {
|
||
'model_type': model_type,
|
||
'product_id': product_id,
|
||
'version': version,
|
||
'training_mode': 'global',
|
||
'store_id': None,
|
||
'aggregation_method': aggregation_method
|
||
}
|
||
|
||
# 兼容以 _model.pth 结尾的格式
|
||
elif base_name.endswith('_model'):
|
||
name_part = base_name.rsplit('_model', 1)[0]
|
||
parts = name_part.split('_')
|
||
# 假设格式为 {product_id}_{...}_{model_type}_{version}
|
||
if len(parts) >= 3:
|
||
version = parts[-1]
|
||
model_type = parts[-2]
|
||
product_id = '_'.join(parts[:-2]) # The rest is product_id + scope
|
||
return {
|
||
'model_type': model_type,
|
||
'product_id': product_id,
|
||
'version': version,
|
||
'training_mode': 'product', # Assumption
|
||
'store_id': None,
|
||
'aggregation_method': None
|
||
}
|
||
|
||
# 兼容旧格式
|
||
else:
|
||
# 尝试解析其他格式
|
||
if 'model_product_' in base_name:
|
||
parts = base_name.split('_model_product_')
|
||
model_type = parts[0]
|
||
product_part = parts[1]
|
||
|
||
if '_v' in product_part:
|
||
last_v_index = product_part.rfind('_v')
|
||
product_id = product_part[:last_v_index]
|
||
version = product_part[last_v_index+1:]
|
||
else:
|
||
product_id = product_part
|
||
version = 'v1'
|
||
|
||
return {
|
||
'model_type': model_type,
|
||
'product_id': product_id,
|
||
'version': version,
|
||
'training_mode': 'product',
|
||
'store_id': None,
|
||
'aggregation_method': None
|
||
}
|
||
|
||
except Exception as e:
|
||
print(f"解析文件名失败 {filename}: {e}")
|
||
|
||
return None
|
||
|
||
def delete_model(self, model_file: str) -> bool:
|
||
"""删除模型文件"""
|
||
try:
|
||
if os.path.exists(model_file):
|
||
os.remove(model_file)
|
||
print(f"已删除模型文件: {model_file}")
|
||
return True
|
||
else:
|
||
print(f"模型文件不存在: {model_file}")
|
||
return False
|
||
except Exception as e:
|
||
print(f"删除模型文件失败: {e}")
|
||
return False
|
||
|
||
def get_model_by_id(self, model_id: str) -> Optional[Dict]:
|
||
"""根据模型ID获取模型信息"""
|
||
models = self.list_models()
|
||
for model in models:
|
||
if model.get('filename', '').replace('.pth', '') == model_id:
|
||
return model
|
||
return None
|
||
|
||
|
||
# 全局模型管理器实例
|
||
# 确保使用项目根目录的saved_models,而不是相对于当前工作目录
|
||
import os
|
||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||
project_root = os.path.dirname(os.path.dirname(current_dir)) # 向上两级到项目根目录
|
||
absolute_model_dir = os.path.join(project_root, 'saved_models')
|
||
model_manager = ModelManager(absolute_model_dir) |