ShopTRAINING/server/core/predictor.py
xz2000 b1b697117b **日期**: 2025-07-14
**主题**: UI导航栏重构

### 描述
根据用户请求,对左侧功能导航栏进行了调整。

### 主要改动
1.  **删除“数据管理”**:
    *   从 `UI/src/App.vue` 的导航菜单中移除了“数据管理”项。
    *   从 `UI/src/router/index.js` 中删除了对应的 `/data` 路由。
    *   删除了视图文件 `UI/src/views/DataView.vue`。

2.  **提升“店铺管理”**:
    *   将“店铺管理”菜单项在 `UI/src/App.vue` 中的位置提升,以填补原“数据管理”的位置,使其在导航中更加突出。

### 涉及文件
*   `UI/src/App.vue`
*   `UI/src/router/index.js`
*   `UI/src/views/DataView.vue` (已删除)

**按药品模型预测**
---
**日期**: 2025-07-14
**主题**: 修复导航菜单高亮问题

### 描述
修复了首次进入或刷新页面时,左侧导航菜单项与当前路由不匹配导致不高亮的问题。

### 主要改动
*   **文件**: `UI/src/App.vue`
*   **修改**:
    1.  引入 `useRoute` 和 `computed`。
    2.  创建了一个计算属性 `activeMenu`,其值动态地等于当前路由的路径 (`route.path`)。
    3.  将 `el-menu` 组件的 `:default-active` 属性绑定到 `activeMenu`。

### 结果
确保了导航菜单的高亮状态始终与当前页面的URL保持同步。

---
**日期**: 2025-07-15
**主题**: 修复硬编码文件路径问题,提高项目可移植性

### 问题描述
项目在从一台计算机迁移到另一台时,由于数据文件路径被硬编码在代码中,导致程序无法找到数据文件而运行失败。

### 根本原因
多个Python文件(`predictor.py`, `multi_store_data_utils.py`)中直接写入了相对路径 `'data/timeseries_training_data_sample_10s50p.parquet'` 作为默认值。这种方式在不同运行环境下(如从根目录运行 vs 从子目录运行)会产生路径解析错误。

### 解决方案:集中配置,统一管理
1.  **修改 `server/core/config.py` (核心)**:
    *   动态计算并定义了一个全局变量 `PROJECT_ROOT`,它始终指向项目的根目录。
    *   基于 `PROJECT_ROOT`,使用 `os.path.join` 创建了一个跨平台的、绝对的默认数据路径 `DEFAULT_DATA_PATH` 和模型保存路径 `DEFAULT_MODEL_DIR`。
    *   这确保了无论从哪个位置执行代码,路径总能被正确解析。

2.  **修改 `server/utils/multi_store_data_utils.py`**:
    *   从 `server/core/config` 导入 `DEFAULT_DATA_PATH`。
    *   将所有数据加载函数的 `file_path` 参数的默认值从硬编码的字符串改为 `None`。
    *   在函数内部,如果 `file_path` 为 `None`,则自动使用导入的 `DEFAULT_DATA_PATH`。
    *   移除了原有的、复杂的、为了猜测正确路径而编写的冗余代码。

3.  **修改 `server/core/predictor.py`**:
    *   同样从 `server/core/config` 导入 `DEFAULT_DATA_PATH`。
    *   在初始化 `PharmacyPredictor` 时,如果未提供数据路径,则使用导入的 `DEFAULT_DATA_PATH` 作为默认值。

### 最终结果
通过将数据源路径集中到唯一的配置文件中进行管理,彻底解决了因硬编码路径导致的可移植性问题。项目现在可以在任何环境下可靠地运行。

---
### 未来如何修改数据源(例如,连接到服务器数据库)

本次重构为将来更换数据源打下了坚实的基础。操作非常简单:

1.  **定位配置文件**: 打开 `server/core/config.py` 文件。

2.  **修改数据源定义**:
    *   **当前 (文件)**:
        ```python
        DEFAULT_DATA_PATH = os.path.join(PROJECT_ROOT, 'data', 'timeseries_training_data_sample_10s50p.parquet')
        ```
    *   **未来 (数据库示例)**:
        您可以将这行替换为数据库连接字符串,或者添加新的数据库配置变量。例如:
        ```python
        # 注释掉或删除旧的文件路径配置
        # DEFAULT_DATA_PATH = ...

        # 新增数据库连接配置
        DATABASE_URL = "postgresql://user:password@your_server_ip:5432/your_database_name"
        ```

