**版本**: 4.0 (最终版) **核心思想**: 逻辑路径被转换为文件名的一部分,实现极致扁平化的文件存储。 --- ## 一、 文件保存规则 ### 1.1. 核心原则 所有元数据都被编码到文件名中。一个逻辑上的层级路径(例如 `product/P001_all/mlstm/v2`)应该被转换为一个用下划线连接的文件名前缀(`product_P001_all_mlstm_v2`)。 ### 1.2. 文件存储位置 - **最终产物**: 所有最终模型、元数据文件、损失图等,统一存放在 `saved_models/` 根目录下。 - **过程文件**: 所有训练过程中的检查点文件,统一存放在 `saved_models/checkpoints/` 目录下。 ### 1.3. 文件名生成规则 1. **构建逻辑路径**: 根据训练参数(模式、范围、类型、版本)确定逻辑路径。 - *示例*: `product/P001_all/mlstm/v2` 2. **生成文件名前缀**: 将逻辑路径中的所有 `/` 替换为 `_`。 - *示例*: `product_P001_all_mlstm_v2` 3. **拼接文件后缀**: 在前缀后加上描述文件类型的后缀。 - `_model.pth` - `_metadata.json` - `_loss_curve.png` - `_checkpoint_best.pth` - `_checkpoint_epoch_{N}.pth` #### **完整示例:** - **最终模型**: `saved_models/product_P001_all_mlstm_v2_model.pth` - **元数据**: `saved_models/product_P001_all_mlstm_v2_metadata.json` - **最佳检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_best.pth` - **Epoch 50 检查点**: `saved_models/checkpoints/product_P001_all_mlstm_v2_checkpoint_epoch_50.pth` --- ## 二、 文件读取规则 1. **确定模型元数据**: 根据需求确定要加载的模型的训练模式、范围、类型和版本。 2. **构建文件名前缀**: 按照与保存时相同的逻辑,将元数据拼接成文件名前缀(例如 `product_P001_all_mlstm_v2`)。 3. **定位文件**: - 要加载最终模型,查找文件: `saved_models/{prefix}_model.pth`。 - 要加载最佳检查点,查找文件: `saved_models/checkpoints/{prefix}_checkpoint_best.pth`。 --- ## 三、 数据库存储规则 数据库用于索引,应存储足以重构文件名前缀的关键元数据。 #### **`models` 表结构建议:** | 字段名 | 类型 | 描述 | 示例 | | :--- | :--- | :--- | :--- | | `id` | INTEGER | 主键 | 1 | | `filename_prefix` | TEXT | **完整文件名前缀,可作为唯一标识** | `product_P001_all_mlstm_v2` | | `model_identifier`| TEXT | 用于版本控制的标识符 (不含版本) | `product_P001_all_mlstm` | | `version` | INTEGER | 版本号 | `2` | | `status` | TEXT | 模型状态 | `completed`, `training`, `failed` | | `created_at` | TEXT | 创建时间 | `2025-07-21 02:29:00` | | `metrics_summary`| TEXT | 关键性能指标的JSON字符串 | `{"rmse": 10.5, "r2": 0.89}` | #### **保存逻辑:** - 训练完成后,向表中插入一条记录。`filename_prefix` 字段是查找与该次训练相关的所有文件的关键。 --- ## 四、 版本记录规则 版本管理依赖于根目录下的 `versions.json` 文件,以实现原子化、线程安全的版本号递增。 - **文件名**: `versions.json` - **位置**: `saved_models/versions.json` - **结构**: 一个JSON对象,`key` 是不包含版本号的标识符,`value` 是该标识符下最新的版本号(整数)。 - **Key**: `{prefix_core}_{model_type}` (例如: `product_P001_all_mlstm`) - **Value**: `Integer` #### **`versions.json` 示例:** ```json { "product_P001_all_mlstm": 2, "store_S001_P002_transformer": 1 } ``` #### **版本管理流程:** 1. **获取新版本**: 开始训练前,构建 `key`。读取 `versions.json`,找到对应 `key` 的 `value`。新版本号为 `value + 1` (若key不存在,则为 `1`)。 2. **更新版本**: 训练成功后,将新的版本号写回到 `versions.json`。此过程**必须使用文件锁**以防止并发冲突。 调试完成药品预测和店铺预测
76 lines
2.5 KiB
Python
76 lines
2.5 KiB
Python
"""
|
||
药店销售预测系统 - 全局配置参数
|
||
"""
|
||
|
||
import torch
|
||
import matplotlib
|
||
matplotlib.use('Agg') # 设置matplotlib后端为Agg,适用于无头服务器环境
|
||
import matplotlib.pyplot as plt
|
||
import os
|
||
import re
|
||
import glob
|
||
|
||
# 项目根目录
|
||
# __file__ 是当前文件 (config.py) 的路径
|
||
# os.path.dirname(__file__) 是 server/core
|
||
# os.path.join(..., '..') 是 server
|
||
# os.path.join(..., '..', '..') 是项目根目录
|
||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
||
|
||
# 解决画图中文显示问题
|
||
plt.rcParams['font.sans-serif'] = ['SimHei']
|
||
plt.rcParams['axes.unicode_minus'] = False
|
||
|
||
# 获取设备(GPU或CPU)
|
||
def get_device():
|
||
"""获取可用的计算设备(GPU或CPU)"""
|
||
if torch.cuda.is_available():
|
||
return torch.device('cuda')
|
||
else:
|
||
return torch.device('cpu')
|
||
|
||
# 全局设备
|
||
DEVICE = get_device()
|
||
|
||
# 数据相关配置
|
||
# 使用 os.path.join 构造跨平台的路径
|
||
DEFAULT_DATA_PATH = os.path.join(PROJECT_ROOT, 'data', 'timeseries_training_data_sample_10s50p.parquet')
|
||
DEFAULT_MODEL_DIR = os.path.join(PROJECT_ROOT, 'saved_models')
|
||
DEFAULT_FEATURES = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||
|
||
# 时间序列参数
|
||
LOOK_BACK = 5 # 使用过去5天数据(适应小数据集)
|
||
FORECAST_HORIZON = 3 # 预测未来3天销量(适应小数据集)
|
||
|
||
# 训练参数
|
||
DEFAULT_EPOCHS = 50 # 训练轮次
|
||
DEFAULT_BATCH_SIZE = 32 # 批大小
|
||
DEFAULT_LEARNING_RATE = 0.001 # 学习率
|
||
|
||
# 模型参数
|
||
NUM_FEATURES = 8 # 输入特征数
|
||
EMBED_DIM = 32 # 嵌入维度
|
||
DENSE_DIM = 32 # 隐藏层神经元数
|
||
NUM_HEADS = 4 # 注意力头数
|
||
DROPOUT_RATE = 0.1 # 丢弃率
|
||
NUM_BLOCKS = 3 # 编码器解码器数
|
||
HIDDEN_SIZE = 64 # 隐藏层大小
|
||
NUM_LAYERS = 2 # 层数
|
||
|
||
# 支持的模型类型
|
||
SUPPORTED_MODELS = ['mlstm', 'kan', 'transformer', 'tcn', 'optimized_kan']
|
||
|
||
# 版本管理配置
|
||
MODEL_VERSION_PREFIX = 'v' # 版本前缀
|
||
DEFAULT_VERSION = 'v1' # 默认版本号
|
||
|
||
# WebSocket配置
|
||
WEBSOCKET_NAMESPACE = '/training' # WebSocket命名空间
|
||
TRAINING_UPDATE_INTERVAL = 1 # 训练进度更新间隔(秒)
|
||
|
||
# 创建模型保存目录
|
||
os.makedirs(DEFAULT_MODEL_DIR, exist_ok=True)
|
||
|
||
# 注意:所有与模型路径、版本管理相关的函数(如 get_next_model_version, get_model_file_path 等)
|
||
# 已被移除,因为这些功能现在由 server.utils.file_save.ModelPathManager 统一处理。
|
||
# 这种集中化管理确保了整个应用程序遵循统一的、基于规范的扁平化文件保存策略。 |