ShopTRAINING/server/core/predictor.py

285 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
药店销售预测系统 - 核心预测器类
"""
import os
import pandas as pd
import numpy as np
import torch
import time
import matplotlib.pyplot as plt
from datetime import datetime
from trainers import (
train_product_model_with_mlstm,
train_product_model_with_kan,
train_product_model_with_tcn,
train_product_model_with_transformer
)
from predictors.model_predictor import load_model_and_predict
from utils.data_utils import prepare_data, prepare_sequences
from analysis.metrics import evaluate_model
from core.config import DEVICE, DEFAULT_MODEL_DIR, DEFAULT_DATA_PATH
class PharmacyPredictor:
"""
药店销售预测系统核心类,用于训练模型和进行预测
"""
def __init__(self, data_path=DEFAULT_DATA_PATH, model_dir=DEFAULT_MODEL_DIR):
"""
初始化预测器
参数:
data_path: 数据文件路径
model_dir: 模型保存目录
"""
self.data_path = data_path
self.model_dir = model_dir
self.device = DEVICE
if not os.path.exists(model_dir):
os.makedirs(model_dir)
print(f"使用设备: {self.device}")
if os.path.exists(data_path):
self.data = pd.read_excel(data_path)
print(f"已加载数据,来源: {data_path}")
else:
print(f"数据文件 {data_path} 不存在,请先生成数据")
self.data = None
def train_model(self, product_id, model_type='transformer', epochs=100, batch_size=32,
learning_rate=0.001, sequence_length=30, forecast_horizon=7,
hidden_size=64, num_layers=2, dropout=0.1, use_optimized=False):
"""
训练预测模型
参数:
product_id: 产品ID
model_type: 模型类型 ('transformer', 'mlstm', 'kan', 'tcn', 'optimized_kan')
epochs: 训练轮次
batch_size: 批次大小
learning_rate: 学习率
sequence_length: 输入序列长度
forecast_horizon: 预测天数
hidden_size: 隐藏层大小
num_layers: 层数
dropout: Dropout比例
use_optimized: 是否使用优化版KAN仅当model_type为'kan'时有效)
返回:
metrics: 模型评估指标
"""
if self.data is None:
print("没有可用的数据,请先加载或生成数据")
return None
product_data = self.data[self.data['product_id'] == product_id].copy()
if product_data.empty:
print(f"找不到产品 {product_id} 的数据")
return None
if model_type == 'transformer':
_, metrics = train_product_model_with_transformer(product_id, epochs=epochs, model_dir=self.model_dir)
elif model_type == 'mlstm':
_, metrics = train_product_model_with_mlstm(product_id, epochs=epochs, model_dir=self.model_dir)
elif model_type == 'kan':
_, metrics = train_product_model_with_kan(product_id, epochs=epochs, use_optimized=use_optimized, model_dir=self.model_dir)
elif model_type == 'optimized_kan':
_, metrics = train_product_model_with_kan(product_id, epochs=epochs, use_optimized=True, model_dir=self.model_dir)
elif model_type == 'tcn':
_, metrics = train_product_model_with_tcn(product_id, epochs=epochs, model_dir=self.model_dir)
else:
print(f"不支持的模型类型: {model_type}")
return None
return metrics
def predict(self, product_id, model_type, future_days=7, start_date=None, analyze_result=False):
"""
使用已训练的模型进行预测
参数:
product_id: 产品ID
model_type: 模型类型
future_days: 预测未来天数
start_date: 预测起始日期
analyze_result: 是否分析预测结果
返回:
预测结果和分析如果analyze_result为True
"""
return load_model_and_predict(
product_id,
model_type,
future_days=future_days,
start_date=start_date,
analyze_result=analyze_result
)
def train_optimized_kan_model(self, product_id, epochs=100, batch_size=32,
learning_rate=0.001, sequence_length=30, forecast_horizon=7,
hidden_size=64, num_layers=2, dropout=0.1):
"""
训练优化版KAN模型便捷方法
参数与train_model相同但固定model_type为'kan'且use_optimized为True
"""
return self.train_model(
product_id=product_id,
model_type='kan',
epochs=epochs,
batch_size=batch_size,
learning_rate=learning_rate,
sequence_length=sequence_length,
forecast_horizon=forecast_horizon,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
use_optimized=True
)
def compare_kan_models(self, product_id, epochs=100, batch_size=32,
learning_rate=0.001, sequence_length=30, forecast_horizon=7,
hidden_size=64, num_layers=2, dropout=0.1):
"""
比较原始KAN和优化版KAN模型性能
参数与train_model相同
返回:
比较结果字典
"""
print(f"开始比较产品 {product_id} 的原始KAN和优化版KAN模型性能...")
# 训练原始KAN模型
print("\n训练原始KAN模型...")
kan_metrics = self.train_model(
product_id=product_id,
model_type='kan',
epochs=epochs,
batch_size=batch_size,
learning_rate=learning_rate,
sequence_length=sequence_length,
forecast_horizon=forecast_horizon,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
use_optimized=False
)
# 训练优化版KAN模型
print("\n训练优化版KAN模型...")
optimized_kan_metrics = self.train_model(
product_id=product_id,
model_type='kan',
epochs=epochs,
batch_size=batch_size,
learning_rate=learning_rate,
sequence_length=sequence_length,
forecast_horizon=forecast_horizon,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout,
use_optimized=True
)
# 比较结果
comparison = {
'kan': kan_metrics,
'optimized_kan': optimized_kan_metrics
}
# 打印比较结果
print("\n模型性能比较:")
print(f"{'指标':<10} {'原始KAN':<15} {'优化版KAN':<15} {'改进百分比':<15}")
print("-" * 55)
for metric in ['mse', 'rmse', 'mae', 'mape']:
if metric in kan_metrics and metric in optimized_kan_metrics:
kan_value = kan_metrics[metric]
opt_value = optimized_kan_metrics[metric]
improvement = (kan_value - opt_value) / kan_value * 100 if kan_value != 0 else 0
print(f"{metric.upper():<10} {kan_value:<15.4f} {opt_value:<15.4f} {improvement:<15.2f}%")
# R²值越高越好所以计算改进的方式不同
if 'r2' in kan_metrics and 'r2' in optimized_kan_metrics:
kan_r2 = kan_metrics['r2']
opt_r2 = optimized_kan_metrics['r2']
improvement = (opt_r2 - kan_r2) / (1 - kan_r2) * 100 if kan_r2 != 1 else 0
print(f"{'':<10} {kan_r2:<15.4f} {opt_r2:<15.4f} {improvement:<15.2f}%")
# 训练时间
if 'training_time' in kan_metrics and 'training_time' in optimized_kan_metrics:
kan_time = kan_metrics['training_time']
opt_time = optimized_kan_metrics['training_time']
time_diff = (opt_time - kan_time) / kan_time * 100 if kan_time != 0 else 0
print(f"{'时间(秒)':<10} {kan_time:<15.2f} {opt_time:<15.2f} {time_diff:<15.2f}%")
return comparison
def list_available_models(self, product_id=None):
"""
列出可用的已训练模型
参数:
product_id: 产品ID如果为None则列出所有模型
返回:
可用模型列表
"""
if not os.path.exists(self.model_dir):
print(f"模型目录 {self.model_dir} 不存在")
return []
model_files = os.listdir(self.model_dir)
if product_id:
model_files = [f for f in model_files if f"product_{product_id}" in f]
models = []
for file in model_files:
if file.endswith('.pth'):
# 处理不同的模型文件命名格式
if "kan_optimized_model" in file:
model_type = "optimized_kan"
product_id = file.split('_product_')[1].split('.pth')[0]
elif "_optimized_model" in file:
model_type = "optimized_kan"
product_id = file.split('_product_')[1].split('.pth')[0]
else:
model_type = file.split('_model_product_')[0]
product_id = file.split('_product_')[1].split('.pth')[0]
models.append({
'model_type': model_type,
'product_id': product_id,
'file_name': file,
'file_path': os.path.join(self.model_dir, file)
})
return models
def delete_model(self, product_id, model_type):
"""
删除已训练的模型
参数:
product_id: 产品ID
model_type: 模型类型
返回:
是否成功删除
"""
model_suffix = '_optimized' if model_type == 'optimized_kan' else ''
model_name = f"{model_type}{model_suffix}_model_product_{product_id}.pth"
model_path = os.path.join(self.model_dir, model_name)
if os.path.exists(model_path):
os.remove(model_path)
print(f"已删除模型: {model_path}")
return True
else:
print(f"模型文件 {model_path} 不存在")
return False