ShopTRAINING/server/core/predictor.py

484 lines
19 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
药店销售预测系统 - 核心预测器类
支持多店铺销售预测功能
"""
import os
import pandas as pd
import numpy as np
import torch
import time
import matplotlib.pyplot as plt
from datetime import datetime
# from trainers import (
# train_product_model_with_mlstm,
# train_product_model_with_kan,
# train_product_model_with_tcn,
# train_product_model_with_transformer
# )
# 上述导入已不再需要,因为我们现在通过模型注册表动态获取训练器
from predictors.model_predictor import load_model_and_predict
from utils.data_utils import prepare_data, prepare_sequences
from utils.multi_store_data_utils import load_multi_store_data, aggregate_multi_store_data
# 导入新的特征选择模块
from utils.feature_selection import get_feature_list_for_model
from analysis.metrics import evaluate_model
from core.config import DEVICE, DEFAULT_MODEL_DIR, DEFAULT_DATA_PATH
class PharmacyPredictor:
"""
药店销售预测系统核心类,用于训练模型和进行预测
"""
def __init__(self, data_path=None, model_dir=DEFAULT_MODEL_DIR):
"""
初始化预测器
参数:
data_path: 数据文件路径默认使用多店铺CSV文件
model_dir: 模型保存目录
"""
# 设置默认数据路径为多店铺CSV文件
if data_path is None:
data_path = DEFAULT_DATA_PATH
self.data_path = data_path
self.model_dir = model_dir
self.device = DEVICE
if not os.path.exists(model_dir):
os.makedirs(model_dir)
print(f"使用设备: {self.device}")
# 尝试加载多店铺数据
try:
self.data = load_multi_store_data(data_path)
print(f"已加载多店铺数据,来源: {data_path}")
except Exception as e:
print(f"加载数据失败: {e}")
self.data = None
def train_model(self, product_id, model_type='transformer', epochs=100, batch_size=32,
learning_rate=0.001, sequence_length=30, forecast_horizon=7,
hidden_size=64, num_layers=2, dropout=0.1, use_optimized=False,
store_id=None, training_mode='product', aggregation_method='sum',
socketio=None, task_id=None, version=None, continue_training=False,
progress_callback=None):
"""
训练预测模型 - 支持多店铺训练
参数:
product_id: 产品ID
model_type: 模型类型 ('transformer', 'mlstm', 'kan', 'tcn', 'optimized_kan')
epochs: 训练轮次
batch_size: 批次大小
learning_rate: 学习率
sequence_length: 输入序列长度
forecast_horizon: 预测天数
hidden_size: 隐藏层大小
num_layers: 层数
dropout: Dropout比例
use_optimized: 是否使用优化版KAN仅当model_type为'kan'时有效)
store_id: 店铺ID仅当training_mode为'store'时使用)
training_mode: 训练模式 ('product', 'store', 'global')
aggregation_method: 聚合方法 ('sum', 'mean', 'median') - 仅用于全局训练
返回:
metrics: 模型评估指标
"""
# 创建统一的输出函数
def log_message(message, log_type='info'):
"""统一的日志输出函数"""
print(message, flush=True) # 始终输出到控制台
# 如果有进度回调,也发送到回调
if progress_callback:
try:
progress_callback({
'log_type': log_type,
'message': message
})
except Exception as e:
print(f"进度回调失败: {e}", flush=True)
if self.data is None:
log_message("没有可用的数据,请先加载或生成数据", 'error')
return None
# --- 新数据管道 ---
# 1. 加载完整的、经过基础标准化的数据
# 注意此时的load_multi_store_data已经过改造不再创造特征
full_data = load_multi_store_data(self.data_path)
if full_data.empty:
log_message("错误加载数据后得到空的DataFrame。", 'error')
return None, None
# 2. 根据训练模式,筛选出本次训练所需的数据子集
if training_mode == 'product':
training_df = full_data[full_data['product_id'] == product_id].copy()
model_identifier = product_id
elif training_mode == 'store':
training_df = full_data[full_data['store_id'] == store_id].copy()
model_identifier = f"store_{store_id}"
elif training_mode == 'global':
# 全局模型使用所有数据
training_df = full_data.copy()
model_identifier = f"global_{aggregation_method}"
else:
log_message(f"不支持的训练模式: {training_mode}", 'error')
return None, None
if training_df.empty:
log_message(f"错误:根据训练模式 '{training_mode}' 和标识 '{product_id or store_id}' 筛选后,没有剩余数据。", 'error')
return None, None
log_message(f"数据筛选完成,用于训练的记录数: {len(training_df)}")
# 3. 根据模型类型,获取专属的特征列表
all_columns = training_df.columns.tolist()
feature_list = get_feature_list_for_model(model_type, all_columns)
if not feature_list:
log_message(f"错误:未能为模型 '{model_type}' 获取任何特征。", 'error')
return None, None
log_message(f"为模型 '{model_type}' 选择了 {len(feature_list)} 个特征: {feature_list[:5]}...")
# 4. 调用相应的训练器,并传入数据和特征列表
try:
from models.model_registry import get_trainer
trainer_function = get_trainer(model_type)
# --- 临时添加:用于快速测试的调试模式 ---
debug_fast_mode = True
if debug_fast_mode:
print("🚀 快速测试模式已激活截取前100条数据进行训练。")
training_df = training_df.head(66)
# ------------------------------------
# 准备通用参数
trainer_args = {
'product_id': product_id,
'model_identifier': model_identifier,
'training_df': training_df, # 传入筛选后的数据
'feature_list': feature_list, # 传入选择好的特征
'store_id': store_id,
'training_mode': training_mode,
'aggregation_method': aggregation_method,
'epochs': epochs,
'sequence_length': sequence_length,
'forecast_horizon': forecast_horizon,
'model_dir': self.model_dir,
'socketio': socketio,
'task_id': task_id,
'progress_callback': progress_callback,
'version': version,
'continue_training': continue_training,
'use_optimized': use_optimized
}
# 动态过滤不兼容的参数
import inspect
sig = inspect.signature(trainer_function)
valid_args = {k: v for k, v in trainer_args.items() if k in sig.parameters}
log_message(f"准备调用 {trainer_function.__name__}...")
result = trainer_function(**valid_args)
# 解析返回结果
if isinstance(result, tuple) and len(result) >= 2:
metrics, artifacts = result[0], result[1]
else:
log_message(f"训练器返回格式未知: {type(result)}", 'warning')
return None, None
# 在返回的metrics中添加训练信息
if metrics:
metrics.update({
'training_mode': training_mode,
'store_id': store_id,
'product_id': product_id,
'model_identifier': model_identifier,
'aggregation_method': aggregation_method if training_mode == 'global' else None
})
return metrics, artifacts
except Exception as e:
import traceback
log_message(f"模型训练失败: {e}\n{traceback.format_exc()}", 'error')
return None, None
def predict(self, product_id, model_type, future_days=7, start_date=None, analyze_result=False, version=None,
store_id=None, training_mode='product', aggregation_method='sum'):
"""
使用已训练的模型进行预测 - 支持多店铺预测
参数:
product_id: 产品ID
model_type: 模型类型
future_days: 预测未来天数
start_date: 预测起始日期
analyze_result: 是否分析预测结果
version: 模型版本如果为None则使用最新版本
store_id: 店铺ID仅当training_mode为'store'时使用)
training_mode: 训练模式 ('product', 'store', 'global')
aggregation_method: 聚合方法 ('sum', 'mean', 'median') - 仅用于全局预测
返回:
预测结果和分析如果analyze_result为True
"""
# 根据训练模式构建模型标识符 (v2 修正)
if training_mode == 'store' and store_id:
model_identifier = f"store_{store_id}"
elif training_mode == 'global':
# 全局模型的标识符不应依赖于单个product_id
model_identifier = f"global_{aggregation_method}"
else: # product mode
model_identifier = product_id
return load_model_and_predict(
model_identifier,
model_type,
store_id=store_id,
future_days=future_days,
start_date=start_date,
analyze_result=analyze_result,
version=version,
training_mode=training_mode
)
def train_optimized_kan_model(self, product_id, epochs=100, batch_size=32,
learning_rate=0.001, sequence_length=30, forecast_horizon=7,
hidden_size=64, num_layers=2, dropout=0.1):
"""
训练优化版KAN模型便捷方法
参数与train_model相同但固定model_type为'kan'且use_optimized为True
"""
return self.train_model(
product_id=product_id,
model_type='kan',
epochs=epochs,
batch_size=batch_size,
learning_rate=learning_rate,
sequence_length=sequence_length,
forecast_horizon=forecast_horizon,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
use_optimized=True
)
def compare_kan_models(self, product_id, epochs=100, batch_size=32,
learning_rate=0.001, sequence_length=30, forecast_horizon=7,
hidden_size=64, num_layers=2, dropout=0.1):
"""
比较原始KAN和优化版KAN模型性能
参数与train_model相同
返回:
比较结果字典
"""
print(f"开始比较产品 {product_id} 的原始KAN和优化版KAN模型性能...")
# 训练原始KAN模型
print("\n训练原始KAN模型...")
kan_metrics = self.train_model(
product_id=product_id,
model_type='kan',
epochs=epochs,
batch_size=batch_size,
learning_rate=learning_rate,
sequence_length=sequence_length,
forecast_horizon=forecast_horizon,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
use_optimized=False
)
# 训练优化版KAN模型
print("\n训练优化版KAN模型...")
optimized_kan_metrics = self.train_model(
product_id=product_id,
model_type='kan',
epochs=epochs,
batch_size=batch_size,
learning_rate=learning_rate,
sequence_length=sequence_length,
forecast_horizon=forecast_horizon,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
use_optimized=True
)
# 比较结果
comparison = {
'kan': kan_metrics,
'optimized_kan': optimized_kan_metrics
}
# 打印比较结果
print("\n模型性能比较:")
print(f"{'指标':<10} {'原始KAN':<15} {'优化版KAN':<15} {'改进百分比':<15}")
print("-" * 55)
for metric in ['mse', 'rmse', 'mae', 'mape']:
if metric in kan_metrics and metric in optimized_kan_metrics:
kan_value = kan_metrics[metric]
opt_value = optimized_kan_metrics[metric]
improvement = (kan_value - opt_value) / kan_value * 100 if kan_value != 0 else 0
print(f"{metric.upper():<10} {kan_value:<15.4f} {opt_value:<15.4f} {improvement:<15.2f}%")
# R²值越高越好所以计算改进的方式不同
if 'r2' in kan_metrics and 'r2' in optimized_kan_metrics:
kan_r2 = kan_metrics['r2']
opt_r2 = optimized_kan_metrics['r2']
improvement = (opt_r2 - kan_r2) / (1 - kan_r2) * 100 if kan_r2 != 1 else 0
print(f"{'':<10} {kan_r2:<15.4f} {opt_r2:<15.4f} {improvement:<15.2f}%")
# 训练时间
if 'training_time' in kan_metrics and 'training_time' in optimized_kan_metrics:
kan_time = kan_metrics['training_time']
opt_time = optimized_kan_metrics['training_time']
time_diff = (opt_time - kan_time) / kan_time * 100 if kan_time != 0 else 0
print(f"{'时间(秒)':<10} {kan_time:<15.2f} {opt_time:<15.2f} {time_diff:<15.2f}%")
return comparison
def list_available_models(self, product_id=None, store_id=None, training_mode=None):
"""
列出可用的已训练模型 - 支持多店铺模型
参数:
product_id: 产品ID如果为None则列出所有模型
store_id: 店铺ID用于筛选店铺专属模型
training_mode: 训练模式筛选 ('product', 'store', 'global')
返回:
可用模型列表
"""
if not os.path.exists(self.model_dir):
print(f"模型目录 {self.model_dir} 不存在")
return []
model_files = os.listdir(self.model_dir)
models = []
for file in model_files:
if file.endswith('.pth'):
try:
# 解析模型文件名
model_info = self._parse_model_filename(file)
if model_info:
# 应用过滤条件
if product_id and model_info.get('product_id') != product_id:
continue
if store_id and model_info.get('store_id') != store_id:
continue
if training_mode and model_info.get('training_mode') != training_mode:
continue
model_info['file_name'] = file
model_info['file_path'] = os.path.join(self.model_dir, file)
models.append(model_info)
except Exception as e:
print(f"解析模型文件名失败: {file}, 错误: {e}")
continue
return models
def _parse_model_filename(self, filename):
"""
解析模型文件名,提取模型信息
参数:
filename: 模型文件名
返回:
dict: 模型信息字典
"""
# 移除文件扩展名
name = filename.replace('.pth', '')
# 解析新的多店铺模型命名格式
if '_model_product_' in name:
parts = name.split('_model_product_')
model_type = parts[0]
product_part = parts[1]
# 检查是否是店铺模型 (格式: model_type_model_product_store_id_product_id)
if len(product_part.split('_')) > 1:
store_id = product_part.split('_')[0]
product_id = '_'.join(product_part.split('_')[1:])
training_mode = 'store'
# 检查是否是全局模型 (格式: model_type_model_product_global_product_id_method)
elif product_part.startswith('global_'):
parts = product_part.split('_')
if len(parts) >= 3:
product_id = '_'.join(parts[1:-1])
aggregation_method = parts[-1]
store_id = None
training_mode = 'global'
else:
product_id = product_part
store_id = None
training_mode = 'product'
else:
# 常规产品模型
product_id = product_part
store_id = None
training_mode = 'product'
# 处理优化版KAN模型
if 'optimized' in model_type:
model_type = 'optimized_kan'
return {
'model_type': model_type,
'product_id': product_id,
'store_id': store_id,
'training_mode': training_mode,
'aggregation_method': aggregation_method if training_mode == 'global' and 'aggregation_method' in locals() else None
}
# 处理旧格式的向后兼容性
elif "kan_optimized_model" in name:
model_type = "optimized_kan"
product_id = name.split('_product_')[1] if '_product_' in name else 'unknown'
return {
'model_type': model_type,
'product_id': product_id,
'store_id': None,
'training_mode': 'product',
'aggregation_method': None
}
return None
def delete_model(self, product_id, model_type):
"""
删除已训练的模型
参数:
product_id: 产品ID
model_type: 模型类型
返回:
是否成功删除
"""
model_suffix = '_optimized' if model_type == 'optimized_kan' else ''
model_name = f"{model_type}{model_suffix}_model_product_{product_id}.pth"
model_path = os.path.join(self.model_dir, model_name)
if os.path.exists(model_path):
os.remove(model_path)
print(f"已删除模型: {model_path}")
return True
else:
print(f"模型文件 {model_path} 不存在")
return False