xz2000 b1b697117b **日期**: 2025-07-14
**主题**: UI导航栏重构

### 描述
根据用户请求,对左侧功能导航栏进行了调整。

### 主要改动
1.  **删除“数据管理”**:
    *   从 `UI/src/App.vue` 的导航菜单中移除了“数据管理”项。
    *   从 `UI/src/router/index.js` 中删除了对应的 `/data` 路由。
    *   删除了视图文件 `UI/src/views/DataView.vue`。

2.  **提升“店铺管理”**:
    *   将“店铺管理”菜单项在 `UI/src/App.vue` 中的位置提升,以填补原“数据管理”的位置,使其在导航中更加突出。

### 涉及文件
*   `UI/src/App.vue`
*   `UI/src/router/index.js`
*   `UI/src/views/DataView.vue` (已删除)

**按药品模型预测**
---
**日期**: 2025-07-14
**主题**: 修复导航菜单高亮问题

### 描述
修复了首次进入或刷新页面时,左侧导航菜单项与当前路由不匹配导致不高亮的问题。

### 主要改动
*   **文件**: `UI/src/App.vue`
*   **修改**:
    1.  引入 `useRoute` 和 `computed`。
    2.  创建了一个计算属性 `activeMenu`,其值动态地等于当前路由的路径 (`route.path`)。
    3.  将 `el-menu` 组件的 `:default-active` 属性绑定到 `activeMenu`。

### 结果
确保了导航菜单的高亮状态始终与当前页面的URL保持同步。

---
**日期**: 2025-07-15
**主题**: 修复硬编码文件路径问题,提高项目可移植性

### 问题描述
项目在从一台计算机迁移到另一台时,由于数据文件路径被硬编码在代码中,导致程序无法找到数据文件而运行失败。

### 根本原因
多个Python文件(`predictor.py`, `multi_store_data_utils.py`)中直接写入了相对路径 `'data/timeseries_training_data_sample_10s50p.parquet'` 作为默认值。这种方式在不同运行环境下(如从根目录运行 vs 从子目录运行)会产生路径解析错误。

### 解决方案:集中配置,统一管理
1.  **修改 `server/core/config.py` (核心)**:
    *   动态计算并定义了一个全局变量 `PROJECT_ROOT`,它始终指向项目的根目录。
    *   基于 `PROJECT_ROOT`,使用 `os.path.join` 创建了一个跨平台的、绝对的默认数据路径 `DEFAULT_DATA_PATH` 和模型保存路径 `DEFAULT_MODEL_DIR`。
    *   这确保了无论从哪个位置执行代码,路径总能被正确解析。

2.  **修改 `server/utils/multi_store_data_utils.py`**:
    *   从 `server/core/config` 导入 `DEFAULT_DATA_PATH`。
    *   将所有数据加载函数的 `file_path` 参数的默认值从硬编码的字符串改为 `None`。
    *   在函数内部,如果 `file_path` 为 `None`,则自动使用导入的 `DEFAULT_DATA_PATH`。
    *   移除了原有的、复杂的、为了猜测正确路径而编写的冗余代码。

3.  **修改 `server/core/predictor.py`**:
    *   同样从 `server/core/config` 导入 `DEFAULT_DATA_PATH`。
    *   在初始化 `PharmacyPredictor` 时,如果未提供数据路径,则使用导入的 `DEFAULT_DATA_PATH` 作为默认值。

### 最终结果
通过将数据源路径集中到唯一的配置文件中进行管理,彻底解决了因硬编码路径导致的可移植性问题。项目现在可以在任何环境下可靠地运行。

---
### 未来如何修改数据源(例如,连接到服务器数据库)

本次重构为将来更换数据源打下了坚实的基础。操作非常简单:

1.  **定位配置文件**: 打开 `server/core/config.py` 文件。

2.  **修改数据源定义**:
    *   **当前 (文件)**:
        ```python
        DEFAULT_DATA_PATH = os.path.join(PROJECT_ROOT, 'data', 'timeseries_training_data_sample_10s50p.parquet')
        ```
    *   **未来 (数据库示例)**:
        您可以将这行替换为数据库连接字符串,或者添加新的数据库配置变量。例如:
        ```python
        # 注释掉或删除旧的文件路径配置
        # DEFAULT_DATA_PATH = ...

        # 新增数据库连接配置
        DATABASE_URL = "postgresql://user:password@your_server_ip:5432/your_database_name"
        ```

