数据文件保存机构改为### 1.2. 文件存储位置 - **最终产物**: 所有最终模型、元数据文件、损失图等,统一存放在 `saved_models/` 根目录下。 - **过程文件**: 所有训练过程中的检查点文件,统一存放在 `saved_models/checkpoints/` 目录下。 ### 1.3. 文件名生成规则 1. **构建逻辑路径**: 根据训练参数(模式、范围、类型、版本)确定逻辑路径。 - *示例*: `product/P001_all/mlstm/v2` 2. **生成文件名前缀**: 将逻辑路径中的所有 `/` 替换为 `_`。 - *示例*: `product_P001_all_mlstm_v2` 3. **拼接文件后缀**: 在前缀后加上描述文件类型的后缀。 - `_model.pth` - `_loss_curve.png` - `_checkpoint_best.pth` - `_checkpoint_epoch_{N}.pth` #### **完整示例:** - **最终模型**: `saved_models/product_P001_all_mlstm_v2_model.pth` - **最佳检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_best.pth` - **Epoch 50 检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_epoch_50.pth`
22 lines
691 B
Python
22 lines
691 B
Python
"""
|
|
药店销售预测系统 - 模型训练模块
|
|
"""
|
|
|
|
from .mlstm_trainer import train_product_model_with_mlstm
|
|
from .kan_trainer import train_product_model_with_kan
|
|
from .tcn_trainer import train_product_model_with_tcn
|
|
from .transformer_trainer import train_product_model_with_transformer
|
|
from .xgboost_trainer import train_product_model_with_xgboost
|
|
|
|
# 默认训练函数
|
|
from .mlstm_trainer import train_product_model_with_mlstm as train_product_model
|
|
|
|
__all__ = [
|
|
'train_product_model',
|
|
'train_product_model_with_mlstm',
|
|
'train_product_model_with_kan',
|
|
'train_product_model_with_tcn',
|
|
'train_product_model_with_transformer',
|
|
'train_product_model_with_xgboost'
|
|
]
|