""" 药店销售预测系统 - 全局配置参数 """ import torch import matplotlib matplotlib.use('Agg') # 设置matplotlib后端为Agg,适用于无头服务器环境 import matplotlib.pyplot as plt import os import re import glob # 解决画图中文显示问题 plt.rcParams['font.sans-serif'] = ['SimHei'] plt.rcParams['axes.unicode_minus'] = False # 获取设备(GPU或CPU) def get_device(): """获取可用的计算设备(GPU或CPU)""" if torch.cuda.is_available(): return torch.device('cuda') else: return torch.device('cpu') # 全局设备 DEVICE = get_device() # 数据相关配置 DEFAULT_DATA_PATH = 'pharmacy_sales.xlsx' DEFAULT_MODEL_DIR = 'saved_models' DEFAULT_FEATURES = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'] # 时间序列参数 LOOK_BACK = 5 # 使用过去5天数据(适应小数据集) FORECAST_HORIZON = 3 # 预测未来3天销量(适应小数据集) # 训练参数 DEFAULT_EPOCHS = 50 # 训练轮次 DEFAULT_BATCH_SIZE = 32 # 批大小 DEFAULT_LEARNING_RATE = 0.001 # 学习率 # 模型参数 NUM_FEATURES = 8 # 输入特征数 EMBED_DIM = 32 # 嵌入维度 DENSE_DIM = 32 # 隐藏层神经元数 NUM_HEADS = 4 # 注意力头数 DROPOUT_RATE = 0.1 # 丢弃率 NUM_BLOCKS = 3 # 编码器解码器数 HIDDEN_SIZE = 64 # 隐藏层大小 NUM_LAYERS = 2 # 层数 # 支持的模型类型 SUPPORTED_MODELS = ['mlstm', 'kan', 'transformer', 'tcn', 'optimized_kan'] # 版本管理配置 MODEL_VERSION_PREFIX = 'v' # 版本前缀 DEFAULT_VERSION = 'v1' # 默认版本号 # WebSocket配置 WEBSOCKET_NAMESPACE = '/training' # WebSocket命名空间 TRAINING_UPDATE_INTERVAL = 1 # 训练进度更新间隔(秒) # 创建模型保存目录 os.makedirs(DEFAULT_MODEL_DIR, exist_ok=True) def get_next_model_version(product_id: str, model_type: str) -> str: """ 获取指定产品和模型类型的下一个版本号 Args: product_id: 产品ID model_type: 模型类型 Returns: 下一个版本号,格式如 'v2', 'v3' 等 """ # 新格式:带版本号的文件 pattern_new = f"{model_type}_model_product_{product_id}_v*.pth" existing_files_new = glob.glob(os.path.join(DEFAULT_MODEL_DIR, pattern_new)) # 旧格式:不带版本号的文件(兼容性支持) pattern_old = f"{model_type}_model_product_{product_id}.pth" old_file_path = os.path.join(DEFAULT_MODEL_DIR, pattern_old) has_old_format = os.path.exists(old_file_path) # 如果没有任何格式的文件,返回默认版本 if not existing_files_new and not has_old_format: return DEFAULT_VERSION # 提取新格式文件的版本号 versions = [] for file_path in existing_files_new: filename = os.path.basename(file_path) version_match = re.search(rf"_v(\d+)\.pth$", filename) if version_match: versions.append(int(version_match.group(1))) # 如果存在旧格式文件,将其视为v1 if has_old_format: versions.append(1) print(f"检测到旧格式模型文件: {old_file_path},将其视为版本v1") if versions: next_version_num = max(versions) + 1 return f"v{next_version_num}" else: return DEFAULT_VERSION def get_model_file_path(product_id: str, model_type: str, version: str = None) -> str: """ 生成模型文件路径 Args: product_id: 产品ID model_type: 模型类型 version: 版本号,如果为None则获取下一个版本 Returns: 模型文件的完整路径 """ if version is None: version = get_next_model_version(product_id, model_type) # 特殊处理v1版本:检查是否存在旧格式文件 if version == "v1": # 检查旧格式文件是否存在 old_format_filename = f"{model_type}_model_product_{product_id}.pth" old_format_path = os.path.join(DEFAULT_MODEL_DIR, old_format_filename) if os.path.exists(old_format_path): print(f"找到旧格式模型文件: {old_format_path},将其作为v1版本") return old_format_path # 使用新格式文件名 filename = f"{model_type}_model_product_{product_id}_{version}.pth" return os.path.join(DEFAULT_MODEL_DIR, filename) def get_model_versions(product_id: str, model_type: str) -> list: """ 获取指定产品和模型类型的所有版本 Args: product_id: 产品ID model_type: 模型类型 Returns: 版本列表,按版本号排序 """ # 新格式:带版本号的文件 pattern_new = f"{model_type}_model_product_{product_id}_v*.pth" existing_files_new = glob.glob(os.path.join(DEFAULT_MODEL_DIR, pattern_new)) # 旧格式:不带版本号的文件(兼容性支持) pattern_old = f"{model_type}_model_product_{product_id}.pth" old_file_path = os.path.join(DEFAULT_MODEL_DIR, pattern_old) has_old_format = os.path.exists(old_file_path) versions = [] # 处理新格式文件 for file_path in existing_files_new: filename = os.path.basename(file_path) version_match = re.search(rf"_v(\d+)\.pth$", filename) if version_match: version_num = int(version_match.group(1)) versions.append(f"v{version_num}") # 如果存在旧格式文件,将其视为v1 if has_old_format: if "v1" not in versions: # 避免重复添加 versions.append("v1") print(f"检测到旧格式模型文件: {old_file_path},将其视为版本v1") # 按版本号排序 versions.sort(key=lambda v: int(v[1:])) return versions def get_latest_model_version(product_id: str, model_type: str) -> str: """ 获取指定产品和模型类型的最新版本 Args: product_id: 产品ID model_type: 模型类型 Returns: 最新版本号,如果没有则返回None """ versions = get_model_versions(product_id, model_type) return versions[-1] if versions else None def save_model_version_info(product_id: str, model_type: str, version: str, file_path: str, metrics: dict = None): """ 保存模型版本信息到数据库 Args: product_id: 产品ID model_type: 模型类型 version: 版本号 file_path: 模型文件路径 metrics: 模型性能指标 """ import sqlite3 import json from datetime import datetime try: conn = sqlite3.connect('prediction_history.db') cursor = conn.cursor() # 插入模型版本记录 cursor.execute(''' INSERT INTO model_versions ( product_id, model_type, version, file_path, created_at, metrics, is_active ) VALUES (?, ?, ?, ?, ?, ?, ?) ''', ( product_id, model_type, version, file_path, datetime.now().isoformat(), json.dumps(metrics) if metrics else None, 1 # 新模型默认为激活状态 )) conn.commit() conn.close() print(f"已保存模型版本信息: {product_id}_{model_type}_{version}") except Exception as e: print(f"保存模型版本信息失败: {str(e)}") def get_model_version_info(product_id: str, model_type: str, version: str = None): """ 从数据库获取模型版本信息 Args: product_id: 产品ID model_type: 模型类型 version: 版本号,如果为None则获取最新版本 Returns: 模型版本信息字典 """ import sqlite3 import json try: conn = sqlite3.connect('prediction_history.db') conn.row_factory = sqlite3.Row cursor = conn.cursor() if version: cursor.execute(''' SELECT * FROM model_versions WHERE product_id = ? AND model_type = ? AND version = ? ORDER BY created_at DESC LIMIT 1 ''', (product_id, model_type, version)) else: cursor.execute(''' SELECT * FROM model_versions WHERE product_id = ? AND model_type = ? ORDER BY created_at DESC LIMIT 1 ''', (product_id, model_type)) row = cursor.fetchone() conn.close() if row: result = dict(row) if result['metrics']: result['metrics'] = json.loads(result['metrics']) return result return None except Exception as e: print(f"获取模型版本信息失败: {str(e)}") return None