3.  **修改数据加载逻辑**:
    *   **定位数据加载函数**: 打开 `server/utils/multi_store_data_utils.py`。
    *   **修改 `load_multi_store_data` 函数**:
        *   引入数据库连接库(如 `sqlalchemy` 或 `psycopg2`)。
        *   修改函数逻辑,使其使用 `config.py` 中的 `DATABASE_URL` 来连接数据库,并执行SQL查询来获取数据,而不是读取文件。
        *   **示例**:
            ```python
            from sqlalchemy import create_engine
            from core.config import DATABASE_URL # 导入新的数据库配置

            def load_multi_store_data(...):
                # ...
                engine = create_engine(DATABASE_URL)
                query = "SELECT * FROM sales_data" # 根据需要构建查询
                df = pd.read_sql(query, engine)
                # ... 后续处理逻辑保持不变 ...
            ```
2025-07-15 10:37:33 +08:00

286 lines
9.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
药店销售预测系统 - 全局配置参数
"""
import torch
import matplotlib
matplotlib.use('Agg') # 设置matplotlib后端为Agg适用于无头服务器环境
import matplotlib.pyplot as plt
import os
import re
import glob
# 项目根目录
# __file__ 是当前文件 (config.py) 的路径
# os.path.dirname(__file__) 是 server/core
# os.path.join(..., '..') 是 server
# os.path.join(..., '..', '..') 是项目根目录
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
# 解决画图中文显示问题
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
# 获取设备GPU或CPU
def get_device():
"""获取可用的计算设备GPU或CPU"""
if torch.cuda.is_available():
return torch.device('cuda')
else:
return torch.device('cpu')
# 全局设备
DEVICE = get_device()
# 数据相关配置
# 使用 os.path.join 构造跨平台的路径
DEFAULT_DATA_PATH = os.path.join(PROJECT_ROOT, 'data', 'timeseries_training_data_sample_10s50p.parquet')
DEFAULT_MODEL_DIR = os.path.join(PROJECT_ROOT, 'saved_models')
DEFAULT_FEATURES = ['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature']
# 时间序列参数
LOOK_BACK = 5 # 使用过去5天数据适应小数据集
FORECAST_HORIZON = 3 # 预测未来3天销量适应小数据集
# 训练参数
DEFAULT_EPOCHS = 50 # 训练轮次
DEFAULT_BATCH_SIZE = 32 # 批大小
DEFAULT_LEARNING_RATE = 0.001 # 学习率
# 模型参数
NUM_FEATURES = 8 # 输入特征数
EMBED_DIM = 32 # 嵌入维度
DENSE_DIM = 32 # 隐藏层神经元数
NUM_HEADS = 4 # 注意力头数
DROPOUT_RATE = 0.1 # 丢弃率
NUM_BLOCKS = 3 # 编码器解码器数
HIDDEN_SIZE = 64 # 隐藏层大小
NUM_LAYERS = 2 # 层数
# 支持的模型类型
SUPPORTED_MODELS = ['mlstm', 'kan', 'transformer', 'tcn', 'optimized_kan']
# 版本管理配置
MODEL_VERSION_PREFIX = 'v' # 版本前缀
DEFAULT_VERSION = 'v1' # 默认版本号
# WebSocket配置
WEBSOCKET_NAMESPACE = '/training' # WebSocket命名空间
TRAINING_UPDATE_INTERVAL = 1 # 训练进度更新间隔(秒)
# 创建模型保存目录
os.makedirs(DEFAULT_MODEL_DIR, exist_ok=True)
def get_next_model_version(product_id: str, model_type: str) -> str:
"""
获取指定产品和模型类型的下一个版本号
Args:
product_id: 产品ID
model_type: 模型类型
Returns:
下一个版本号,格式如 'v2', 'v3'
"""
# 新格式:带版本号的文件
pattern_new = f"{model_type}_model_product_{product_id}_v*.pth"
existing_files_new = glob.glob(os.path.join(DEFAULT_MODEL_DIR, pattern_new))
# 旧格式:不带版本号的文件(兼容性支持)
pattern_old = f"{model_type}_model_product_{product_id}.pth"
old_file_path = os.path.join(DEFAULT_MODEL_DIR, pattern_old)
has_old_format = os.path.exists(old_file_path)
# 如果没有任何格式的文件,返回默认版本
if not existing_files_new and not has_old_format:
return DEFAULT_VERSION
# 提取新格式文件的版本号
versions = []
for file_path in existing_files_new:
filename = os.path.basename(file_path)
version_match = re.search(rf"_v(\d+)\.pth$", filename)
if version_match:
versions.append(int(version_match.group(1)))
# 如果存在旧格式文件将其视为v1
if has_old_format:
versions.append(1)
print(f"检测到旧格式模型文件: {old_file_path}将其视为版本v1")
if versions:
next_version_num = max(versions) + 1
return f"v{next_version_num}"
else:
return DEFAULT_VERSION
def get_model_file_path(product_id: str, model_type: str, version: str = None) -> str:
"""
生成模型文件路径
Args:
product_id: 产品ID
model_type: 模型类型
version: 版本号如果为None则获取下一个版本
Returns:
模型文件的完整路径
"""
if version is None:
version = get_next_model_version(product_id, model_type)
# 特殊处理v1版本检查是否存在旧格式文件
if version == "v1":
# 检查旧格式文件是否存在
old_format_filename = f"{model_type}_model_product_{product_id}.pth"
old_format_path = os.path.join(DEFAULT_MODEL_DIR, old_format_filename)
if os.path.exists(old_format_path):
print(f"找到旧格式模型文件: {old_format_path}将其作为v1版本")
return old_format_path
# 使用新格式文件名
filename = f"{model_type}_model_product_{product_id}_{version}.pth"
return os.path.join(DEFAULT_MODEL_DIR, filename)
def get_model_versions(product_id: str, model_type: str) -> list:
"""
获取指定产品和模型类型的所有版本
Args:
product_id: 产品ID
model_type: 模型类型
Returns:
版本列表,按版本号排序
"""
# 新格式:带版本号的文件
pattern_new = f"{model_type}_model_product_{product_id}_v*.pth"
existing_files_new = glob.glob(os.path.join(DEFAULT_MODEL_DIR, pattern_new))
# 旧格式:不带版本号的文件(兼容性支持)
pattern_old = f"{model_type}_model_product_{product_id}.pth"
old_file_path = os.path.join(DEFAULT_MODEL_DIR, pattern_old)
has_old_format = os.path.exists(old_file_path)
versions = []
# 处理新格式文件
for file_path in existing_files_new:
filename = os.path.basename(file_path)
version_match = re.search(rf"_v(\d+)\.pth$", filename)
if version_match:
version_num = int(version_match.group(1))
versions.append(f"v{version_num}")
# 如果存在旧格式文件将其视为v1
if has_old_format:
if "v1" not in versions: # 避免重复添加
versions.append("v1")
print(f"检测到旧格式模型文件: {old_file_path}将其视为版本v1")
# 按版本号排序
versions.sort(key=lambda v: int(v[1:]))
return versions
def get_latest_model_version(product_id: str, model_type: str) -> str:
"""
获取指定产品和模型类型的最新版本
Args:
product_id: 产品ID
model_type: 模型类型
Returns:
最新版本号如果没有则返回None
"""
versions = get_model_versions(product_id, model_type)
return versions[-1] if versions else None
def save_model_version_info(product_id: str, model_type: str, version: str, file_path: str, metrics: dict = None):
"""
保存模型版本信息到数据库
Args:
product_id: 产品ID
model_type: 模型类型
version: 版本号
file_path: 模型文件路径
metrics: 模型性能指标
"""
import sqlite3
import json
from datetime import datetime
try:
conn = sqlite3.connect('prediction_history.db')
cursor = conn.cursor()
# 插入模型版本记录
cursor.execute('''
INSERT INTO model_versions (
product_id, model_type, version, file_path, created_at, metrics, is_active
) VALUES (?, ?, ?, ?, ?, ?, ?)
''', (
product_id,
model_type,
version,
file_path,
datetime.now().isoformat(),
json.dumps(metrics) if metrics else None,
1 # 新模型默认为激活状态
))
conn.commit()
conn.close()
print(f"已保存模型版本信息: {product_id}_{model_type}_{version}")
except Exception as e:
print(f"保存模型版本信息失败: {str(e)}")
def get_model_version_info(product_id: str, model_type: str, version: str = None):
"""
从数据库获取模型版本信息
Args:
product_id: 产品ID
model_type: 模型类型
version: 版本号如果为None则获取最新版本
Returns:
模型版本信息字典
"""
import sqlite3
import json
try:
conn = sqlite3.connect('prediction_history.db')
conn.row_factory = sqlite3.Row
cursor = conn.cursor()
if version:
cursor.execute('''
SELECT * FROM model_versions
WHERE product_id = ? AND model_type = ? AND version = ?
ORDER BY created_at DESC LIMIT 1
''', (product_id, model_type, version))
else:
cursor.execute('''
SELECT * FROM model_versions
WHERE product_id = ? AND model_type = ?
ORDER BY created_at DESC LIMIT 1
''', (product_id, model_type))
row = cursor.fetchone()
conn.close()
if row:
result = dict(row)
if result['metrics']:
result['metrics'] = json.loads(result['metrics'])
return result
return None
except Exception as e:
print(f"获取模型版本信息失败: {str(e)}")
return None