2025-06-18 06:39:41 +08:00
|
|
|
|
"""
|
|
|
|
|
药店销售预测系统 - 全局配置参数
|
|
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import matplotlib
|
|
|
|
|
matplotlib.use('Agg') # 设置matplotlib后端为Agg,适用于无头服务器环境
|
|
|
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
import os
|
2025-07-02 11:05:23 +08:00
|
|
|
|
import re
|
|
|
|
|
import glob
|
2025-06-18 06:39:41 +08:00
|
|
|
|
|
2025-07-15 10:37:25 +08:00
|
|
|
|
# 项目根目录
|
|
|
|
|
# __file__ 是当前文件 (config.py) 的路径
|
|
|
|
|
# os.path.dirname(__file__) 是 server/core
|
|
|
|
|
# os.path.join(..., '..') 是 server
|
|
|
|
|
# os.path.join(..., '..', '..') 是项目根目录
|
|
|
|
|
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
|
|
|
|
|
|
2025-06-18 06:39:41 +08:00
|
|
|
|
# 解决画图中文显示问题
|
|
|
|
|
plt.rcParams['font.sans-serif'] = ['SimHei']
|
|
|
|
|
plt.rcParams['axes.unicode_minus'] = False
|
|
|
|
|
|
|
|
|
|
# 获取设备(GPU或CPU)
|
|
|
|
|
def get_device():
|
|
|
|
|
"""获取可用的计算设备(GPU或CPU)"""
|
|
|
|
|
if torch.cuda.is_available():
|
|
|
|
|
return torch.device('cuda')
|
|
|
|
|
else:
|
|
|
|
|
return torch.device('cpu')
|
|
|
|
|
|
|
|
|
|
# 全局设备
|
|
|
|
|
DEVICE = get_device()
|
|
|
|
|
|
|
|
|
|
# 数据相关配置
|
2025-07-15 10:37:25 +08:00
|
|
|
|
# 使用 os.path.join 构造跨平台的路径
|
|
|
|
|
DEFAULT_DATA_PATH = os.path.join(PROJECT_ROOT, 'data', 'timeseries_training_data_sample_10s50p.parquet')
|
|
|
|
|
DEFAULT_MODEL_DIR = os.path.join(PROJECT_ROOT, 'saved_models')
|
2025-06-18 06:39:41 +08:00
|
|
|
|
DEFAULT_FEATURES = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
|
|
|
|
|
|
|
|
|
|
# 时间序列参数
|
2025-07-02 11:05:23 +08:00
|
|
|
|
LOOK_BACK = 5 # 使用过去5天数据(适应小数据集)
|
|
|
|
|
FORECAST_HORIZON = 3 # 预测未来3天销量(适应小数据集)
|
2025-06-18 06:39:41 +08:00
|
|
|
|
|
|
|
|
|
# 训练参数
|
|
|
|
|
DEFAULT_EPOCHS = 50 # 训练轮次
|
|
|
|
|
DEFAULT_BATCH_SIZE = 32 # 批大小
|
|
|
|
|
DEFAULT_LEARNING_RATE = 0.001 # 学习率
|
|
|
|
|
|
|
|
|
|
# 模型参数
|
|
|
|
|
NUM_FEATURES = 8 # 输入特征数
|
|
|
|
|
EMBED_DIM = 32 # 嵌入维度
|
|
|
|
|
DENSE_DIM = 32 # 隐藏层神经元数
|
|
|
|
|
NUM_HEADS = 4 # 注意力头数
|
|
|
|
|
DROPOUT_RATE = 0.1 # 丢弃率
|
|
|
|
|
NUM_BLOCKS = 3 # 编码器解码器数
|
|
|
|
|
HIDDEN_SIZE = 64 # 隐藏层大小
|
|
|
|
|
NUM_LAYERS = 2 # 层数
|
|
|
|
|
|
|
|
|
|
# 支持的模型类型
|
|
|
|
|
SUPPORTED_MODELS = ['mlstm', 'kan', 'transformer', 'tcn', 'optimized_kan']
|
|
|
|
|
|
2025-07-02 11:05:23 +08:00
|
|
|
|
# 版本管理配置
|
|
|
|
|
MODEL_VERSION_PREFIX = 'v' # 版本前缀
|
|
|
|
|
DEFAULT_VERSION = 'v1' # 默认版本号
|
|
|
|
|
|
|
|
|
|
# WebSocket配置
|
|
|
|
|
WEBSOCKET_NAMESPACE = '/training' # WebSocket命名空间
|
|
|
|
|
TRAINING_UPDATE_INTERVAL = 1 # 训练进度更新间隔(秒)
|
|
|
|
|
|
2025-06-18 06:39:41 +08:00
|
|
|
|
# 创建模型保存目录
|
2025-07-02 11:05:23 +08:00
|
|
|
|
os.makedirs(DEFAULT_MODEL_DIR, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
2025-07-15 20:06:17 +08:00
|
|
|
|
def get_model_file_path(product_id: str, model_type: str, version: str) -> str:
|
2025-07-02 11:05:23 +08:00
|
|
|
|
"""
|
2025-07-15 20:06:17 +08:00
|
|
|
|
根据产品ID、模型类型和版本号,生成模型文件的准确路径。
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
|
|
|
|
Args:
|
2025-07-15 20:06:17 +08:00
|
|
|
|
product_id: 产品ID (纯数字)
|
2025-07-02 11:05:23 +08:00
|
|
|
|
model_type: 模型类型
|
2025-07-15 20:06:17 +08:00
|
|
|
|
version: 版本字符串 (例如 'best', 'final_epoch_50', 'v1_legacy')
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
模型文件的完整路径
|
|
|
|
|
"""
|
2025-07-15 20:06:17 +08:00
|
|
|
|
# 处理历史遗留的 "v1" 格式
|
|
|
|
|
if version == "v1_legacy":
|
|
|
|
|
filename = f"{model_type}_model_product_{product_id}.pth"
|
|
|
|
|
return os.path.join(DEFAULT_MODEL_DIR, filename)
|
|
|
|
|
|
2025-07-16 16:24:08 +08:00
|
|
|
|
# 修正:直接使用唯一的product_id(它可能包含store_前缀)来构建文件名
|
|
|
|
|
# 文件名示例: transformer_17002608_epoch_best.pth 或 transformer_store_01010023_epoch_best.pth
|
2025-07-17 17:54:53 +08:00
|
|
|
|
# 针对 KAN 和 optimized_kan,使用 model_manager 的命名约定
|
2025-07-18 13:14:34 +08:00
|
|
|
|
# 统一所有模型的命名格式
|
|
|
|
|
filename = f"{model_type}_product_{product_id}_{version}.pth"
|
2025-07-16 16:24:08 +08:00
|
|
|
|
# 修正:直接在根模型目录查找,不再使用checkpoints子目录
|
|
|
|
|
return os.path.join(DEFAULT_MODEL_DIR, filename)
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
|
|
|
|
def get_model_versions(product_id: str, model_type: str) -> list:
|
|
|
|
|
"""
|
|
|
|
|
获取指定产品和模型类型的所有版本
|
|
|
|
|
|
|
|
|
|
Args:
|
2025-07-15 20:06:17 +08:00
|
|
|
|
product_id: 产品ID (现在应该是纯数字ID)
|
2025-07-02 11:05:23 +08:00
|
|
|
|
model_type: 模型类型
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
版本列表,按版本号排序
|
|
|
|
|
"""
|
2025-07-18 13:14:34 +08:00
|
|
|
|
# 统一使用新的命名约定进行搜索
|
|
|
|
|
pattern = os.path.join(DEFAULT_MODEL_DIR, f"{model_type}_product_{product_id}_*.pth")
|
|
|
|
|
existing_files = glob.glob(pattern)
|
2025-07-15 20:06:17 +08:00
|
|
|
|
|
2025-07-18 13:14:34 +08:00
|
|
|
|
versions = set()
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
2025-07-15 20:06:17 +08:00
|
|
|
|
for file_path in existing_files:
|
2025-07-02 11:05:23 +08:00
|
|
|
|
filename = os.path.basename(file_path)
|
2025-07-17 17:54:53 +08:00
|
|
|
|
|
2025-07-18 13:14:34 +08:00
|
|
|
|
# 严格匹配 _v<number> 或 'best'
|
|
|
|
|
match = re.search(r'_(v\d+|best)\.pth$', filename)
|
|
|
|
|
if match:
|
|
|
|
|
versions.add(match.group(1))
|
2025-07-17 17:54:53 +08:00
|
|
|
|
|
2025-07-18 13:14:34 +08:00
|
|
|
|
# 按数字版本降序排序,'best'始终在最前
|
|
|
|
|
def sort_key(v):
|
|
|
|
|
if v == 'best':
|
|
|
|
|
return -1 # 'best' is always first
|
|
|
|
|
if v.startswith('v'):
|
|
|
|
|
return int(v[1:])
|
|
|
|
|
return float('inf') # Should not happen
|
|
|
|
|
|
|
|
|
|
sorted_versions = sorted(list(versions), key=sort_key, reverse=True)
|
|
|
|
|
|
2025-07-15 20:06:17 +08:00
|
|
|
|
return sorted_versions
|
2025-07-02 11:05:23 +08:00
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save_model_version_info(product_id: str, model_type: str, version: str, file_path: str, metrics: dict = None):
|
|
|
|
|
"""
|
|
|
|
|
保存模型版本信息到数据库
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
product_id: 产品ID
|
|
|
|
|
model_type: 模型类型
|
|
|
|
|
version: 版本号
|
|
|
|
|
file_path: 模型文件路径
|
|
|
|
|
metrics: 模型性能指标
|
|
|
|
|
"""
|
|
|
|
|
import sqlite3
|
|
|
|
|
import json
|
|
|
|
|
from datetime import datetime
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
conn = sqlite3.connect('prediction_history.db')
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
|
|
|
|
# 插入模型版本记录
|
|
|
|
|
cursor.execute('''
|
|
|
|
|
INSERT INTO model_versions (
|
|
|
|
|
product_id, model_type, version, file_path, created_at, metrics, is_active
|
|
|
|
|
) VALUES (?, ?, ?, ?, ?, ?, ?)
|
|
|
|
|
''', (
|
|
|
|
|
product_id,
|
|
|
|
|
model_type,
|
|
|
|
|
version,
|
|
|
|
|
file_path,
|
|
|
|
|
datetime.now().isoformat(),
|
|
|
|
|
json.dumps(metrics) if metrics else None,
|
|
|
|
|
1 # 新模型默认为激活状态
|
|
|
|
|
))
|
|
|
|
|
|
|
|
|
|
conn.commit()
|
|
|
|
|
conn.close()
|
|
|
|
|
print(f"已保存模型版本信息: {product_id}_{model_type}_{version}")
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"保存模型版本信息失败: {str(e)}")
|
|
|
|
|
|
|
|
|
|
def get_model_version_info(product_id: str, model_type: str, version: str = None):
|
|
|
|
|
"""
|
|
|
|
|
从数据库获取模型版本信息
|
|
|
|
|
|
|
|
|
|
Args:
|
|
|
|
|
product_id: 产品ID
|
|
|
|
|
model_type: 模型类型
|
|
|
|
|
version: 版本号,如果为None则获取最新版本
|
|
|
|
|
|
|
|
|
|
Returns:
|
|
|
|
|
模型版本信息字典
|
|
|
|
|
"""
|
|
|
|
|
import sqlite3
|
|
|
|
|
import json
|
|
|
|
|
|
|
|
|
|
try:
|
|
|
|
|
conn = sqlite3.connect('prediction_history.db')
|
|
|
|
|
conn.row_factory = sqlite3.Row
|
|
|
|
|
cursor = conn.cursor()
|
|
|
|
|
|
|
|
|
|
if version:
|
|
|
|
|
cursor.execute('''
|
|
|
|
|
SELECT * FROM model_versions
|
|
|
|
|
WHERE product_id = ? AND model_type = ? AND version = ?
|
|
|
|
|
ORDER BY created_at DESC LIMIT 1
|
|
|
|
|
''', (product_id, model_type, version))
|
|
|
|
|
else:
|
|
|
|
|
cursor.execute('''
|
|
|
|
|
SELECT * FROM model_versions
|
|
|
|
|
WHERE product_id = ? AND model_type = ?
|
|
|
|
|
ORDER BY created_at DESC LIMIT 1
|
|
|
|
|
''', (product_id, model_type))
|
|
|
|
|
|
|
|
|
|
row = cursor.fetchone()
|
|
|
|
|
conn.close()
|
|
|
|
|
|
|
|
|
|
if row:
|
|
|
|
|
result = dict(row)
|
|
|
|
|
if result['metrics']:
|
|
|
|
|
result['metrics'] = json.loads(result['metrics'])
|
|
|
|
|
return result
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
|
print(f"获取模型版本信息失败: {str(e)}")
|
|
|
|
|
return None
|