278 lines
8.7 KiB
Python
Raw Normal View History

"""
药店销售预测系统 - 全局配置参数
"""
import torch
import matplotlib
matplotlib.use('Agg') # 设置matplotlib后端为Agg适用于无头服务器环境
import matplotlib.pyplot as plt
import os
2025-07-02 11:05:23 +08:00
import re
import glob
# 解决画图中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 获取设备GPU或CPU
def get_device():
"""获取可用的计算设备GPU或CPU"""
if torch.cuda.is_available():
return torch.device('cuda')
else:
return torch.device('cpu')
# 全局设备
DEVICE = get_device()
# 数据相关配置
DEFAULT_DATA_PATH = 'pharmacy_sales.xlsx'
DEFAULT_MODEL_DIR = 'saved_models'
DEFAULT_FEATURES = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
# 时间序列参数
2025-07-02 11:05:23 +08:00
LOOK_BACK = 5 # 使用过去5天数据适应小数据集
FORECAST_HORIZON = 3 # 预测未来3天销量适应小数据集
# 训练参数
DEFAULT_EPOCHS = 50 # 训练轮次
DEFAULT_BATCH_SIZE = 32 # 批大小
DEFAULT_LEARNING_RATE = 0.001 # 学习率
# 模型参数
NUM_FEATURES = 8 # 输入特征数
EMBED_DIM = 32 # 嵌入维度
DENSE_DIM = 32 # 隐藏层神经元数
NUM_HEADS = 4 # 注意力头数
DROPOUT_RATE = 0.1 # 丢弃率
NUM_BLOCKS = 3 # 编码器解码器数
HIDDEN_SIZE = 64 # 隐藏层大小
NUM_LAYERS = 2 # 层数
# 支持的模型类型
SUPPORTED_MODELS = ['mlstm', 'kan', 'transformer', 'tcn', 'optimized_kan']
2025-07-02 11:05:23 +08:00
# 版本管理配置
MODEL_VERSION_PREFIX = 'v' # 版本前缀
DEFAULT_VERSION = 'v1' # 默认版本号
# WebSocket配置
WEBSOCKET_NAMESPACE = '/training' # WebSocket命名空间
TRAINING_UPDATE_INTERVAL = 1 # 训练进度更新间隔(秒)
# 创建模型保存目录
2025-07-02 11:05:23 +08:00
os.makedirs(DEFAULT_MODEL_DIR, exist_ok=True)
def get_next_model_version(product_id: str, model_type: str) -> str:
"""
获取指定产品和模型类型的下一个版本号
Args:
product_id: 产品ID
model_type: 模型类型
Returns:
下一个版本号格式如 'v2', 'v3'
"""
# 新格式:带版本号的文件
pattern_new = f"{model_type}_model_product_{product_id}_v*.pth"
existing_files_new = glob.glob(os.path.join(DEFAULT_MODEL_DIR, pattern_new))
# 旧格式:不带版本号的文件(兼容性支持)
pattern_old = f"{model_type}_model_product_{product_id}.pth"
old_file_path = os.path.join(DEFAULT_MODEL_DIR, pattern_old)
has_old_format = os.path.exists(old_file_path)
# 如果没有任何格式的文件,返回默认版本
if not existing_files_new and not has_old_format:
return DEFAULT_VERSION
# 提取新格式文件的版本号
versions = []
for file_path in existing_files_new:
filename = os.path.basename(file_path)
version_match = re.search(rf"_v(\d+)\.pth$", filename)
if version_match:
versions.append(int(version_match.group(1)))
# 如果存在旧格式文件将其视为v1
if has_old_format:
versions.append(1)
print(f"检测到旧格式模型文件: {old_file_path}将其视为版本v1")
if versions:
next_version_num = max(versions) + 1
return f"v{next_version_num}"
else:
return DEFAULT_VERSION
def get_model_file_path(product_id: str, model_type: str, version: str = None) -> str:
"""
生成模型文件路径
Args:
product_id: 产品ID
model_type: 模型类型
version: 版本号如果为None则获取下一个版本
Returns:
模型文件的完整路径
"""
if version is None:
version = get_next_model_version(product_id, model_type)
# 特殊处理v1版本检查是否存在旧格式文件
if version == "v1":
# 检查旧格式文件是否存在
old_format_filename = f"{model_type}_model_product_{product_id}.pth"
old_format_path = os.path.join(DEFAULT_MODEL_DIR, old_format_filename)
if os.path.exists(old_format_path):
print(f"找到旧格式模型文件: {old_format_path}将其作为v1版本")
return old_format_path
# 使用新格式文件名
filename = f"{model_type}_model_product_{product_id}_{version}.pth"
return os.path.join(DEFAULT_MODEL_DIR, filename)
def get_model_versions(product_id: str, model_type: str) -> list:
"""
获取指定产品和模型类型的所有版本
Args:
product_id: 产品ID
model_type: 模型类型
Returns:
版本列表按版本号排序
"""
# 新格式:带版本号的文件
pattern_new = f"{model_type}_model_product_{product_id}_v*.pth"
existing_files_new = glob.glob(os.path.join(DEFAULT_MODEL_DIR, pattern_new))
# 旧格式:不带版本号的文件(兼容性支持)
pattern_old = f"{model_type}_model_product_{product_id}.pth"
old_file_path = os.path.join(DEFAULT_MODEL_DIR, pattern_old)
has_old_format = os.path.exists(old_file_path)
versions = []
# 处理新格式文件
for file_path in existing_files_new:
filename = os.path.basename(file_path)
version_match = re.search(rf"_v(\d+)\.pth$", filename)
if version_match:
version_num = int(version_match.group(1))
versions.append(f"v{version_num}")
# 如果存在旧格式文件将其视为v1
if has_old_format:
if "v1" not in versions: # 避免重复添加
versions.append("v1")
print(f"检测到旧格式模型文件: {old_file_path}将其视为版本v1")
# 按版本号排序
versions.sort(key=lambda v: int(v[1:]))
return versions
def get_latest_model_version(product_id: str, model_type: str) -> str:
"""
获取指定产品和模型类型的最新版本
Args:
product_id: 产品ID
model_type: 模型类型
Returns:
最新版本号如果没有则返回None
"""
versions = get_model_versions(product_id, model_type)
return versions[-1] if versions else None
def save_model_version_info(product_id: str, model_type: str, version: str, file_path: str, metrics: dict = None):
"""
保存模型版本信息到数据库
Args:
product_id: 产品ID
model_type: 模型类型
version: 版本号
file_path: 模型文件路径
metrics: 模型性能指标
"""
import sqlite3
import json
from datetime import datetime
try:
conn = sqlite3.connect('prediction_history.db')
cursor = conn.cursor()
# 插入模型版本记录
cursor.execute('''
INSERT INTO model_versions (
product_id, model_type, version, file_path, created_at, metrics, is_active
) VALUES (?, ?, ?, ?, ?, ?, ?)
''', (
product_id,
model_type,
version,
file_path,
datetime.now().isoformat(),
json.dumps(metrics) if metrics else None,
1 # 新模型默认为激活状态
))
conn.commit()
conn.close()
print(f"已保存模型版本信息: {product_id}_{model_type}_{version}")
except Exception as e:
print(f"保存模型版本信息失败: {str(e)}")
def get_model_version_info(product_id: str, model_type: str, version: str = None):
"""
从数据库获取模型版本信息
Args:
product_id: 产品ID
model_type: 模型类型
version: 版本号如果为None则获取最新版本
Returns:
模型版本信息字典
"""
import sqlite3
import json
try:
conn = sqlite3.connect('prediction_history.db')
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
if version:
cursor.execute('''
SELECT * FROM model_versions
WHERE product_id = ? AND model_type = ? AND version = ?
ORDER BY created_at DESC LIMIT 1
''', (product_id, model_type, version))
else:
cursor.execute('''
SELECT * FROM model_versions
WHERE product_id = ? AND model_type = ?
ORDER BY created_at DESC LIMIT 1
''', (product_id, model_type))
row = cursor.fetchone()
conn.close()
if row:
result = dict(row)
if result['metrics']:
result['metrics'] = json.loads(result['metrics'])
return result
return None
except Exception as e:
print(f"获取模型版本信息失败: {str(e)}")
return None