54 lines
1.5 KiB
Python
54 lines
1.5 KiB
Python
![]() |
"""
|
|||
|
药店销售预测系统 - 全局配置参数
|
|||
|
"""
|
|||
|
|
|||
|
import torch
|
|||
|
import matplotlib
|
|||
|
matplotlib.use('Agg') # 设置matplotlib后端为Agg,适用于无头服务器环境
|
|||
|
import matplotlib.pyplot as plt
|
|||
|
import os
|
|||
|
|
|||
|
# 解决画图中文显示问题
|
|||
|
plt.rcParams['font.sans-serif'] = ['SimHei']
|
|||
|
plt.rcParams['axes.unicode_minus'] = False
|
|||
|
|
|||
|
# 获取设备(GPU或CPU)
|
|||
|
def get_device():
|
|||
|
"""获取可用的计算设备(GPU或CPU)"""
|
|||
|
if torch.cuda.is_available():
|
|||
|
return torch.device('cuda')
|
|||
|
else:
|
|||
|
return torch.device('cpu')
|
|||
|
|
|||
|
# 全局设备
|
|||
|
DEVICE = get_device()
|
|||
|
|
|||
|
# 数据相关配置
|
|||
|
DEFAULT_DATA_PATH = 'pharmacy_sales.xlsx'
|
|||
|
DEFAULT_MODEL_DIR = 'saved_models'
|
|||
|
DEFAULT_FEATURES = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
|||
|
|
|||
|
# 时间序列参数
|
|||
|
LOOK_BACK = 14 # 使用过去14天数据
|
|||
|
FORECAST_HORIZON = 7 # 预测未来7天销量
|
|||
|
|
|||
|
# 训练参数
|
|||
|
DEFAULT_EPOCHS = 50 # 训练轮次
|
|||
|
DEFAULT_BATCH_SIZE = 32 # 批大小
|
|||
|
DEFAULT_LEARNING_RATE = 0.001 # 学习率
|
|||
|
|
|||
|
# 模型参数
|
|||
|
NUM_FEATURES = 8 # 输入特征数
|
|||
|
EMBED_DIM = 32 # 嵌入维度
|
|||
|
DENSE_DIM = 32 # 隐藏层神经元数
|
|||
|
NUM_HEADS = 4 # 注意力头数
|
|||
|
DROPOUT_RATE = 0.1 # 丢弃率
|
|||
|
NUM_BLOCKS = 3 # 编码器解码器数
|
|||
|
HIDDEN_SIZE = 64 # 隐藏层大小
|
|||
|
NUM_LAYERS = 2 # 层数
|
|||
|
|
|||
|
# 支持的模型类型
|
|||
|
SUPPORTED_MODELS = ['mlstm', 'kan', 'transformer', 'tcn', 'optimized_kan']
|
|||
|
|
|||
|
# 创建模型保存目录
|
|||
|
os.makedirs(DEFAULT_MODEL_DIR, exist_ok=True)
|