3.  **修改数据加载逻辑**:
    *   **定位数据加载函数**: 打开 `server/utils/multi_store_data_utils.py`。
    *   **修改 `load_multi_store_data` 函数**:
        *   引入数据库连接库(如 `sqlalchemy` 或 `psycopg2`)。
        *   修改函数逻辑,使其使用 `config.py` 中的 `DATABASE_URL` 来连接数据库,并执行SQL查询来获取数据,而不是读取文件。
        *   **示例**:
            ```python
            from sqlalchemy import create_engine
            from core.config import DATABASE_URL # 导入新的数据库配置

            def load_multi_store_data(...):
                # ...
                engine = create_engine(DATABASE_URL)
                query = "SELECT * FROM sales_data" # 根据需要构建查询
                df = pd.read_sql(query, engine)
                # ... 后续处理逻辑保持不变 ...
            ```
2025-07-15 10:37:33 +08:00

548 lines
22 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
药店销售预测系统 - 核心预测器类
支持多店铺销售预测功能
"""
import os
import pandas as pd
import numpy as np
import torch
import time
import matplotlib.pyplot as plt
from datetime import datetime
from trainers import (
train_product_model_with_mlstm,
train_product_model_with_kan,
train_product_model_with_tcn,
train_product_model_with_transformer
)
from predictors.model_predictor import load_model_and_predict
from utils.data_utils import prepare_data, prepare_sequences
from utils.multi_store_data_utils import (
load_multi_store_data,
get_store_product_sales_data,
aggregate_multi_store_data
)
from analysis.metrics import evaluate_model
from core.config import DEVICE, DEFAULT_MODEL_DIR, DEFAULT_DATA_PATH
class PharmacyPredictor:
"""
药店销售预测系统核心类,用于训练模型和进行预测
"""
def __init__(self, data_path=None, model_dir=DEFAULT_MODEL_DIR):
"""
初始化预测器
参数:
data_path: 数据文件路径默认使用多店铺CSV文件
model_dir: 模型保存目录
"""
# 设置默认数据路径为多店铺CSV文件
if data_path is None:
data_path = DEFAULT_DATA_PATH
self.data_path = data_path
self.model_dir = model_dir
self.device = DEVICE
if not os.path.exists(model_dir):
os.makedirs(model_dir)
print(f"使用设备: {self.device}")
# 尝试加载多店铺数据
try:
self.data = load_multi_store_data(data_path)
print(f"已加载多店铺数据,来源: {data_path}")
except Exception as e:
print(f"加载数据失败: {e}")
self.data = None
def train_model(self, product_id, model_type='transformer', epochs=100, batch_size=32,
learning_rate=0.001, sequence_length=30, forecast_horizon=7,
hidden_size=64, num_layers=2, dropout=0.1, use_optimized=False,
store_id=None, training_mode='product', aggregation_method='sum',
socketio=None, task_id=None, version=None, continue_training=False,
progress_callback=None):
"""
训练预测模型 - 支持多店铺训练
参数:
product_id: 产品ID
model_type: 模型类型 ('transformer', 'mlstm', 'kan', 'tcn', 'optimized_kan')
epochs: 训练轮次
batch_size: 批次大小
learning_rate: 学习率
sequence_length: 输入序列长度
forecast_horizon: 预测天数
hidden_size: 隐藏层大小
num_layers: 层数
dropout: Dropout比例
use_optimized: 是否使用优化版KAN仅当model_type为'kan'时有效)
store_id: 店铺ID仅当training_mode为'store'时使用)
training_mode: 训练模式 ('product', 'store', 'global')
aggregation_method: 聚合方法 ('sum', 'mean', 'median') - 仅用于全局训练
返回:
metrics: 模型评估指标
"""
# 创建统一的输出函数
def log_message(message, log_type='info'):
"""统一的日志输出函数"""
print(message, flush=True) # 始终输出到控制台
# 如果有进度回调,也发送到回调
if progress_callback:
try:
progress_callback({
'log_type': log_type,
'message': message
})
except Exception as e:
print(f"进度回调失败: {e}", flush=True)
if self.data is None:
log_message("没有可用的数据,请先加载或生成数据", 'error')
return None
# 根据训练模式准备数据
if training_mode == 'product':
# 按产品训练:使用所有店铺的该产品数据
product_data = self.data[self.data['product_id'] == product_id].copy()
if product_data.empty:
log_message(f"找不到产品 {product_id} 的数据", 'error')
return None
log_message(f"按产品训练模式: 产品 {product_id}, 数据量: {len(product_data)}")
elif training_mode == 'store':
# 按店铺训练
if not store_id:
log_message("店铺训练模式需要指定 store_id", 'error')
return None
# 如果product_id是'unknown',则表示为店铺所有商品训练一个聚合模型
if product_id == 'unknown':
try:
# 使用新的聚合函数,按店铺聚合
product_data = aggregate_multi_store_data(
store_id=store_id,
aggregation_method=aggregation_method,
file_path=self.data_path
)
log_message(f"按店铺聚合训练: 店铺 {store_id}, 聚合方法 {aggregation_method}, 数据量: {len(product_data)}")
# 将product_id设置为店铺ID以便模型保存时使用有意义的标识
product_id = store_id
except Exception as e:
log_message(f"聚合店铺 {store_id} 数据失败: {e}", 'error')
return None
else:
# 为店铺的单个特定产品训练
try:
product_data = get_store_product_sales_data(
store_id=store_id,
product_id=product_id,
file_path=self.data_path
)
log_message(f"按店铺-产品训练: 店铺 {store_id}, 产品 {product_id}, 数据量: {len(product_data)}")
except Exception as e:
log_message(f"获取店铺产品数据失败: {e}", 'error')
return None
elif training_mode == 'global':
# 全局训练:聚合所有店铺的产品数据
try:
# 如果product_id是'unknown',则表示为全局所有商品训练一个聚合模型
if product_id == 'unknown':
product_data = aggregate_multi_store_data(
product_id=None, # 传递None以触发真正的全局聚合
aggregation_method=aggregation_method,
file_path=self.data_path
)
log_message(f"全局训练模式: 所有产品, 聚合方法 {aggregation_method}, 数据量: {len(product_data)}")
# 将product_id设置为一个有意义的标识符
product_id = 'all_products'
else:
product_data = aggregate_multi_store_data(
product_id=product_id,
aggregation_method=aggregation_method,
file_path=self.data_path
)
log_message(f"全局训练模式: 产品 {product_id}, 聚合方法 {aggregation_method}, 数据量: {len(product_data)}")
except Exception as e:
log_message(f"聚合全局数据失败: {e}", 'error')
return None
else:
log_message(f"不支持的训练模式: {training_mode}", 'error')
return None
# 根据训练模式构建模型标识符
if training_mode == 'store':
model_identifier = f"{store_id}_{product_id}"
elif training_mode == 'global':
model_identifier = f"global_{product_id}_{aggregation_method}"
else:
model_identifier = product_id
# 调用相应的训练函数
try:
log_message(f"🤖 开始调用 {model_type} 训练器")
if model_type == 'transformer':
model_result, metrics, actual_version = train_product_model_with_transformer(
product_id=product_id,
product_df=product_data,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
epochs=epochs,
model_dir=self.model_dir,
version=version,
socketio=socketio,
task_id=task_id,
continue_training=continue_training
)
log_message(f"{model_type} 训练器返回: metrics={type(metrics)}, version={actual_version}", 'success')
elif model_type == 'mlstm':
_, metrics, _, _ = train_product_model_with_mlstm(
product_id=product_id,
product_df=product_data,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
epochs=epochs,
model_dir=self.model_dir,
socketio=socketio,
task_id=task_id,
progress_callback=progress_callback
)
elif model_type == 'kan':
_, metrics = train_product_model_with_kan(
product_id=product_id,
product_df=product_data,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
epochs=epochs,
use_optimized=use_optimized,
model_dir=self.model_dir
)
elif model_type == 'optimized_kan':
_, metrics = train_product_model_with_kan(
product_id=product_id,
product_df=product_data,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
epochs=epochs,
use_optimized=True,
model_dir=self.model_dir
)
elif model_type == 'tcn':
_, metrics, _, _ = train_product_model_with_tcn(
product_id=product_id,
product_df=product_data,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
epochs=epochs,
model_dir=self.model_dir,
socketio=socketio,
task_id=task_id
)
else:
log_message(f"不支持的模型类型: {model_type}", 'error')
return None
# 检查和打印返回的metrics
log_message(f"📊 训练完成检查返回的metrics: {metrics}")
# 在返回的metrics中添加训练信息
if metrics:
log_message(f"✅ metrics不为空添加训练信息")
metrics.update({
'training_mode': training_mode,
'store_id': store_id,
'product_id': product_id,
'model_identifier': model_identifier,
'aggregation_method': aggregation_method if training_mode == 'global' else None
})
log_message(f"📈 最终返回的metrics: {metrics}", 'success')
else:
log_message(f"⚠️ metrics为空或None", 'warning')
return metrics
except Exception as e:
log_message(f"模型训练失败: {e}", 'error')
return None
def predict(self, product_id, model_type, future_days=7, start_date=None, analyze_result=False, version=None,
store_id=None, training_mode='product', aggregation_method='sum'):
"""
使用已训练的模型进行预测 - 支持多店铺预测
参数:
product_id: 产品ID
model_type: 模型类型
future_days: 预测未来天数
start_date: 预测起始日期
analyze_result: 是否分析预测结果
version: 模型版本如果为None则使用最新版本
store_id: 店铺ID仅当training_mode为'store'时使用)
training_mode: 训练模式 ('product', 'store', 'global')
aggregation_method: 聚合方法 ('sum', 'mean', 'median') - 仅用于全局预测
返回:
预测结果和分析如果analyze_result为True
"""
# 根据训练模式构建模型标识符
if training_mode == 'store' and store_id:
model_identifier = f"{store_id}_{product_id}"
elif training_mode == 'global':
model_identifier = f"global_{product_id}_{aggregation_method}"
else:
model_identifier = product_id
return load_model_and_predict(
model_identifier,
model_type,
future_days=future_days,
start_date=start_date,
analyze_result=analyze_result,
version=version
)
def train_optimized_kan_model(self, product_id, epochs=100, batch_size=32,
learning_rate=0.001, sequence_length=30, forecast_horizon=7,
hidden_size=64, num_layers=2, dropout=0.1):
"""
训练优化版KAN模型便捷方法
参数与train_model相同但固定model_type为'kan'且use_optimized为True
"""
return self.train_model(
product_id=product_id,
model_type='kan',
epochs=epochs,
batch_size=batch_size,
learning_rate=learning_rate,
sequence_length=sequence_length,
forecast_horizon=forecast_horizon,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
use_optimized=True
)
def compare_kan_models(self, product_id, epochs=100, batch_size=32,
learning_rate=0.001, sequence_length=30, forecast_horizon=7,
hidden_size=64, num_layers=2, dropout=0.1):
"""
比较原始KAN和优化版KAN模型性能
参数与train_model相同
返回:
比较结果字典
"""
print(f"开始比较产品 {product_id} 的原始KAN和优化版KAN模型性能...")
# 训练原始KAN模型
print("\n训练原始KAN模型...")
kan_metrics = self.train_model(
product_id=product_id,
model_type='kan',
epochs=epochs,
batch_size=batch_size,
learning_rate=learning_rate,
sequence_length=sequence_length,
forecast_horizon=forecast_horizon,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
use_optimized=False
)
# 训练优化版KAN模型
print("\n训练优化版KAN模型...")
optimized_kan_metrics = self.train_model(
product_id=product_id,
model_type='kan',
epochs=epochs,
batch_size=batch_size,
learning_rate=learning_rate,
sequence_length=sequence_length,
forecast_horizon=forecast_horizon,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
use_optimized=True
)
# 比较结果
comparison = {
'kan': kan_metrics,
'optimized_kan': optimized_kan_metrics
}
# 打印比较结果
print("\n模型性能比较:")
print(f"{'指标':<10} {'原始KAN':<15} {'优化版KAN':<15} {'改进百分比':<15}")
print("-" * 55)
for metric in ['mse', 'rmse', 'mae', 'mape']:
if metric in kan_metrics and metric in optimized_kan_metrics:
kan_value = kan_metrics[metric]
opt_value = optimized_kan_metrics[metric]
improvement = (kan_value - opt_value) / kan_value * 100 if kan_value != 0 else 0
print(f"{metric.upper():<10} {kan_value:<15.4f} {opt_value:<15.4f} {improvement:<15.2f}%")
# R²值越高越好所以计算改进的方式不同
if 'r2' in kan_metrics and 'r2' in optimized_kan_metrics:
kan_r2 = kan_metrics['r2']
opt_r2 = optimized_kan_metrics['r2']
improvement = (opt_r2 - kan_r2) / (1 - kan_r2) * 100 if kan_r2 != 1 else 0
print(f"{'':<10} {kan_r2:<15.4f} {opt_r2:<15.4f} {improvement:<15.2f}%")
# 训练时间
if 'training_time' in kan_metrics and 'training_time' in optimized_kan_metrics:
kan_time = kan_metrics['training_time']
opt_time = optimized_kan_metrics['training_time']
time_diff = (opt_time - kan_time) / kan_time * 100 if kan_time != 0 else 0
print(f"{'时间(秒)':<10} {kan_time:<15.2f} {opt_time:<15.2f} {time_diff:<15.2f}%")
return comparison
def list_available_models(self, product_id=None, store_id=None, training_mode=None):
"""
列出可用的已训练模型 - 支持多店铺模型
参数:
product_id: 产品ID如果为None则列出所有模型
store_id: 店铺ID用于筛选店铺专属模型
training_mode: 训练模式筛选 ('product', 'store', 'global')
返回:
可用模型列表
"""
if not os.path.exists(self.model_dir):
print(f"模型目录 {self.model_dir} 不存在")
return []
model_files = os.listdir(self.model_dir)
models = []
for file in model_files:
if file.endswith('.pth'):
try:
# 解析模型文件名
model_info = self._parse_model_filename(file)
if model_info:
# 应用过滤条件
if product_id and model_info.get('product_id') != product_id:
continue
if store_id and model_info.get('store_id') != store_id:
continue
if training_mode and model_info.get('training_mode') != training_mode:
continue
model_info['file_name'] = file
model_info['file_path'] = os.path.join(self.model_dir, file)
models.append(model_info)
except Exception as e:
print(f"解析模型文件名失败: {file}, 错误: {e}")
continue
return models
def _parse_model_filename(self, filename):
"""
解析模型文件名,提取模型信息
参数:
filename: 模型文件名
返回:
dict: 模型信息字典
"""
# 移除文件扩展名
name = filename.replace('.pth', '')
# 解析新的多店铺模型命名格式
if '_model_product_' in name:
parts = name.split('_model_product_')
model_type = parts[0]
product_part = parts[1]
# 检查是否是店铺模型 (格式: model_type_model_product_store_id_product_id)
if len(product_part.split('_')) > 1:
store_id = product_part.split('_')[0]
product_id = '_'.join(product_part.split('_')[1:])
training_mode = 'store'
# 检查是否是全局模型 (格式: model_type_model_product_global_product_id_method)
elif product_part.startswith('global_'):
parts = product_part.split('_')
if len(parts) >= 3:
product_id = '_'.join(parts[1:-1])
aggregation_method = parts[-1]
store_id = None
training_mode = 'global'
else:
product_id = product_part
store_id = None
training_mode = 'product'
else:
# 常规产品模型
product_id = product_part
store_id = None
training_mode = 'product'
# 处理优化版KAN模型
if 'optimized' in model_type:
model_type = 'optimized_kan'
return {
'model_type': model_type,
'product_id': product_id,
'store_id': store_id,
'training_mode': training_mode,
'aggregation_method': aggregation_method if training_mode == 'global' and 'aggregation_method' in locals() else None
}
# 处理旧格式的向后兼容性
elif "kan_optimized_model" in name:
model_type = "optimized_kan"
product_id = name.split('_product_')[1] if '_product_' in name else 'unknown'
return {
'model_type': model_type,
'product_id': product_id,
'store_id': None,
'training_mode': 'product',
'aggregation_method': None
}
return None
def delete_model(self, product_id, model_type):
"""
删除已训练的模型
参数:
product_id: 产品ID
model_type: 模型类型
返回:
是否成功删除
"""
model_suffix = '_optimized' if model_type == 'optimized_kan' else ''
model_name = f"{model_type}{model_suffix}_model_product_{product_id}.pth"
model_path = os.path.join(self.model_dir, model_name)
if os.path.exists(model_path):
os.remove(model_path)
print(f"已删除模型: {model_path}")
return True
else:
print(f"模型文件 {model_path} 不存在")
return False