""" 药店销售预测系统 - 全局配置参数 """ import torch import matplotlib matplotlib.use('Agg') # 设置matplotlib后端为Agg,适用于无头服务器环境 import matplotlib.pyplot as plt import os import re import glob # 项目根目录 # __file__ 是当前文件 (config.py) 的路径 # os.path.dirname(__file__) 是 server/core # os.path.join(..., '..') 是 server # os.path.join(..., '..', '..') 是项目根目录 PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) # 解决画图中文显示问题 plt.rcParams['font.sans-serif'] = ['SimHei'] plt.rcParams['axes.unicode_minus'] = False # 获取设备(GPU或CPU) def get_device(): """获取可用的计算设备(GPU或CPU)""" if torch.cuda.is_available(): return torch.device('cuda') else: return torch.device('cpu') # 全局设备 DEVICE = get_device() # 数据相关配置 # 使用 os.path.join 构造跨平台的路径 DEFAULT_DATA_PATH = os.path.join(PROJECT_ROOT, 'data', 'timeseries_training_data_sample_10s50p.parquet') DEFAULT_MODEL_DIR = os.path.join(PROJECT_ROOT, 'saved_models') DEFAULT_FEATURES = ['sales', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'] # 时间序列参数 LOOK_BACK = 5 # 使用过去5天数据(适应小数据集) FORECAST_HORIZON = 3 # 预测未来3天销量(适应小数据集) # 训练参数 DEFAULT_EPOCHS = 50 # 训练轮次 DEFAULT_BATCH_SIZE = 32 # 批大小 DEFAULT_LEARNING_RATE = 0.001 # 学习率 # 模型参数 NUM_FEATURES = 8 # 输入特征数 EMBED_DIM = 32 # 嵌入维度 DENSE_DIM = 32 # 隐藏层神经元数 NUM_HEADS = 4 # 注意力头数 DROPOUT_RATE = 0.1 # 丢弃率 NUM_BLOCKS = 3 # 编码器解码器数 HIDDEN_SIZE = 64 # 隐藏层大小 NUM_LAYERS = 2 # 层数 # 支持的模型类型 SUPPORTED_MODELS = ['mlstm', 'kan', 'transformer', 'tcn', 'optimized_kan'] # 版本管理配置 MODEL_VERSION_PREFIX = 'v' # 版本前缀 DEFAULT_VERSION = 'v1' # 默认版本号 # WebSocket配置 WEBSOCKET_NAMESPACE = '/training' # WebSocket命名空间 TRAINING_UPDATE_INTERVAL = 1 # 训练进度更新间隔(秒) # 创建模型保存目录 os.makedirs(DEFAULT_MODEL_DIR, exist_ok=True) # 注意:所有与模型路径、版本管理相关的函数(如 get_next_model_version, get_model_file_path 等) # 已被移除,因为这些功能现在由 server.utils.file_save.ModelPathManager 统一处理。 # 这种集中化管理确保了整个应用程序遵循统一的、基于规范的扁平化文件保存策略。