ShopTRAINING/server/core/predictor.py
2025-07-02 11:05:23 +08:00

514 lines
20 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
药店销售预测系统 - 核心预测器类
支持多店铺销售预测功能
"""
import os
import pandas as pd
import numpy as np
import torch
import time
import matplotlib.pyplot as plt
from datetime import datetime
from trainers import (
train_product_model_with_mlstm,
train_product_model_with_kan,
train_product_model_with_tcn,
train_product_model_with_transformer
)
from predictors.model_predictor import load_model_and_predict
from utils.data_utils import prepare_data, prepare_sequences
from utils.multi_store_data_utils import (
load_multi_store_data,
get_store_product_sales_data,
aggregate_multi_store_data
)
from analysis.metrics import evaluate_model
from core.config import DEVICE, DEFAULT_MODEL_DIR, DEFAULT_DATA_PATH
class PharmacyPredictor:
"""
药店销售预测系统核心类,用于训练模型和进行预测
"""
def __init__(self, data_path=None, model_dir=DEFAULT_MODEL_DIR):
"""
初始化预测器
参数:
data_path: 数据文件路径默认使用多店铺CSV文件
model_dir: 模型保存目录
"""
# 设置默认数据路径为多店铺CSV文件
if data_path is None:
data_path = 'pharmacy_sales_multi_store.csv'
self.data_path = data_path
self.model_dir = model_dir
self.device = DEVICE
if not os.path.exists(model_dir):
os.makedirs(model_dir)
print(f"使用设备: {self.device}")
# 尝试加载多店铺数据
try:
self.data = load_multi_store_data(data_path)
print(f"已加载多店铺数据,来源: {data_path}")
except Exception as e:
print(f"加载数据失败: {e}")
self.data = None
def train_model(self, product_id, model_type='transformer', epochs=100, batch_size=32,
learning_rate=0.001, sequence_length=30, forecast_horizon=7,
hidden_size=64, num_layers=2, dropout=0.1, use_optimized=False,
store_id=None, training_mode='product', aggregation_method='sum',
socketio=None, task_id=None, version=None, continue_training=False,
progress_callback=None):
"""
训练预测模型 - 支持多店铺训练
参数:
product_id: 产品ID
model_type: 模型类型 ('transformer', 'mlstm', 'kan', 'tcn', 'optimized_kan')
epochs: 训练轮次
batch_size: 批次大小
learning_rate: 学习率
sequence_length: 输入序列长度
forecast_horizon: 预测天数
hidden_size: 隐藏层大小
num_layers: 层数
dropout: Dropout比例
use_optimized: 是否使用优化版KAN仅当model_type为'kan'时有效)
store_id: 店铺ID仅当training_mode为'store'时使用)
training_mode: 训练模式 ('product', 'store', 'global')
aggregation_method: 聚合方法 ('sum', 'mean', 'median') - 仅用于全局训练
返回:
metrics: 模型评估指标
"""
# 创建统一的输出函数
def log_message(message, log_type='info'):
"""统一的日志输出函数"""
print(message, flush=True) # 始终输出到控制台
# 如果有进度回调,也发送到回调
if progress_callback:
try:
progress_callback({
'log_type': log_type,
'message': message
})
except Exception as e:
print(f"进度回调失败: {e}", flush=True)
if self.data is None:
log_message("没有可用的数据,请先加载或生成数据", 'error')
return None
# 根据训练模式准备数据
if training_mode == 'product':
# 按产品训练:使用所有店铺的该产品数据
product_data = self.data[self.data['product_id'] == product_id].copy()
if product_data.empty:
log_message(f"找不到产品 {product_id} 的数据", 'error')
return None
log_message(f"按产品训练模式: 产品 {product_id}, 数据量: {len(product_data)}")
elif training_mode == 'store':
# 按店铺训练:使用特定店铺的特定产品数据
if not store_id:
log_message("店铺训练模式需要指定 store_id", 'error')
return None
try:
product_data = get_store_product_sales_data(
store_id=store_id,
product_id=product_id,
file_path=self.data_path
)
log_message(f"按店铺训练模式: 店铺 {store_id}, 产品 {product_id}, 数据量: {len(product_data)}")
except Exception as e:
log_message(f"获取店铺产品数据失败: {e}", 'error')
return None
elif training_mode == 'global':
# 全局训练:聚合所有店铺的产品数据
try:
product_data = aggregate_multi_store_data(
product_id=product_id,
aggregation_method=aggregation_method,
file_path=self.data_path
)
log_message(f"全局训练模式: 产品 {product_id}, 聚合方法 {aggregation_method}, 数据量: {len(product_data)}")
except Exception as e:
log_message(f"聚合全局数据失败: {e}", 'error')
return None
else:
log_message(f"不支持的训练模式: {training_mode}", 'error')
return None
# 根据训练模式构建模型标识符
if training_mode == 'store':
model_identifier = f"{store_id}_{product_id}"
elif training_mode == 'global':
model_identifier = f"global_{product_id}_{aggregation_method}"
else:
model_identifier = product_id
# 调用相应的训练函数
try:
log_message(f"🤖 开始调用 {model_type} 训练器")
if model_type == 'transformer':
model_result, metrics, actual_version = train_product_model_with_transformer(
product_id,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
epochs=epochs,
model_dir=self.model_dir,
version=version,
socketio=socketio,
task_id=task_id,
continue_training=continue_training
)
log_message(f"{model_type} 训练器返回: metrics={type(metrics)}, version={actual_version}", 'success')
elif model_type == 'mlstm':
_, metrics, _, _ = train_product_model_with_mlstm(
product_id,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
epochs=epochs,
model_dir=self.model_dir,
socketio=socketio,
task_id=task_id,
progress_callback=progress_callback
)
elif model_type == 'kan':
_, metrics = train_product_model_with_kan(
product_id,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
epochs=epochs,
use_optimized=use_optimized,
model_dir=self.model_dir
)
elif model_type == 'optimized_kan':
_, metrics = train_product_model_with_kan(
product_id,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
epochs=epochs,
use_optimized=True,
model_dir=self.model_dir
)
elif model_type == 'tcn':
_, metrics, _, _ = train_product_model_with_tcn(
product_id,
store_id=store_id,
training_mode=training_mode,
aggregation_method=aggregation_method,
epochs=epochs,
model_dir=self.model_dir,
socketio=socketio,
task_id=task_id
)
else:
log_message(f"不支持的模型类型: {model_type}", 'error')
return None
# 检查和打印返回的metrics
log_message(f"📊 训练完成检查返回的metrics: {metrics}")
# 在返回的metrics中添加训练信息
if metrics:
log_message(f"✅ metrics不为空添加训练信息")
metrics.update({
'training_mode': training_mode,
'store_id': store_id,
'product_id': product_id,
'model_identifier': model_identifier,
'aggregation_method': aggregation_method if training_mode == 'global' else None
})
log_message(f"📈 最终返回的metrics: {metrics}", 'success')
else:
log_message(f"⚠️ metrics为空或None", 'warning')
return metrics
except Exception as e:
log_message(f"模型训练失败: {e}", 'error')
return None
def predict(self, product_id, model_type, future_days=7, start_date=None, analyze_result=False, version=None,
store_id=None, training_mode='product', aggregation_method='sum'):
"""
使用已训练的模型进行预测 - 支持多店铺预测
参数:
product_id: 产品ID
model_type: 模型类型
future_days: 预测未来天数
start_date: 预测起始日期
analyze_result: 是否分析预测结果
version: 模型版本如果为None则使用最新版本
store_id: 店铺ID仅当training_mode为'store'时使用)
training_mode: 训练模式 ('product', 'store', 'global')
aggregation_method: 聚合方法 ('sum', 'mean', 'median') - 仅用于全局预测
返回:
预测结果和分析如果analyze_result为True
"""
# 根据训练模式构建模型标识符
if training_mode == 'store' and store_id:
model_identifier = f"{store_id}_{product_id}"
elif training_mode == 'global':
model_identifier = f"global_{product_id}_{aggregation_method}"
else:
model_identifier = product_id
return load_model_and_predict(
model_identifier,
model_type,
future_days=future_days,
start_date=start_date,
analyze_result=analyze_result,
version=version
)
def train_optimized_kan_model(self, product_id, epochs=100, batch_size=32,
learning_rate=0.001, sequence_length=30, forecast_horizon=7,
hidden_size=64, num_layers=2, dropout=0.1):
"""
训练优化版KAN模型便捷方法
参数与train_model相同但固定model_type为'kan'且use_optimized为True
"""
return self.train_model(
product_id=product_id,
model_type='kan',
epochs=epochs,
batch_size=batch_size,
learning_rate=learning_rate,
sequence_length=sequence_length,
forecast_horizon=forecast_horizon,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
use_optimized=True
)
def compare_kan_models(self, product_id, epochs=100, batch_size=32,
learning_rate=0.001, sequence_length=30, forecast_horizon=7,
hidden_size=64, num_layers=2, dropout=0.1):
"""
比较原始KAN和优化版KAN模型性能
参数与train_model相同
返回:
比较结果字典
"""
print(f"开始比较产品 {product_id} 的原始KAN和优化版KAN模型性能...")
# 训练原始KAN模型
print("\n训练原始KAN模型...")
kan_metrics = self.train_model(
product_id=product_id,
model_type='kan',
epochs=epochs,
batch_size=batch_size,
learning_rate=learning_rate,
sequence_length=sequence_length,
forecast_horizon=forecast_horizon,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
use_optimized=False
)
# 训练优化版KAN模型
print("\n训练优化版KAN模型...")
optimized_kan_metrics = self.train_model(
product_id=product_id,
model_type='kan',
epochs=epochs,
batch_size=batch_size,
learning_rate=learning_rate,
sequence_length=sequence_length,
forecast_horizon=forecast_horizon,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
use_optimized=True
)
# 比较结果
comparison = {
'kan': kan_metrics,
'optimized_kan': optimized_kan_metrics
}
# 打印比较结果
print("\n模型性能比较:")
print(f"{'指标':<10} {'原始KAN':<15} {'优化版KAN':<15} {'改进百分比':<15}")
print("-" * 55)
for metric in ['mse', 'rmse', 'mae', 'mape']:
if metric in kan_metrics and metric in optimized_kan_metrics:
kan_value = kan_metrics[metric]
opt_value = optimized_kan_metrics[metric]
improvement = (kan_value - opt_value) / kan_value * 100 if kan_value != 0 else 0
print(f"{metric.upper():<10} {kan_value:<15.4f} {opt_value:<15.4f} {improvement:<15.2f}%")
# R²值越高越好所以计算改进的方式不同
if 'r2' in kan_metrics and 'r2' in optimized_kan_metrics:
kan_r2 = kan_metrics['r2']
opt_r2 = optimized_kan_metrics['r2']
improvement = (opt_r2 - kan_r2) / (1 - kan_r2) * 100 if kan_r2 != 1 else 0
print(f"{'':<10} {kan_r2:<15.4f} {opt_r2:<15.4f} {improvement:<15.2f}%")
# 训练时间
if 'training_time' in kan_metrics and 'training_time' in optimized_kan_metrics:
kan_time = kan_metrics['training_time']
opt_time = optimized_kan_metrics['training_time']
time_diff = (opt_time - kan_time) / kan_time * 100 if kan_time != 0 else 0
print(f"{'时间(秒)':<10} {kan_time:<15.2f} {opt_time:<15.2f} {time_diff:<15.2f}%")
return comparison
def list_available_models(self, product_id=None, store_id=None, training_mode=None):
"""
列出可用的已训练模型 - 支持多店铺模型
参数:
product_id: 产品ID如果为None则列出所有模型
store_id: 店铺ID用于筛选店铺专属模型
training_mode: 训练模式筛选 ('product', 'store', 'global')
返回:
可用模型列表
"""
if not os.path.exists(self.model_dir):
print(f"模型目录 {self.model_dir} 不存在")
return []
model_files = os.listdir(self.model_dir)
models = []
for file in model_files:
if file.endswith('.pth'):
try:
# 解析模型文件名
model_info = self._parse_model_filename(file)
if model_info:
# 应用过滤条件
if product_id and model_info.get('product_id') != product_id:
continue
if store_id and model_info.get('store_id') != store_id:
continue
if training_mode and model_info.get('training_mode') != training_mode:
continue
model_info['file_name'] = file
model_info['file_path'] = os.path.join(self.model_dir, file)
models.append(model_info)
except Exception as e:
print(f"解析模型文件名失败: {file}, 错误: {e}")
continue
return models
def _parse_model_filename(self, filename):
"""
解析模型文件名,提取模型信息
参数:
filename: 模型文件名
返回:
dict: 模型信息字典
"""
# 移除文件扩展名
name = filename.replace('.pth', '')
# 解析新的多店铺模型命名格式
if '_model_product_' in name:
parts = name.split('_model_product_')
model_type = parts[0]
product_part = parts[1]
# 检查是否是店铺模型 (格式: model_type_model_product_store_id_product_id)
if len(product_part.split('_')) > 1:
store_id = product_part.split('_')[0]
product_id = '_'.join(product_part.split('_')[1:])
training_mode = 'store'
# 检查是否是全局模型 (格式: model_type_model_product_global_product_id_method)
elif product_part.startswith('global_'):
parts = product_part.split('_')
if len(parts) >= 3:
product_id = '_'.join(parts[1:-1])
aggregation_method = parts[-1]
store_id = None
training_mode = 'global'
else:
product_id = product_part
store_id = None
training_mode = 'product'
else:
# 常规产品模型
product_id = product_part
store_id = None
training_mode = 'product'
# 处理优化版KAN模型
if 'optimized' in model_type:
model_type = 'optimized_kan'
return {
'model_type': model_type,
'product_id': product_id,
'store_id': store_id,
'training_mode': training_mode,
'aggregation_method': aggregation_method if training_mode == 'global' and 'aggregation_method' in locals() else None
}
# 处理旧格式的向后兼容性
elif "kan_optimized_model" in name:
model_type = "optimized_kan"
product_id = name.split('_product_')[1] if '_product_' in name else 'unknown'
return {
'model_type': model_type,
'product_id': product_id,
'store_id': None,
'training_mode': 'product',
'aggregation_method': None
}
return None
def delete_model(self, product_id, model_type):
"""
删除已训练的模型
参数:
product_id: 产品ID
model_type: 模型类型
返回:
是否成功删除
"""
model_suffix = '_optimized' if model_type == 'optimized_kan' else ''
model_name = f"{model_type}{model_suffix}_model_product_{product_id}.pth"
model_path = os.path.join(self.model_dir, model_name)
if os.path.exists(model_path):
os.remove(model_path)
print(f"已删除模型: {model_path}")
return True
else:
print(f"模型文件 {model_path} 不存在")
return False