""" 药店销售预测系统 - 全局配置参数 """ import torch import matplotlib matplotlib.use('Agg') # 设置matplotlib后端为Agg,适用于无头服务器环境 import matplotlib.pyplot as plt import os # 解决画图中文显示问题 plt.rcParams['font.sans-serif'] = ['SimHei'] plt.rcParams['axes.unicode_minus'] = False # 获取设备(GPU或CPU) def get_device(): """获取可用的计算设备(GPU或CPU)""" if torch.cuda.is_available(): return torch.device('cuda') else: return torch.device('cpu') # 全局设备 DEVICE = get_device() # 数据相关配置 DEFAULT_DATA_PATH = 'pharmacy_sales.xlsx' DEFAULT_MODEL_DIR = 'saved_models' DEFAULT_FEATURES = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'] # 时间序列参数 LOOK_BACK = 14 # 使用过去14天数据 FORECAST_HORIZON = 7 # 预测未来7天销量 # 训练参数 DEFAULT_EPOCHS = 50 # 训练轮次 DEFAULT_BATCH_SIZE = 32 # 批大小 DEFAULT_LEARNING_RATE = 0.001 # 学习率 # 模型参数 NUM_FEATURES = 8 # 输入特征数 EMBED_DIM = 32 # 嵌入维度 DENSE_DIM = 32 # 隐藏层神经元数 NUM_HEADS = 4 # 注意力头数 DROPOUT_RATE = 0.1 # 丢弃率 NUM_BLOCKS = 3 # 编码器解码器数 HIDDEN_SIZE = 64 # 隐藏层大小 NUM_LAYERS = 2 # 层数 # 支持的模型类型 SUPPORTED_MODELS = ['mlstm', 'kan', 'transformer', 'tcn', 'optimized_kan'] # 创建模型保存目录 os.makedirs(DEFAULT_MODEL_DIR, exist_ok=True)