ShopTRAINING/run_pharmacy_prediction.py

279 lines
11 KiB
Python
Raw Normal View History

import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from datetime import datetime
import sys
# 导入模块
from pharmacy_predictor import train_product_model, train_product_model_with_mlstm, train_product_model_with_kan, train_product_model_with_transformer
from pharmacy_predictor import load_model_and_predict, load_kan_model_and_predict
from models import ModelManager
from pharmacy_predictor import PharmacyPredictor
from generate_pharmacy_data import generate_pharmacy_data
def clear_screen():
"""清空控制台屏幕"""
os.system('cls' if os.name == 'nt' else 'clear')
def print_header():
"""打印应用程序标题"""
print("=" * 80)
print(" " * 25 + "药店销售预测系统" + " " * 25)
print("=" * 80)
def main_menu():
"""显示主菜单并获取用户选择"""
print("\n主菜单:")
print("1. 训练所有药品的销售预测模型")
print("2. 训练单个药品的销售预测模型 (Transformer)")
print("3. 训练单个药品的销售预测模型 (mLSTM)")
print("4. 训练单个药品的销售预测模型 (KAN)")
print("5. 训练单个药品的销售预测模型 (优化版KAN)") # 新增选项
print("6. 比较原始KAN和优化版KAN模型性能") # 新增选项
print("7. 查看已有预测结果")
print("8. 使用已训练的模型进行预测")
print("9. 比较不同模型的预测结果")
print("10. 模型管理")
print("0. 退出")
choice = input("\n请输入您的选择 (0-10): ")
return choice
def train_all_products_menu(predictor):
"""训练所有药品的菜单"""
clear_screen()
print_header()
print("\n选择模型类型:")
print("1. Transformer")
print("2. mLSTM")
print("3. KAN")
print("4. 优化版KAN") # 新增选项
print("0. 返回主菜单")
choice = input("\n请输入您的选择 (0-4): ")
if choice == '0':
return
model_types = {
'1': 'transformer',
'2': 'mlstm',
'3': 'kan',
'4': 'kan' # 使用相同的模型类型但在训练时指定use_optimized=True
}
if choice in model_types:
model_type = model_types[choice]
use_optimized = (choice == '4') # 如果选择4则使用优化版KAN
# 获取所有产品ID
product_ids = predictor.data['product_id'].unique()
for product_id in product_ids:
print(f"\n正在训练产品 {product_id} 的模型...")
if model_type == 'kan' and use_optimized:
predictor.train_optimized_kan_model(product_id)
else:
predictor.train_model(product_id, model_type=model_type)
print("\n所有产品的模型训练完成!")
input("\n按Enter键继续...")
else:
print("\n无效的选择!")
input("\n按Enter键继续...")
def train_single_product_menu(predictor, model_type, use_optimized=False):
"""训练单个药品的菜单"""
clear_screen()
print_header()
model_name = model_type
if model_type == 'kan' and use_optimized:
model_name = "优化版KAN"
print(f"\n训练单个药品的{model_name}模型")
# 获取所有产品ID
if predictor.data is not None:
product_ids = predictor.data['product_id'].unique()
print("\n可用的产品ID:")
for i, product_id in enumerate(product_ids):
print(f"{i+1}. {product_id}")
choice = input("\n请选择产品ID的编号 (或输入0返回): ")
if choice == '0':
return
try:
idx = int(choice) - 1
if 0 <= idx < len(product_ids):
product_id = product_ids[idx]
# 获取训练参数
epochs = int(input("\n请输入训练轮数 (默认100): ") or "100")
batch_size = int(input("请输入批次大小 (默认32): ") or "32")
learning_rate = float(input("请输入学习率 (默认0.001): ") or "0.001")
sequence_length = int(input("请输入输入序列长度 (默认30): ") or "30")
forecast_horizon = int(input("请输入预测天数 (默认7): ") or "7")
hidden_size = int(input("请输入隐藏层大小 (默认64): ") or "64")
num_layers = int(input("请输入层数 (默认2): ") or "2")
dropout = float(input("请输入Dropout比例 (默认0.1): ") or "0.1")
print(f"\n开始训练产品 {product_id}{model_name}模型...")
if model_type == 'kan' and use_optimized:
# 使用优化版KAN模型
metrics = predictor.train_optimized_kan_model(
product_id=product_id,
epochs=epochs,
batch_size=batch_size,
learning_rate=learning_rate,
sequence_length=sequence_length,
forecast_horizon=forecast_horizon,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout
)
else:
# 使用普通模型
metrics = predictor.train_model(
product_id=product_id,
model_type=model_type,
epochs=epochs,
batch_size=batch_size,
learning_rate=learning_rate,
sequence_length=sequence_length,
forecast_horizon=forecast_horizon,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout
)
print("\n模型训练完成!")
print(f"MSE: {metrics['mse']:.4f}")
print(f"RMSE: {metrics['rmse']:.4f}")
print(f"MAE: {metrics['mae']:.4f}")
print(f"R²: {metrics['r2']:.4f}")
print(f"MAPE: {metrics['mape']:.2f}%")
print(f"训练时间: {metrics['training_time']:.2f}")
else:
print("\n无效的选择!")
except (ValueError, IndexError):
print("\n无效的输入!")
else:
print("\n没有可用的数据。请先生成或加载数据。")
input("\n按Enter键继续...")
def compare_kan_models_menu(predictor):
"""比较原始KAN和优化版KAN模型性能的菜单"""
clear_screen()
print_header()
print("\n比较原始KAN和优化版KAN模型性能")
# 获取所有产品ID
if predictor.data is not None:
product_ids = predictor.data['product_id'].unique()
print("\n可用的产品ID:")
for i, product_id in enumerate(product_ids):
print(f"{i+1}. {product_id}")
choice = input("\n请选择产品ID的编号 (或输入0返回): ")
if choice == '0':
return
try:
idx = int(choice) - 1
if 0 <= idx < len(product_ids):
product_id = product_ids[idx]
# 获取训练参数
epochs = int(input("\n请输入训练轮数 (默认100): ") or "100")
batch_size = int(input("请输入批次大小 (默认32): ") or "32")
learning_rate = float(input("请输入学习率 (默认0.001): ") or "0.001")
sequence_length = int(input("请输入输入序列长度 (默认30): ") or "30")
forecast_horizon = int(input("请输入预测天数 (默认7): ") or "7")
hidden_size = int(input("请输入隐藏层大小 (默认64): ") or "64")
num_layers = int(input("请输入层数 (默认2): ") or "2")
dropout = float(input("请输入Dropout比例 (默认0.1): ") or "0.1")
print(f"\n开始比较产品 {product_id} 的原始KAN和优化版KAN模型...")
# 使用比较方法
comparison = predictor.compare_kan_models(
product_id=product_id,
epochs=epochs,
batch_size=batch_size,
learning_rate=learning_rate,
sequence_length=sequence_length,
forecast_horizon=forecast_horizon,
hidden_size=hidden_size,
num_layers=num_layers,
dropout=dropout
)
print("\n比较完成!")
else:
print("\n无效的选择!")
except (ValueError, IndexError):
print("\n无效的输入!")
else:
print("\n没有可用的数据。请先生成或加载数据。")
input("\n按Enter键继续...")
def main():
"""主函数"""
# 检查数据文件是否存在,如果不存在则生成
data_path = 'pharmacy_sales.xlsx'
if not os.path.exists(data_path):
print("生成药店销售数据...")
generate_pharmacy_data(num_products=10, days=365, output_file=data_path)
print(f"数据已生成并保存到 {data_path}")
# 创建预测器实例
predictor = PharmacyPredictor(data_path=data_path)
while True:
clear_screen()
print_header()
choice = main_menu()
if choice == '0':
print("\n感谢使用药店销售预测系统,再见!")
break
elif choice == '1':
train_all_products_menu(predictor)
elif choice == '2':
train_single_product_menu(predictor, model_type='transformer')
elif choice == '3':
train_single_product_menu(predictor, model_type='mlstm')
elif choice == '4':
train_single_product_menu(predictor, model_type='kan')
elif choice == '5':
# 使用优化版KAN模型
train_single_product_menu(predictor, model_type='kan', use_optimized=True)
elif choice == '6':
# 比较原始KAN和优化版KAN模型性能
compare_kan_models_menu(predictor)
elif choice == '7':
# 查看已有预测结果
pass
elif choice == '8':
# 使用已训练的模型进行预测
pass
elif choice == '9':
# 比较不同模型的预测结果
pass
elif choice == '10':
# 模型管理
pass
else:
print("\n无效的选择,请重试!")
input("\n按Enter键继续...")
if __name__ == "__main__":
main()