306 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
药店销售预测系统 - 全局配置参数
"""
import torch
import matplotlib
matplotlib.use('Agg') # 设置matplotlib后端为Agg适用于无头服务器环境
import matplotlib.pyplot as plt
import os
import re
import glob
# 项目根目录
# __file__ 是当前文件 (config.py) 的路径
# os.path.dirname(__file__) 是 server/core
# os.path.join(..., '..') 是 server
# os.path.join(..., '..', '..') 是项目根目录
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
# 解决画图中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 获取设备GPU或CPU
def get_device():
"""获取可用的计算设备GPU或CPU"""
if torch.cuda.is_available():
return torch.device('cuda')
else:
return torch.device('cpu')
# 全局设备
DEVICE = get_device()
# 数据相关配置
# 使用 os.path.join 构造跨平台的路径
DEFAULT_DATA_PATH = os.path.join(PROJECT_ROOT, 'data', 'timeseries_training_data_sample_10s50p.parquet')
DEFAULT_MODEL_DIR = os.path.join(PROJECT_ROOT, 'saved_models')
DEFAULT_FEATURES = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
# 时间序列参数
LOOK_BACK = 5 # 使用过去5天数据适应小数据集
FORECAST_HORIZON = 3 # 预测未来3天销量适应小数据集
# 训练参数
DEFAULT_EPOCHS = 50 # 训练轮次
DEFAULT_BATCH_SIZE = 32 # 批大小
DEFAULT_LEARNING_RATE = 0.001 # 学习率
# 模型参数
NUM_FEATURES = 8 # 输入特征数
EMBED_DIM = 32 # 嵌入维度
DENSE_DIM = 32 # 隐藏层神经元数
NUM_HEADS = 4 # 注意力头数
DROPOUT_RATE = 0.1 # 丢弃率
NUM_BLOCKS = 3 # 编码器解码器数
HIDDEN_SIZE = 64 # 隐藏层大小
NUM_LAYERS = 2 # 层数
# 支持的模型类型
SUPPORTED_MODELS = ['mlstm', 'kan', 'transformer', 'tcn', 'optimized_kan']
# 版本管理配置
MODEL_VERSION_PREFIX = 'v' # 版本前缀
DEFAULT_VERSION = 'v1' # 默认版本号
# WebSocket配置
WEBSOCKET_NAMESPACE = '/training' # WebSocket命名空间
TRAINING_UPDATE_INTERVAL = 1 # 训练进度更新间隔(秒)
# 创建模型保存目录
os.makedirs(DEFAULT_MODEL_DIR, exist_ok=True)
def get_next_model_version(product_id: str, model_type: str) -> str:
"""
获取指定产品和模型类型的下一个版本号
Args:
product_id: 产品ID
model_type: 模型类型
Returns:
下一个版本号,格式如 'v2', 'v3'
"""
# 新格式:带版本号的文件
pattern_new = f"{model_type}_model_product_{product_id}_v*.pth"
existing_files_new = glob.glob(os.path.join(DEFAULT_MODEL_DIR, pattern_new))
# 旧格式:不带版本号的文件(兼容性支持)
pattern_old = f"{model_type}_model_product_{product_id}.pth"
old_file_path = os.path.join(DEFAULT_MODEL_DIR, pattern_old)
has_old_format = os.path.exists(old_file_path)
# 如果没有任何格式的文件,返回默认版本
if not existing_files_new and not has_old_format:
return DEFAULT_VERSION
# 提取新格式文件的版本号
versions = []
for file_path in existing_files_new:
filename = os.path.basename(file_path)
version_match = re.search(rf"_v(\d+)\.pth$", filename)
if version_match:
versions.append(int(version_match.group(1)))
# 如果存在旧格式文件将其视为v1
if has_old_format:
versions.append(1)
print(f"检测到旧格式模型文件: {old_file_path}将其视为版本v1")
if versions:
next_version_num = max(versions) + 1
return f"v{next_version_num}"
else:
return DEFAULT_VERSION
def get_model_file_path(product_id: str, model_type: str, version: str) -> str:
"""
根据产品ID、模型类型和版本号生成模型文件的准确路径。
Args:
product_id: 产品ID (纯数字)
model_type: 模型类型
version: 版本字符串 (例如 'best', 'final_epoch_50', 'v1_legacy')
Returns:
模型文件的完整路径
"""
# 处理历史遗留的 "v1" 格式
if version == "v1_legacy":
filename = f"{model_type}_model_product_{product_id}.pth"
return os.path.join(DEFAULT_MODEL_DIR, filename)
# 修正直接使用唯一的product_id它可能包含store_前缀来构建文件名
# 文件名示例: transformer_17002608_epoch_best.pth 或 transformer_store_01010023_epoch_best.pth
# 针对 KAN 和 optimized_kan使用 model_manager 的命名约定
if model_type in ['kan', 'optimized_kan']:
# 格式: {model_type}_product_{product_id}_{version}.pth
# 注意KAN trainer 保存时product_id 就是 model_identifier
filename = f"{model_type}_product_{product_id}_{version}.pth"
else:
# 其他模型使用 _epoch_ 约定
filename = f"{model_type}_{product_id}_epoch_{version}.pth"
# 修正直接在根模型目录查找不再使用checkpoints子目录
return os.path.join(DEFAULT_MODEL_DIR, filename)
def get_model_versions(product_id: str, model_type: str) -> list:
"""
获取指定产品和模型类型的所有版本
Args:
product_id: 产品ID (现在应该是纯数字ID)
model_type: 模型类型
Returns:
版本列表,按版本号排序
"""
# 直接使用传入的product_id构建搜索模式
# 搜索模式,匹配 "transformer_product_17002608_epoch_50.pth" 或 "transformer_product_17002608_epoch_best.pth"
# 修正直接使用唯一的product_id它可能包含store_前缀来构建搜索模式
# 扩展搜索模式以兼容多种命名约定
patterns = [
f"{model_type}_{product_id}_epoch_*.pth", # 原始格式 (e.g., transformer_123_epoch_best.pth)
f"{model_type}_product_{product_id}_*.pth" # KAN/ModelManager格式 (e.g., kan_product_123_v1.pth)
]
existing_files = []
for pattern in patterns:
search_path = os.path.join(DEFAULT_MODEL_DIR, pattern)
existing_files.extend(glob.glob(search_path))
# 旧格式(兼容性支持)
pattern_old = f"{model_type}_model_product_{product_id}.pth"
old_file_path = os.path.join(DEFAULT_MODEL_DIR, pattern_old)
if os.path.exists(old_file_path):
existing_files.append(old_file_path)
versions = set() # 使用集合避免重复
# 从找到的文件中提取版本信息
for file_path in existing_files:
filename = os.path.basename(file_path)
# 尝试匹配 _epoch_ 格式
version_match_epoch = re.search(r"_epoch_(.+)\.pth$", filename)
if version_match_epoch:
versions.add(version_match_epoch.group(1))
continue
# 尝试匹配 _product_..._v 格式 (KAN)
version_match_kan = re.search(r"_product_.+_v(\d+)\.pth$", filename)
if version_match_kan:
versions.add(f"v{version_match_kan.group(1)}")
continue
# 尝试匹配旧的 _model_product_ 格式
if pattern_old in filename:
versions.add("v1_legacy") # 添加一个特殊标识
print(f"检测到旧格式模型文件: {old_file_path},将其视为版本 v1_legacy")
continue
# 转换为列表并排序
sorted_versions = sorted(list(versions))
return sorted_versions
def get_latest_model_version(product_id: str, model_type: str) -> str:
"""
获取指定产品和模型类型的最新版本
Args:
product_id: 产品ID
model_type: 模型类型
Returns:
最新版本号如果没有则返回None
"""
versions = get_model_versions(product_id, model_type)
return versions[-1] if versions else None
def save_model_version_info(product_id: str, model_type: str, version: str, file_path: str, metrics: dict = None):
"""
保存模型版本信息到数据库
Args:
product_id: 产品ID
model_type: 模型类型
version: 版本号
file_path: 模型文件路径
metrics: 模型性能指标
"""
import sqlite3
import json
from datetime import datetime
try:
conn = sqlite3.connect('prediction_history.db')
cursor = conn.cursor()
# 插入模型版本记录
cursor.execute('''
INSERT INTO model_versions (
product_id, model_type, version, file_path, created_at, metrics, is_active
) VALUES (?, ?, ?, ?, ?, ?, ?)
''', (
product_id,
model_type,
version,
file_path,
datetime.now().isoformat(),
json.dumps(metrics) if metrics else None,
1 # 新模型默认为激活状态
))
conn.commit()
conn.close()
print(f"已保存模型版本信息: {product_id}_{model_type}_{version}")
except Exception as e:
print(f"保存模型版本信息失败: {str(e)}")
def get_model_version_info(product_id: str, model_type: str, version: str = None):
"""
从数据库获取模型版本信息
Args:
product_id: 产品ID
model_type: 模型类型
version: 版本号如果为None则获取最新版本
Returns:
模型版本信息字典
"""
import sqlite3
import json
try:
conn = sqlite3.connect('prediction_history.db')
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
if version:
cursor.execute('''
SELECT * FROM model_versions
WHERE product_id = ? AND model_type = ? AND version = ?
ORDER BY created_at DESC LIMIT 1
''', (product_id, model_type, version))
else:
cursor.execute('''
SELECT * FROM model_versions
WHERE product_id = ? AND model_type = ?
ORDER BY created_at DESC LIMIT 1
''', (product_id, model_type))
row = cursor.fetchone()
conn.close()
if row:
result = dict(row)
if result['metrics']:
result['metrics'] = json.loads(result['metrics'])
return result
return None
except Exception as e:
print(f"获取模型版本信息失败: {str(e)}")
return None