**核心目标**: 将新的 `ModelManager` 统一应用到项目中所有剩余的模型训练器,并重构核心调用逻辑,确保整个训练链路的架构一致性。 **1. 修改 `server/trainers/kan_trainer.py`** * **内容**: 完全重写了 `kan_trainer.py`。 * **适配接口**: 函数签名与 `mlstm_trainer` 对齐,增加了 `socketio`, `task_id`, `patience` 等参数。 * **集成 `ModelManager`**: 移除了所有旧的、手动的保存逻辑,改为在训练开始时调用 `model_manager` 获取版本号和路径。 * **标准化产物保存**: 所有产物(模型、元数据、检查点、损失曲线)均通过 `model_manager.save_model_artifact()` 保存。 * **增加健壮性**: 引入了早停(Early Stopping)和保存最佳检查点(Best Checkpoint)的逻辑。 **2. 修改 `server/trainers/tcn_trainer.py`** * **内容**: 完全重写了 `tcn_trainer.py`,应用了与 `kan_trainer` 完全相同的重构模式。 * 移除了旧的 `save_checkpoint` 辅助函数和基于 `core.config` 的版本管理。 * 全面转向使用 `model_manager` 进行版本控制和文件保存。 * 统一了函数签名和进度反馈逻辑。 **3. 修改 `server/trainers/transformer_trainer.py`** * **内容**: 完全重写了 `transformer_trainer.py`,完成了对所有训练器的统一重构。 * 移除了所有遗留的、基于文件名的路径拼接和保存逻辑。 * 实现了与其它训练器一致的、基于 `ModelManager` 的标准化训练流程。 **4. 修改 `server/core/predictor.py`** * **内容**: 对核心预测器类 `PharmacyPredictor` 进行了彻底重构。 * **统一调用接口**: `train_model` 方法现在以完全一致的方式调用所有(`mlstm`, `kan`, `tcn`, `transformer`)训练器。 * **移除旧逻辑**: 删除了 `_parse_model_filename` 等所有基于文件名解析的旧方法。 * **适配 `ModelManager`**: `list_models` 和 `delete_model` 等方法现在直接调用 `model_manager` 的相应功能,不再自己实现逻辑。 * **简化 `predict`**: 预测方法现在直接接收标准化的模型版本路径 (`model_version_path`) 作为输入,逻辑更清晰。
11 KiB
根文件夹:save_models
新模型文件系统设计
我们已经从“一个文件包含所有信息”的模式,转向了“目录结构本身就是信息”的模式。
基本结构:
saved_models/ ├── product/ │ ├── all/ │ │ ├── MLSTM/ │ │ │ ├── v1/ │ │ │ │ ├── model.pth │ │ │ │ ├── metadata.json │ │ │ │ ├── loss_curve.png │ │ │ │ └── checkpoint_best.pth │ │ │ └── v2/ │ │ │ └── ... │ │ └── TCN/ │ │ └── v1/ │ │ └── ... │ └── {product_id}/ │ └── ... │ ├── user/ │ └── ... │ └── versions.json
txt
关键点解读:
versions.json: 这是整个系统的“注册表”。它记录了每一种模型(由mode, scope, type唯一确定)的最新版本号。所有新的训练任务都会先读取这个文件来确定下一个版本号应该是多少,从而避免了冲突。 目录路径: 模型的路径现在包含了它的核心元数据。例如,路径 saved_models/product/all/MLSTM/v1 清晰地告诉我们: 训练模式 (Mode): product (产品模式) 范围 (Scope): all (适用于所有产品) 模型类型 (Type): MLSTM 版本 (Version): v1 版本目录内容: 每个版本目录(如 v1/)下都包含了一次完整训练的所有产物,并且文件名是标准化的: model.pth: 最终保存的、用于预测的模型。 metadata.json: 包含训练参数、数据标准化scaler对象等重要元数据。 loss_curve.png: 训练过程中的损失曲线图。 checkpoint_best.pth: 训练过程中验证集上表现最好的模型检查点。
按药品训练
1.创建 product 文件夹 2.选择药品 product下创建药品id 文件夹,根据数据范围加上相应的后缀,聚合所有店铺all,指定店铺就店铺id 3.模型类型 对应的文件下创建模型名称的文件夹 4.在模型名称的文件夹下,版本文件夹version+第几次训练 5.在版本文件下存储对应的检查点文件,最终模型文件,损失曲线图
按店铺训练
1.创建 store 文件夹 2.选择店铺 store下创建店铺id 文件夹,根据药品范围加上相应的后缀,所有药品all,指定药品就药品id 3.模型类型 对应的文件下创建模型名称的文件夹 4.在模型名称的文件夹下,版本文件夹version+第几次训练 5.在版本文件下存储对应的检查点文件,最终模型文件,损失曲线图
按全局训练
1.创建 global 文件夹 2.选择训练范围时 创建文件夹根据数据范围,所有店铺所有药品为all,选择店铺就店铺id,选择药品就药品id ,自定义范围就根据下面的店铺id创建,再在店铺id文件夹下创建对应的药品id文件夹 3.聚合方式 根据聚合方式创建对应的文件 4.模型类型 对应的文件下创建模型名称的文件夹 5.在模型名称的文件夹下,版本文件夹version+第几次训练 6.在版本文件下存储对应的检查点文件,最终模型文件,损失曲线图
优化后模型保存规则分析总结
与当前系统中将模型信息编码到文件名并将文件存储在相对扁平目录中的做法相比,新规则引入了一套更具结构化和层级化的模型保存策略。这种优化旨在提高模型文件的可管理性、可追溯性和可扩展性。
核心思想
优化后的核心思想是**“目录即元数据”**。通过创建层级分明的目录结构,将模型的训练模式、范围、类型和版本等关键信息体现在目录路径中,而不是仅仅依赖于文件名。所有与单次训练相关的产物(最终模型、检查点、损失曲线图等)都将被统一存放在同一个版本文件夹下,便于管理和溯源。
统一根目录
所有模型都将保存在 saved_models
文件夹下。
优化后的目录结构规则
1. 按药品训练 (Product Training)
- 目录结构:
saved_models/product/{product_id}_{scope}/{model_type}/v{N}/
- 路径解析:
product
: 表示这是按“药品”为核心的训练模式。{product_id}_{scope}
:{product_id}
: 训练的药品ID 。{scope}
: 数据的店铺范围。all
: 使用所有店铺的聚合数据。{store_id}
: 使用指定店铺的数据。
{model_type}
: 模型的类型 (例如mlstm
,transformer
)。v{N}
: 模型的版本号 (例如v1
,v2
)。
- 文件夹内容:
- 最终模型文件 (例如
model_final.pth
) - 训练检查点文件 (例如
checkpoint_epoch_10.pth
,checkpoint_best.pth
) - 损失曲线图 (例如
loss_curve.png
)
- 最终模型文件 (例如
2. 按店铺训练 (Store Training)
- 目录结构:
saved_models/store/{store_id}_{scope}/{model_type}/v{N}/
- 路径解析:
store
: 表示这是按“店铺”为核心的训练模式。{store_id}_{scope}
:{store_id}
: 训练的店铺ID 。{scope}
: 数据的药品范围。all
: 使用该店铺所有药品的聚合数据。{product_id}
: 使用该店铺指定药品
v{N}
: 模型的版本号。
- 文件夹内容: 与“按药品训练”模式相同。
3. 全局训练 (Global Training)
- 目录结构:
saved_models/global/{scope_path}/{aggregation_method}/{model_type}/v{N}/
- 路径解析:
global
: 表示这是“全局”训练模式。{scope_path}
: 描述训练所用数据的范围,结构比较灵活:all
: 代表所有店铺的所有药品。stores/{store_id}
: 代表选择了特定的店铺。products/{product_id}
: 代表选择了特定的药品。custom/{store_id}/{product_id}
: 代表自定义范围,同时指定了店铺和药品。
{aggregation_method}
: 数据的聚合方式 (例如sum
,mean
)。{model_type}
: 模型的类型。v{N}
: 模型的版本号。
- 文件夹内容: 与“按药品训练”模式相同。
总结
总的来说,优化后的规则通过一个清晰、自解释的目录结构,系统化地组织了所有训练产物。这不仅使得查找和管理特定模型变得极为方便,也为未来的自动化模型管理和部署流程奠定了坚实的基础。
优化规则下的详细文件保存、读取及数据库记录规范
基于优化后的目录结构规则,我们进一步定义详细的文件保存、读取、数据库记录及版本管理的具体规范。
一、 详细文件保存路径规则
所有训练产物都保存在对应模型的版本文件夹内,并采用统一的命名约定。
- 最终模型文件:
model.pth
- 最佳性能检查点:
checkpoint_best.pth
- 定期检查点:
checkpoint_epoch_{epoch_number}.pth
(例如:checkpoint_epoch_50.pth
) - 损失曲线图:
loss_curve.png
- 训练元数据:
metadata.json
(包含训练参数、指标等详细信息)
示例路径:
-
按药品训练 (P001, 所有店铺, mlstm, v2):
- 目录:
saved_models/product/P001_all/mlstm/v2/
- 最终模型:
saved_models/product/P001_all/mlstm/v2/model.pth
- 损失曲线:
saved_models/product/P001_all/mlstm/v2/loss_curve.png
- 目录:
-
按店铺训练 (S001, 指定药品P002, transformer, v1):
- 目录:
saved_models/store/S001_P002/transformer/v1/
- 最终模型:
saved_models/store/S001_P002/transformer/v1/model.pth
- 目录:
-
全局训练 (所有数据, sum聚合, kan, v5):
- 目录:
saved_models/global/all/sum/kan/v5/
- 最终模型:
saved_models/global/all/sum/kan/v5/model.pth
- 目录:
二、 文件读取规则
读取模型或其产物时,首先根据模型的元数据构建其版本目录路径,然后在该目录内定位具体文件。
读取逻辑:
-
确定模型元数据:
- 训练模式 (
product
,store
,global
) - 范围 (
{product_id}_{scope}
,{store_id}_{scope}
,{scope_path}
) - 聚合方式 (仅全局模式)
- 模型类型 (
mlstm
,kan
, etc.) - 版本号 (
v{N}
)
- 训练模式 (
-
构建模型根目录路径: 根据上述元数据拼接路径。
- 示例: 要读取“店铺S001下P002药品的transformer模型v1”,构建路径
saved_models/store/S001_P002/transformer/v1/
。
- 示例: 要读取“店铺S001下P002药品的transformer模型v1”,构建路径
-
定位具体文件: 在构建好的目录下直接读取所需文件。
- 加载最终模型: 读取
model.pth
。 - 加载最佳模型: 读取
checkpoint_best.pth
。 - 查看损失曲线: 读取
loss_curve.png
。
- 加载最终模型: 读取
三、 数据库保存规则
数据库的核心职责是索引模型,而不是存储冗余信息。因此,数据库中只保存足以定位到模型版本目录的路径信息。
model_versions
表结构优化:
字段名 | 类型 | 描述 | 示例 |
---|---|---|---|
id |
INTEGER | 主键 | 1 |
model_identifier |
TEXT | 模型的唯一标识符,由模式和范围构成 | product_P001_all |
model_type |
TEXT | 模型类型 | mlstm |
version |
TEXT | 版本号 | v2 |
model_path |
TEXT | 模型版本目录的相对路径 | saved_models/product/P001_all/mlstm/v2/ |
created_at |
TEXT | 创建时间 | 2025-07-15 18:40:00 |
metrics_summary |
TEXT | 关键性能指标的JSON字符串 | {"rmse": 10.5, "r2": 0.89} |
保存逻辑:
- 当一次训练成功完成并生成版本
v{N}
后,向model_versions
表中插入一条新记录。 model_path
字段只记录到版本目录,如saved_models/product/P001_all/mlstm/v2/
。应用程序根据此路径和标准文件名(如model.pth
)来加载具体文件。
四、 版本记录文件规则
为了快速、方便地获取和递增版本号,在 saved_models
根目录下维护一个版本记录文件。
- 文件名:
versions.json
- 位置:
saved_models/versions.json
- 结构: 一个JSON对象,
key
是模型的唯一标识符,value
是该模型的最新版本号 (整数)。
versions.json
示例:
{
"product_P001_all_mlstm": 2,
"store_S001_P002_transformer": 1,
"global_all_sum_kan": 5
}
版本管理流程:
-
获取下一个版本号:
- 在开始新训练前,根据训练参数构建模型的唯一标识符 (例如
product_P001_all_mlstm
)。 - 读取
saved_models/versions.json
文件。 - 查找对应的
key
,获取当前最新版本号。如果key
不存在,则当前版本为 0。 - 下一个版本号即为
当前版本号 + 1
。
- 在开始新训练前,根据训练参数构建模型的唯一标识符 (例如
-
更新版本号:
- 训练成功后,将新的版本号写回到
saved_models/versions.json
文件中,更新对应key
的value
。 - 这个过程需要加锁以防止并发训练时出现版本号冲突。
- 训练成功后,将新的版本号写回到