2025-06-18 06:39:41 +08:00
|
|
|
|
"""
|
|
|
|
|
药店销售预测系统 - 全局配置参数
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import matplotlib
|
|
|
|
|
matplotlib.use('Agg') # 设置matplotlib后端为Agg,适用于无头服务器环境
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
import os
|
2025-07-02 11:05:23 +08:00
|
|
|
|
import re
|
|
|
|
|
import glob
|
2025-06-18 06:39:41 +08:00
|
|
|
|
|
2025-07-15 10:37:25 +08:00
|
|
|
|
# 项目根目录
|
|
|
|
|
# __file__ 是当前文件 (config.py) 的路径
|
|
|
|
|
# os.path.dirname(__file__) 是 server/core
|
|
|
|
|
# os.path.join(..., '..') 是 server
|
|
|
|
|
# os.path.join(..., '..', '..') 是项目根目录
|
|
|
|
|
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
|
|
|
|
|
2025-06-18 06:39:41 +08:00
|
|
|
|
# 解决画图中文显示问题
|
|
|
|
|
plt.rcParams['font.sans-serif'] = ['SimHei']
|
|
|
|
|
plt.rcParams['axes.unicode_minus'] = False
|
|
|
|
|
|
|
|
|
|
# 获取设备(GPU或CPU)
|
|
|
|
|
def get_device():
|
|
|
|
|
"""获取可用的计算设备(GPU或CPU)"""
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
|
return torch.device('cuda')
|
|
|
|
|
else:
|
|
|
|
|
return torch.device('cpu')
|
|
|
|
|
|
|
|
|
|
# 全局设备
|
|
|
|
|
DEVICE = get_device()
|
|
|
|
|
|
|
|
|
|
# 数据相关配置
|
2025-07-15 10:37:25 +08:00
|
|
|
|
# 使用 os.path.join 构造跨平台的路径
|
|
|
|
|
DEFAULT_DATA_PATH = os.path.join(PROJECT_ROOT, 'data', 'timeseries_training_data_sample_10s50p.parquet')
|
|
|
|
|
DEFAULT_MODEL_DIR = os.path.join(PROJECT_ROOT, 'saved_models')
|
2025-07-21 16:38:36 +08:00
|
|
|
|
DEFAULT_FEATURES = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
2025-06-18 06:39:41 +08:00
|
|
|
|
|
|
|
|
|
# 时间序列参数
|
2025-07-02 11:05:23 +08:00
|
|
|
|
LOOK_BACK = 5 # 使用过去5天数据(适应小数据集)
|
|
|
|
|
FORECAST_HORIZON = 3 # 预测未来3天销量(适应小数据集)
|
2025-06-18 06:39:41 +08:00
|
|
|
|
|
|
|
|
|
# 训练参数
|
|
|
|
|
DEFAULT_EPOCHS = 50 # 训练轮次
|
|
|
|
|
DEFAULT_BATCH_SIZE = 32 # 批大小
|
|
|
|
|
DEFAULT_LEARNING_RATE = 0.001 # 学习率
|
|
|
|
|
|
|
|
|
|
# 模型参数
|
|
|
|
|
NUM_FEATURES = 8 # 输入特征数
|
|
|
|
|
EMBED_DIM = 32 # 嵌入维度
|
|
|
|
|
DENSE_DIM = 32 # 隐藏层神经元数
|
|
|
|
|
NUM_HEADS = 4 # 注意力头数
|
|
|
|
|
DROPOUT_RATE = 0.1 # 丢弃率
|
|
|
|
|
NUM_BLOCKS = 3 # 编码器解码器数
|
|
|
|
|
HIDDEN_SIZE = 64 # 隐藏层大小
|
|
|
|
|
NUM_LAYERS = 2 # 层数
|
|
|
|
|
|
|
|
|
|
# 支持的模型类型
|
|
|
|
|
SUPPORTED_MODELS = ['mlstm', 'kan', 'transformer', 'tcn', 'optimized_kan']
|
|
|
|
|
|
2025-07-02 11:05:23 +08:00
|
|
|
|
# 版本管理配置
|
|
|
|
|
MODEL_VERSION_PREFIX = 'v' # 版本前缀
|
|
|
|
|
DEFAULT_VERSION = 'v1' # 默认版本号
|
|
|
|
|
|
|
|
|
|
# WebSocket配置
|
|
|
|
|
WEBSOCKET_NAMESPACE = '/training' # WebSocket命名空间
|
|
|
|
|
TRAINING_UPDATE_INTERVAL = 1 # 训练进度更新间隔(秒)
|
|
|
|
|
|
2025-06-18 06:39:41 +08:00
|
|
|
|
# 创建模型保存目录
|
2025-07-02 11:05:23 +08:00
|
|
|
|
os.makedirs(DEFAULT_MODEL_DIR, exist_ok=True)
|
|
|
|
|
|
2025-07-21 16:38:36 +08:00
|
|
|
|
# 注意:所有与模型路径、版本管理相关的函数(如 get_next_model_version, get_model_file_path 等)
|
|
|
|
|
# 已被移除,因为这些功能现在由 server.utils.file_save.ModelPathManager 统一处理。
|
|
|
|
|
# 这种集中化管理确保了整个应用程序遵循统一的、基于规范的扁平化文件保存策略。
|