**核心目标**: 将新的 `ModelManager` 统一应用到项目中所有剩余的模型训练器,并重构核心调用逻辑,确保整个训练链路的架构一致性。 **1. 修改 `server/trainers/kan_trainer.py`** * **内容**: 完全重写了 `kan_trainer.py`。 * **适配接口**: 函数签名与 `mlstm_trainer` 对齐,增加了 `socketio`, `task_id`, `patience` 等参数。 * **集成 `ModelManager`**: 移除了所有旧的、手动的保存逻辑,改为在训练开始时调用 `model_manager` 获取版本号和路径。 * **标准化产物保存**: 所有产物(模型、元数据、检查点、损失曲线)均通过 `model_manager.save_model_artifact()` 保存。 * **增加健壮性**: 引入了早停(Early Stopping)和保存最佳检查点(Best Checkpoint)的逻辑。 **2. 修改 `server/trainers/tcn_trainer.py`** * **内容**: 完全重写了 `tcn_trainer.py`,应用了与 `kan_trainer` 完全相同的重构模式。 * 移除了旧的 `save_checkpoint` 辅助函数和基于 `core.config` 的版本管理。 * 全面转向使用 `model_manager` 进行版本控制和文件保存。 * 统一了函数签名和进度反馈逻辑。 **3. 修改 `server/trainers/transformer_trainer.py`** * **内容**: 完全重写了 `transformer_trainer.py`,完成了对所有训练器的统一重构。 * 移除了所有遗留的、基于文件名的路径拼接和保存逻辑。 * 实现了与其它训练器一致的、基于 `ModelManager` 的标准化训练流程。 **4. 修改 `server/core/predictor.py`** * **内容**: 对核心预测器类 `PharmacyPredictor` 进行了彻底重构。 * **统一调用接口**: `train_model` 方法现在以完全一致的方式调用所有(`mlstm`, `kan`, `tcn`, `transformer`)训练器。 * **移除旧逻辑**: 删除了 `_parse_model_filename` 等所有基于文件名解析的旧方法。 * **适配 `ModelManager`**: `list_models` 和 `delete_model` 等方法现在直接调用 `model_manager` 的相应功能,不再自己实现逻辑。 * **简化 `predict`**: 预测方法现在直接接收标准化的模型版本路径 (`model_version_path`) 作为输入,逻辑更清晰。
73 lines
2.2 KiB
Python
73 lines
2.2 KiB
Python
"""
|
||
药店销售预测系统 - 全局配置参数
|
||
"""
|
||
|
||
import torch
|
||
import matplotlib
|
||
matplotlib.use('Agg') # 设置matplotlib后端为Agg,适用于无头服务器环境
|
||
import matplotlib.pyplot as plt
|
||
import os
|
||
import re
|
||
import glob
|
||
|
||
# 项目根目录
|
||
# __file__ 是当前文件 (config.py) 的路径
|
||
# os.path.dirname(__file__) 是 server/core
|
||
# os.path.join(..., '..') 是 server
|
||
# os.path.join(..., '..', '..') 是项目根目录
|
||
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
||
|
||
# 解决画图中文显示问题
|
||
plt.rcParams['font.sans-serif'] = ['SimHei']
|
||
plt.rcParams['axes.unicode_minus'] = False
|
||
|
||
# 获取设备(GPU或CPU)
|
||
def get_device():
|
||
"""获取可用的计算设备(GPU或CPU)"""
|
||
if torch.cuda.is_available():
|
||
return torch.device('cuda')
|
||
else:
|
||
return torch.device('cpu')
|
||
|
||
# 全局设备
|
||
DEVICE = get_device()
|
||
|
||
# 数据相关配置
|
||
# 使用 os.path.join 构造跨平台的路径
|
||
DEFAULT_DATA_PATH = os.path.join(PROJECT_ROOT, 'data', 'timeseries_training_data_sample_10s50p.parquet')
|
||
DEFAULT_MODEL_DIR = os.path.join(PROJECT_ROOT, 'saved_models')
|
||
DEFAULT_FEATURES = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
||
|
||
# 时间序列参数
|
||
LOOK_BACK = 5 # 使用过去5天数据(适应小数据集)
|
||
FORECAST_HORIZON = 3 # 预测未来3天销量(适应小数据集)
|
||
|
||
# 训练参数
|
||
DEFAULT_EPOCHS = 50 # 训练轮次
|
||
DEFAULT_BATCH_SIZE = 32 # 批大小
|
||
DEFAULT_LEARNING_RATE = 0.001 # 学习率
|
||
|
||
# 模型参数
|
||
NUM_FEATURES = 8 # 输入特征数
|
||
EMBED_DIM = 32 # 嵌入维度
|
||
DENSE_DIM = 32 # 隐藏层神经元数
|
||
NUM_HEADS = 4 # 注意力头数
|
||
DROPOUT_RATE = 0.1 # 丢弃率
|
||
NUM_BLOCKS = 3 # 编码器解码器数
|
||
HIDDEN_SIZE = 64 # 隐藏层大小
|
||
NUM_LAYERS = 2 # 层数
|
||
|
||
# 支持的模型类型
|
||
SUPPORTED_MODELS = ['mlstm', 'kan', 'transformer', 'tcn', 'optimized_kan']
|
||
|
||
# 版本管理配置
|
||
MODEL_VERSION_PREFIX = 'v' # 版本前缀
|
||
DEFAULT_VERSION = 'v1' # 默认版本号
|
||
|
||
# WebSocket配置
|
||
WEBSOCKET_NAMESPACE = '/training' # WebSocket命名空间
|
||
TRAINING_UPDATE_INTERVAL = 1 # 训练进度更新间隔(秒)
|
||
|
||
# 创建模型保存目录
|
||
os.makedirs(DEFAULT_MODEL_DIR, exist_ok=True)
|