54 lines
1.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
药店销售预测系统 - 全局配置参数
"""
import torch
import matplotlib
matplotlib.use('Agg') # 设置matplotlib后端为Agg适用于无头服务器环境
import matplotlib.pyplot as plt
import os
# 解决画图中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 获取设备GPU或CPU
def get_device():
"""获取可用的计算设备GPU或CPU"""
if torch.cuda.is_available():
return torch.device('cuda')
else:
return torch.device('cpu')
# 全局设备
DEVICE = get_device()
# 数据相关配置
DEFAULT_DATA_PATH = 'pharmacy_sales.xlsx'
DEFAULT_MODEL_DIR = 'saved_models'
DEFAULT_FEATURES = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
# 时间序列参数
LOOK_BACK = 14 # 使用过去14天数据
FORECAST_HORIZON = 7 # 预测未来7天销量
# 训练参数
DEFAULT_EPOCHS = 50 # 训练轮次
DEFAULT_BATCH_SIZE = 32 # 批大小
DEFAULT_LEARNING_RATE = 0.001 # 学习率
# 模型参数
NUM_FEATURES = 8 # 输入特征数
EMBED_DIM = 32 # 嵌入维度
DENSE_DIM = 32 # 隐藏层神经元数
NUM_HEADS = 4 # 注意力头数
DROPOUT_RATE = 0.1 # 丢弃率
NUM_BLOCKS = 3 # 编码器解码器数
HIDDEN_SIZE = 64 # 隐藏层大小
NUM_LAYERS = 2 # 层数
# 支持的模型类型
SUPPORTED_MODELS = ['mlstm', 'kan', 'transformer', 'tcn', 'optimized_kan']
# 创建模型保存目录
os.makedirs(DEFAULT_MODEL_DIR, exist_ok=True)