228 lines
7.2 KiB
Python
Raw Permalink Normal View History

"""
药店销售预测系统 - 全局配置参数
"""
import torch
import matplotlib
matplotlib.use('Agg') # 设置matplotlib后端为Agg适用于无头服务器环境
import matplotlib.pyplot as plt
import os
2025-07-02 11:05:23 +08:00
import re
import glob
**日期**: 2025-07-14 **主题**: UI导航栏重构 ### 描述 根据用户请求,对左侧功能导航栏进行了调整。 ### 主要改动 1. **删除“数据管理”**: * 从 `UI/src/App.vue` 的导航菜单中移除了“数据管理”项。 * 从 `UI/src/router/index.js` 中删除了对应的 `/data` 路由。 * 删除了视图文件 `UI/src/views/DataView.vue`。 2. **提升“店铺管理”**: * 将“店铺管理”菜单项在 `UI/src/App.vue` 中的位置提升,以填补原“数据管理”的位置,使其在导航中更加突出。 ### 涉及文件 * `UI/src/App.vue` * `UI/src/router/index.js` * `UI/src/views/DataView.vue` (已删除) **按药品模型预测** --- **日期**: 2025-07-14 **主题**: 修复导航菜单高亮问题 ### 描述 修复了首次进入或刷新页面时,左侧导航菜单项与当前路由不匹配导致不高亮的问题。 ### 主要改动 * **文件**: `UI/src/App.vue` * **修改**: 1. 引入 `useRoute` 和 `computed`。 2. 创建了一个计算属性 `activeMenu`,其值动态地等于当前路由的路径 (`route.path`)。 3. 将 `el-menu` 组件的 `:default-active` 属性绑定到 `activeMenu`。 ### 结果 确保了导航菜单的高亮状态始终与当前页面的URL保持同步。 --- **日期**: 2025-07-15 **主题**: 修复硬编码文件路径问题,提高项目可移植性 ### 问题描述 项目在从一台计算机迁移到另一台时,由于数据文件路径被硬编码在代码中,导致程序无法找到数据文件而运行失败。 ### 根本原因 多个Python文件(`predictor.py`, `multi_store_data_utils.py`)中直接写入了相对路径 `'data/timeseries_training_data_sample_10s50p.parquet'` 作为默认值。这种方式在不同运行环境下(如从根目录运行 vs 从子目录运行)会产生路径解析错误。 ### 解决方案:集中配置,统一管理 1. **修改 `server/core/config.py` (核心)**: * 动态计算并定义了一个全局变量 `PROJECT_ROOT`,它始终指向项目的根目录。 * 基于 `PROJECT_ROOT`,使用 `os.path.join` 创建了一个跨平台的、绝对的默认数据路径 `DEFAULT_DATA_PATH` 和模型保存路径 `DEFAULT_MODEL_DIR`。 * 这确保了无论从哪个位置执行代码,路径总能被正确解析。 2. **修改 `server/utils/multi_store_data_utils.py`**: * 从 `server/core/config` 导入 `DEFAULT_DATA_PATH`。 * 将所有数据加载函数的 `file_path` 参数的默认值从硬编码的字符串改为 `None`。 * 在函数内部,如果 `file_path` 为 `None`,则自动使用导入的 `DEFAULT_DATA_PATH`。 * 移除了原有的、复杂的、为了猜测正确路径而编写的冗余代码。 3. **修改 `server/core/predictor.py`**: * 同样从 `server/core/config` 导入 `DEFAULT_DATA_PATH`。 * 在初始化 `PharmacyPredictor` 时,如果未提供数据路径,则使用导入的 `DEFAULT_DATA_PATH` 作为默认值。 ### 最终结果 通过将数据源路径集中到唯一的配置文件中进行管理,彻底解决了因硬编码路径导致的可移植性问题。项目现在可以在任何环境下可靠地运行。 --- ### 未来如何修改数据源(例如,连接到服务器数据库) 本次重构为将来更换数据源打下了坚实的基础。操作非常简单: 1. **定位配置文件**: 打开 `server/core/config.py` 文件。 2. **修改数据源定义**: * **当前 (文件)**: ```python DEFAULT_DATA_PATH = os.path.join(PROJECT_ROOT, 'data', 'timeseries_training_data_sample_10s50p.parquet') ``` * **未来 (数据库示例)**: 您可以将这行替换为数据库连接字符串,或者添加新的数据库配置变量。例如: ```python # 注释掉或删除旧的文件路径配置 # DEFAULT_DATA_PATH = ... # 新增数据库连接配置 DATABASE_URL = "postgresql://user:password@your_server_ip:5432/your_database_name" ``` 3. **修改数据加载逻辑**: * **定位数据加载函数**: 打开 `server/utils/multi_store_data_utils.py`。 * **修改 `load_multi_store_data` 函数**: * 引入数据库连接库(如 `sqlalchemy` 或 `psycopg2`)。 * 修改函数逻辑,使其使用 `config.py` 中的 `DATABASE_URL` 来连接数据库,并执行SQL查询来获取数据,而不是读取文件。 * **示例**: ```python from sqlalchemy import create_engine from core.config import DATABASE_URL # 导入新的数据库配置 def load_multi_store_data(...): # ... engine = create_engine(DATABASE_URL) query = "SELECT * FROM sales_data" # 根据需要构建查询 df = pd.read_sql(query, engine) # ... 后续处理逻辑保持不变 ... ```
2025-07-15 10:37:25 +08:00
# 项目根目录
# __file__ 是当前文件 (config.py) 的路径
# os.path.dirname(__file__) 是 server/core
# os.path.join(..., '..') 是 server
# os.path.join(..., '..', '..') 是项目根目录
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
# 解决画图中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 获取设备GPU或CPU
def get_device():
"""获取可用的计算设备GPU或CPU"""
if torch.cuda.is_available():
return torch.device('cuda')
else:
return torch.device('cpu')
# 全局设备
DEVICE = get_device()
# 数据相关配置
**日期**: 2025-07-14 **主题**: UI导航栏重构 ### 描述 根据用户请求,对左侧功能导航栏进行了调整。 ### 主要改动 1. **删除“数据管理”**: * 从 `UI/src/App.vue` 的导航菜单中移除了“数据管理”项。 * 从 `UI/src/router/index.js` 中删除了对应的 `/data` 路由。 * 删除了视图文件 `UI/src/views/DataView.vue`。 2. **提升“店铺管理”**: * 将“店铺管理”菜单项在 `UI/src/App.vue` 中的位置提升,以填补原“数据管理”的位置,使其在导航中更加突出。 ### 涉及文件 * `UI/src/App.vue` * `UI/src/router/index.js` * `UI/src/views/DataView.vue` (已删除) **按药品模型预测** --- **日期**: 2025-07-14 **主题**: 修复导航菜单高亮问题 ### 描述 修复了首次进入或刷新页面时,左侧导航菜单项与当前路由不匹配导致不高亮的问题。 ### 主要改动 * **文件**: `UI/src/App.vue` * **修改**: 1. 引入 `useRoute` 和 `computed`。 2. 创建了一个计算属性 `activeMenu`,其值动态地等于当前路由的路径 (`route.path`)。 3. 将 `el-menu` 组件的 `:default-active` 属性绑定到 `activeMenu`。 ### 结果 确保了导航菜单的高亮状态始终与当前页面的URL保持同步。 --- **日期**: 2025-07-15 **主题**: 修复硬编码文件路径问题,提高项目可移植性 ### 问题描述 项目在从一台计算机迁移到另一台时,由于数据文件路径被硬编码在代码中,导致程序无法找到数据文件而运行失败。 ### 根本原因 多个Python文件(`predictor.py`, `multi_store_data_utils.py`)中直接写入了相对路径 `'data/timeseries_training_data_sample_10s50p.parquet'` 作为默认值。这种方式在不同运行环境下(如从根目录运行 vs 从子目录运行)会产生路径解析错误。 ### 解决方案:集中配置,统一管理 1. **修改 `server/core/config.py` (核心)**: * 动态计算并定义了一个全局变量 `PROJECT_ROOT`,它始终指向项目的根目录。 * 基于 `PROJECT_ROOT`,使用 `os.path.join` 创建了一个跨平台的、绝对的默认数据路径 `DEFAULT_DATA_PATH` 和模型保存路径 `DEFAULT_MODEL_DIR`。 * 这确保了无论从哪个位置执行代码,路径总能被正确解析。 2. **修改 `server/utils/multi_store_data_utils.py`**: * 从 `server/core/config` 导入 `DEFAULT_DATA_PATH`。 * 将所有数据加载函数的 `file_path` 参数的默认值从硬编码的字符串改为 `None`。 * 在函数内部,如果 `file_path` 为 `None`,则自动使用导入的 `DEFAULT_DATA_PATH`。 * 移除了原有的、复杂的、为了猜测正确路径而编写的冗余代码。 3. **修改 `server/core/predictor.py`**: * 同样从 `server/core/config` 导入 `DEFAULT_DATA_PATH`。 * 在初始化 `PharmacyPredictor` 时,如果未提供数据路径,则使用导入的 `DEFAULT_DATA_PATH` 作为默认值。 ### 最终结果 通过将数据源路径集中到唯一的配置文件中进行管理,彻底解决了因硬编码路径导致的可移植性问题。项目现在可以在任何环境下可靠地运行。 --- ### 未来如何修改数据源(例如,连接到服务器数据库) 本次重构为将来更换数据源打下了坚实的基础。操作非常简单: 1. **定位配置文件**: 打开 `server/core/config.py` 文件。 2. **修改数据源定义**: * **当前 (文件)**: ```python DEFAULT_DATA_PATH = os.path.join(PROJECT_ROOT, 'data', 'timeseries_training_data_sample_10s50p.parquet') ``` * **未来 (数据库示例)**: 您可以将这行替换为数据库连接字符串,或者添加新的数据库配置变量。例如: ```python # 注释掉或删除旧的文件路径配置 # DEFAULT_DATA_PATH = ... # 新增数据库连接配置 DATABASE_URL = "postgresql://user:password@your_server_ip:5432/your_database_name" ``` 3. **修改数据加载逻辑**: * **定位数据加载函数**: 打开 `server/utils/multi_store_data_utils.py`。 * **修改 `load_multi_store_data` 函数**: * 引入数据库连接库(如 `sqlalchemy` 或 `psycopg2`)。 * 修改函数逻辑,使其使用 `config.py` 中的 `DATABASE_URL` 来连接数据库,并执行SQL查询来获取数据,而不是读取文件。 * **示例**: ```python from sqlalchemy import create_engine from core.config import DATABASE_URL # 导入新的数据库配置 def load_multi_store_data(...): # ... engine = create_engine(DATABASE_URL) query = "SELECT * FROM sales_data" # 根据需要构建查询 df = pd.read_sql(query, engine) # ... 后续处理逻辑保持不变 ... ```
2025-07-15 10:37:25 +08:00
# 使用 os.path.join 构造跨平台的路径
DEFAULT_DATA_PATH = os.path.join(PROJECT_ROOT, 'data', 'timeseries_training_data_sample_10s50p.parquet')
DEFAULT_MODEL_DIR = os.path.join(PROJECT_ROOT, 'saved_models')
DEFAULT_PREDICTIONS_DIR = os.path.join(PROJECT_ROOT, 'saved_predictions')
DEFAULT_FEATURES = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
# 时间序列参数
2025-07-02 11:05:23 +08:00
LOOK_BACK = 5 # 使用过去5天数据适应小数据集
FORECAST_HORIZON = 3 # 预测未来3天销量适应小数据集
# 训练参数
DEFAULT_EPOCHS = 50 # 训练轮次
DEFAULT_BATCH_SIZE = 32 # 批大小
DEFAULT_LEARNING_RATE = 0.001 # 学习率
# 模型参数
NUM_FEATURES = 8 # 输入特征数
EMBED_DIM = 32 # 嵌入维度
DENSE_DIM = 32 # 隐藏层神经元数
NUM_HEADS = 4 # 注意力头数
DROPOUT_RATE = 0.1 # 丢弃率
NUM_BLOCKS = 3 # 编码器解码器数
HIDDEN_SIZE = 64 # 隐藏层大小
NUM_LAYERS = 2 # 层数
# 支持的模型类型
2025-07-22 15:40:37 +08:00
# 支持的模型类型 (v2 - 动态加载)
from models.model_registry import TRAINER_REGISTRY
SUPPORTED_MODELS = list(TRAINER_REGISTRY.keys())
2025-07-02 11:05:23 +08:00
# 版本管理配置
MODEL_VERSION_PREFIX = 'v' # 版本前缀
DEFAULT_VERSION = 'v1' # 默认版本号
# WebSocket配置
WEBSOCKET_NAMESPACE = '/training' # WebSocket命名空间
TRAINING_UPDATE_INTERVAL = 1 # 训练进度更新间隔(秒)
# 创建模型和预测结果保存目录
2025-07-02 11:05:23 +08:00
os.makedirs(DEFAULT_MODEL_DIR, exist_ok=True)
os.makedirs(DEFAULT_PREDICTIONS_DIR, exist_ok=True)
2025-07-02 11:05:23 +08:00
2025-07-15 20:06:17 +08:00
def get_model_file_path(product_id: str, model_type: str, version: str) -> str:
2025-07-02 11:05:23 +08:00
"""
2025-07-15 20:06:17 +08:00
根据产品ID模型类型和版本号生成模型文件的准确路径
2025-07-02 11:05:23 +08:00
Args:
2025-07-15 20:06:17 +08:00
product_id: 产品ID (纯数字)
2025-07-02 11:05:23 +08:00
model_type: 模型类型
2025-07-15 20:06:17 +08:00
version: 版本字符串 (例如 'best', 'final_epoch_50', 'v1_legacy')
2025-07-02 11:05:23 +08:00
Returns:
模型文件的完整路径
"""
2025-07-15 20:06:17 +08:00
# 处理历史遗留的 "v1" 格式
if version == "v1_legacy":
filename = f"{model_type}_model_product_{product_id}.pth"
return os.path.join(DEFAULT_MODEL_DIR, filename)
2025-07-16 16:24:08 +08:00
# 修正直接使用唯一的product_id它可能包含store_前缀来构建文件名
# 文件名示例: transformer_17002608_epoch_best.pth 或 transformer_store_01010023_epoch_best.pth
2025-07-17 17:54:53 +08:00
# 针对 KAN 和 optimized_kan使用 model_manager 的命名约定
# 统一所有模型的命名格式
filename = f"{model_type}_product_{product_id}_{version}.pth"
2025-07-16 16:24:08 +08:00
# 修正直接在根模型目录查找不再使用checkpoints子目录
return os.path.join(DEFAULT_MODEL_DIR, filename)
2025-07-02 11:05:23 +08:00
def get_model_versions(product_id: str, model_type: str) -> list:
"""
获取指定产品和模型类型的所有版本
Args:
2025-07-15 20:06:17 +08:00
product_id: 产品ID (现在应该是纯数字ID)
2025-07-02 11:05:23 +08:00
model_type: 模型类型
Returns:
版本列表按版本号排序
"""
# 统一使用新的命名约定进行搜索
pattern = os.path.join(DEFAULT_MODEL_DIR, f"{model_type}_product_{product_id}_*.pth")
existing_files = glob.glob(pattern)
2025-07-15 20:06:17 +08:00
versions = set()
2025-07-02 11:05:23 +08:00
2025-07-15 20:06:17 +08:00
for file_path in existing_files:
2025-07-02 11:05:23 +08:00
filename = os.path.basename(file_path)
2025-07-17 17:54:53 +08:00
# 严格匹配 _v<number> 或 'best'
match = re.search(r'_(v\d+|best)\.pth$', filename)
if match:
versions.add(match.group(1))
2025-07-17 17:54:53 +08:00
# 按数字版本降序排序,'best'始终在最前
def sort_key(v):
if v == 'best':
return -1 # 'best' is always first
if v.startswith('v'):
return int(v[1:])
return float('inf') # Should not happen
sorted_versions = sorted(list(versions), key=sort_key, reverse=True)
2025-07-15 20:06:17 +08:00
return sorted_versions
2025-07-02 11:05:23 +08:00
def save_model_version_info(product_id: str, model_type: str, version: str, file_path: str, metrics: dict = None):
"""
保存模型版本信息到数据库
Args:
product_id: 产品ID
model_type: 模型类型
version: 版本号
file_path: 模型文件路径
metrics: 模型性能指标
"""
import sqlite3
import json
from datetime import datetime
try:
conn = sqlite3.connect('prediction_history.db')
cursor = conn.cursor()
# 插入模型版本记录
cursor.execute('''
INSERT INTO model_versions (
product_id, model_type, version, file_path, created_at, metrics, is_active
) VALUES (?, ?, ?, ?, ?, ?, ?)
''', (
product_id,
model_type,
version,
file_path,
datetime.now().isoformat(),
json.dumps(metrics) if metrics else None,
1 # 新模型默认为激活状态
))
conn.commit()
conn.close()
print(f"已保存模型版本信息: {product_id}_{model_type}_{version}")
except Exception as e:
print(f"保存模型版本信息失败: {str(e)}")
def get_model_version_info(product_id: str, model_type: str, version: str = None):
"""
从数据库获取模型版本信息
Args:
product_id: 产品ID
model_type: 模型类型
version: 版本号如果为None则获取最新版本
Returns:
模型版本信息字典
"""
import sqlite3
import json
try:
conn = sqlite3.connect('prediction_history.db')
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
if version:
cursor.execute('''
SELECT * FROM model_versions
WHERE product_id = ? AND model_type = ? AND version = ?
ORDER BY created_at DESC LIMIT 1
''', (product_id, model_type, version))
else:
cursor.execute('''
SELECT * FROM model_versions
WHERE product_id = ? AND model_type = ?
ORDER BY created_at DESC LIMIT 1
''', (product_id, model_type))
row = cursor.fetchone()
conn.close()
if row:
result = dict(row)
if result['metrics']:
result['metrics'] = json.loads(result['metrics'])
return result
return None
except Exception as e:
print(f"获取模型版本信息失败: {str(e)}")
return None