""" 药店销售预测系统 - 全局配置参数 """ import torch import matplotlib matplotlib.use('Agg') # 设置matplotlib后端为Agg,适用于无头服务器环境 import matplotlib.pyplot as plt import os import re import glob # 解决画图中文显示问题 plt.rcParams['font.sans-serif'] = ['SimHei'] plt.rcParams['axes.unicode_minus'] = False # 获取设备(GPU或CPU) def get_device(): """获取可用的计算设备(GPU或CPU)""" if torch.cuda.is_available(): return torch.device('cuda') else: return torch.device('cpu') # 全局设备 DEVICE = get_device() # 数据相关配置 # 使用 os.path.join 构造跨平台的路径 # 修正: 改为相对路径 DEFAULT_DATA_PATH = os.path.join('data', 'timeseries_training_data_sample_10s50p.parquet') DEFAULT_MODEL_DIR = 'saved_models' DEFAULT_FEATURES = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'] # 时间序列参数 LOOK_BACK = 5 # 使用过去5天数据(适应小数据集) FORECAST_HORIZON = 3 # 预测未来3天销量(适应小数据集) # 训练参数 DEFAULT_EPOCHS = 50 # 训练轮次 DEFAULT_BATCH_SIZE = 32 # 批大小 DEFAULT_LEARNING_RATE = 0.001 # 学习率 # 模型参数 NUM_FEATURES = 8 # 输入特征数 EMBED_DIM = 32 # 嵌入维度 DENSE_DIM = 32 # 隐藏层神经元数 NUM_HEADS = 4 # 注意力头数 DROPOUT_RATE = 0.1 # 丢弃率 NUM_BLOCKS = 3 # 编码器解码器数 HIDDEN_SIZE = 64 # 隐藏层大小 NUM_LAYERS = 2 # 层数 # 支持的模型类型 # 支持的模型类型 (v2 - 动态加载) from models.model_registry import TRAINER_REGISTRY SUPPORTED_MODELS = list(TRAINER_REGISTRY.keys()) # 版本管理配置 MODEL_VERSION_PREFIX = 'v' # 版本前缀 DEFAULT_VERSION = 'v1' # 默认版本号 # WebSocket配置 WEBSOCKET_NAMESPACE = '/training' # WebSocket命名空间 TRAINING_UPDATE_INTERVAL = 1 # 训练进度更新间隔(秒) # 创建模型保存目录 os.makedirs(DEFAULT_MODEL_DIR, exist_ok=True) def get_model_file_path(product_id: str, model_type: str, version: str) -> str: """ 根据产品ID、模型类型和版本号,生成模型文件的准确路径。 Args: product_id: 产品ID (纯数字) model_type: 模型类型 version: 版本字符串 (例如 'best', 'final_epoch_50', 'v1_legacy') Returns: 模型文件的完整路径 """ # 处理历史遗留的 "v1" 格式 if version == "v1_legacy": filename = f"{model_type}_model_product_{product_id}.pth" return os.path.join(DEFAULT_MODEL_DIR, filename) # 修正:直接使用唯一的product_id(它可能包含store_前缀)来构建文件名 # 文件名示例: transformer_17002608_epoch_best.pth 或 transformer_store_01010023_epoch_best.pth # 针对 KAN 和 optimized_kan,使用 model_manager 的命名约定 # 统一所有模型的命名格式 filename = f"{model_type}_product_{product_id}_{version}.pth" # 修正:直接在根模型目录查找,不再使用checkpoints子目录 return os.path.join(DEFAULT_MODEL_DIR, filename) def get_model_versions(product_id: str, model_type: str) -> list: """ 获取指定产品和模型类型的所有版本 Args: product_id: 产品ID (现在应该是纯数字ID) model_type: 模型类型 Returns: 版本列表,按版本号排序 """ # 统一使用新的命名约定进行搜索 pattern = os.path.join(DEFAULT_MODEL_DIR, f"{model_type}_product_{product_id}_*.pth") existing_files = glob.glob(pattern) versions = set() for file_path in existing_files: filename = os.path.basename(file_path) # 严格匹配 _v 或 'best' match = re.search(r'_(v\d+|best)\.pth$', filename) if match: versions.add(match.group(1)) # 按数字版本降序排序,'best'始终在最前 def sort_key(v): if v == 'best': return -1 # 'best' is always first if v.startswith('v'): return int(v[1:]) return float('inf') # Should not happen sorted_versions = sorted(list(versions), key=sort_key, reverse=True) return sorted_versions def save_model_version_info(product_id: str, model_type: str, version: str, file_path: str, metrics: dict = None): """ 保存模型版本信息到数据库 Args: product_id: 产品ID model_type: 模型类型 version: 版本号 file_path: 模型文件路径 metrics: 模型性能指标 """ import sqlite3 import json from datetime import datetime try: conn = sqlite3.connect('prediction_history.db') cursor = conn.cursor() # 插入模型版本记录 cursor.execute(''' INSERT INTO model_versions ( product_id, model_type, version, file_path, created_at, metrics, is_active ) VALUES (?, ?, ?, ?, ?, ?, ?) ''', ( product_id, model_type, version, file_path, datetime.now().isoformat(), json.dumps(metrics) if metrics else None, 1 # 新模型默认为激活状态 )) conn.commit() conn.close() print(f"已保存模型版本信息: {product_id}_{model_type}_{version}") except Exception as e: print(f"保存模型版本信息失败: {str(e)}") def get_model_version_info(product_id: str, model_type: str, version: str = None): """ 从数据库获取模型版本信息 Args: product_id: 产品ID model_type: 模型类型 version: 版本号,如果为None则获取最新版本 Returns: 模型版本信息字典 """ import sqlite3 import json try: conn = sqlite3.connect('prediction_history.db') conn.row_factory = sqlite3.Row cursor = conn.cursor() if version: cursor.execute(''' SELECT * FROM model_versions WHERE product_id = ? AND model_type = ? AND version = ? ORDER BY created_at DESC LIMIT 1 ''', (product_id, model_type, version)) else: cursor.execute(''' SELECT * FROM model_versions WHERE product_id = ? AND model_type = ? ORDER BY created_at DESC LIMIT 1 ''', (product_id, model_type)) row = cursor.fetchone() conn.close() if row: result = dict(row) if result['metrics']: result['metrics'] = json.loads(result['metrics']) return result return None except Exception as e: print(f"获取模型版本信息失败: {str(e)}") return None