649 lines
26 KiB
Python
649 lines
26 KiB
Python
import os
|
|
import sys
|
|
|
|
# 获取当前脚本所在目录的绝对路径
|
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
|
|
# 将当前目录添加到系统路径
|
|
sys.path.append(current_dir)
|
|
|
|
import pandas as pd
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
from datetime import datetime
|
|
|
|
# 直接从各个模块导入所需的函数和类
|
|
from core.predictor import PharmacyPredictor
|
|
from trainers import train_product_model
|
|
from trainers.mlstm_trainer import train_product_model_with_mlstm
|
|
from trainers.kan_trainer import train_product_model_with_kan
|
|
from trainers.tcn_trainer import train_product_model_with_tcn
|
|
from trainers.transformer_trainer import train_product_model_with_transformer
|
|
from predictors.model_predictor import load_model_and_predict
|
|
|
|
from generate_pharmacy_data import generate_pharmacy_data
|
|
|
|
def clear_screen():
|
|
"""清空控制台屏幕"""
|
|
os.system('cls' if os.name == 'nt' else 'clear')
|
|
|
|
def print_header():
|
|
"""打印应用程序标题"""
|
|
print("=" * 80)
|
|
print(" " * 25 + "药店销售预测系统" + " " * 25)
|
|
print("=" * 80)
|
|
|
|
def main_menu():
|
|
"""显示主菜单并获取用户选择"""
|
|
print("\n主菜单:")
|
|
print("1. 训练所有药品的销售预测模型")
|
|
print("2. 训练单个药品的销售预测模型 (Transformer)")
|
|
print("3. 训练单个药品的销售预测模型 (mLSTM)")
|
|
print("4. 训练单个药品的销售预测模型 (KAN)")
|
|
print("5. 训练单个药品的销售预测模型 (优化版KAN)")
|
|
print("6. 比较原始KAN和优化版KAN模型性能")
|
|
print("7. 训练单个药品的销售预测模型 (TCN)")
|
|
print("8. 查看已有预测结果")
|
|
print("9. 使用已训练的模型进行预测")
|
|
print("10. 比较不同模型的预测结果")
|
|
print("11. 模型管理")
|
|
print("0. 退出")
|
|
|
|
choice = input("\n请输入您的选择 (0-11): ")
|
|
return choice
|
|
|
|
def train_all_products_menu(predictor):
|
|
"""训练所有药品的菜单"""
|
|
clear_screen()
|
|
print_header()
|
|
print("\n请选择要使用的模型类型:")
|
|
print("1. Transformer")
|
|
print("2. mLSTM")
|
|
print("3. KAN")
|
|
print("4. TCN")
|
|
print("0. 返回主菜单")
|
|
|
|
choice = input("\n请输入您的选择 (0-4): ")
|
|
|
|
if choice == '0':
|
|
return
|
|
|
|
model_types = {
|
|
'1': 'transformer',
|
|
'2': 'mlstm',
|
|
'3': 'kan',
|
|
'4': 'tcn'
|
|
}
|
|
|
|
if choice in model_types:
|
|
model_type = model_types[choice]
|
|
|
|
# 获取所有产品ID
|
|
product_ids = predictor.data['product_id'].unique()
|
|
|
|
for product_id in product_ids:
|
|
print(f"\n正在训练产品 {product_id} 的模型...")
|
|
predictor.train_model(product_id, model_type=model_type)
|
|
|
|
print("\n所有产品的模型训练完成!")
|
|
input("\n按Enter键继续...")
|
|
else:
|
|
print("\n无效的选择!")
|
|
input("\n按Enter键继续...")
|
|
|
|
def train_single_product_menu(predictor, model_type, use_optimized=False):
|
|
"""训练单个药品的菜单"""
|
|
clear_screen()
|
|
print_header()
|
|
|
|
model_name = model_type
|
|
if model_type == 'kan' and use_optimized:
|
|
model_name = "优化版KAN"
|
|
|
|
print(f"\n训练单个药品的{model_name}模型")
|
|
|
|
# 获取所有产品ID
|
|
if predictor.data is not None:
|
|
product_ids = predictor.data['product_id'].unique()
|
|
print("\n可用的产品ID:")
|
|
for i, product_id in enumerate(product_ids):
|
|
print(f"{i+1}. {product_id}")
|
|
|
|
choice = input("\n请选择产品ID的编号 (或输入0返回): ")
|
|
|
|
if choice == '0':
|
|
return
|
|
|
|
try:
|
|
idx = int(choice) - 1
|
|
if 0 <= idx < len(product_ids):
|
|
product_id = product_ids[idx]
|
|
|
|
# 获取训练参数
|
|
epochs = int(input("\n请输入训练轮数 (默认100): ") or "100")
|
|
batch_size = int(input("请输入批次大小 (默认32): ") or "32")
|
|
learning_rate = float(input("请输入学习率 (默认0.001): ") or "0.001")
|
|
sequence_length = int(input("请输入输入序列长度 (默认30): ") or "30")
|
|
forecast_horizon = int(input("请输入预测天数 (默认7): ") or "7")
|
|
hidden_size = int(input("请输入隐藏层大小 (默认64): ") or "64")
|
|
num_layers = int(input("请输入层数 (默认2): ") or "2")
|
|
dropout = float(input("请输入Dropout比例 (默认0.1): ") or "0.1")
|
|
|
|
print(f"\n开始训练产品 {product_id} 的{model_name}模型...")
|
|
|
|
if model_type == 'kan' and use_optimized:
|
|
# 使用优化版KAN模型
|
|
metrics = predictor.train_model(
|
|
product_id=product_id,
|
|
model_type='optimized_kan',
|
|
epochs=epochs,
|
|
batch_size=batch_size,
|
|
learning_rate=learning_rate,
|
|
sequence_length=sequence_length,
|
|
forecast_horizon=forecast_horizon,
|
|
hidden_size=hidden_size,
|
|
num_layers=num_layers,
|
|
dropout=dropout
|
|
)
|
|
else:
|
|
# 使用普通模型
|
|
metrics = predictor.train_model(
|
|
product_id=product_id,
|
|
model_type=model_type,
|
|
epochs=epochs,
|
|
batch_size=batch_size,
|
|
learning_rate=learning_rate,
|
|
sequence_length=sequence_length,
|
|
forecast_horizon=forecast_horizon,
|
|
hidden_size=hidden_size,
|
|
num_layers=num_layers,
|
|
dropout=dropout
|
|
)
|
|
|
|
print("\n模型训练完成!")
|
|
print(f"MSE: {metrics['mse']:.4f}")
|
|
print(f"RMSE: {metrics['rmse']:.4f}")
|
|
print(f"MAE: {metrics['mae']:.4f}")
|
|
print(f"R²: {metrics['r2']:.4f}")
|
|
print(f"MAPE: {metrics['mape']:.2f}%")
|
|
print(f"训练时间: {metrics['training_time']:.2f}秒")
|
|
else:
|
|
print("\n无效的选择!")
|
|
except (ValueError, IndexError):
|
|
print("\n无效的输入!")
|
|
else:
|
|
print("\n没有可用的数据。请先生成或加载数据。")
|
|
|
|
input("\n按Enter键继续...")
|
|
|
|
def compare_kan_models_menu(predictor):
|
|
"""比较原始KAN和优化版KAN模型性能的菜单"""
|
|
clear_screen()
|
|
print_header()
|
|
print("\n比较原始KAN和优化版KAN模型性能")
|
|
|
|
# 获取所有产品ID
|
|
if predictor.data is not None:
|
|
product_ids = predictor.data['product_id'].unique()
|
|
print("\n可用的产品ID:")
|
|
for i, product_id in enumerate(product_ids):
|
|
print(f"{i+1}. {product_id}")
|
|
|
|
choice = input("\n请选择产品ID的编号 (或输入0返回): ")
|
|
|
|
if choice == '0':
|
|
return
|
|
|
|
try:
|
|
idx = int(choice) - 1
|
|
if 0 <= idx < len(product_ids):
|
|
product_id = product_ids[idx]
|
|
|
|
# 获取训练参数
|
|
epochs = int(input("\n请输入训练轮数 (默认100): ") or "100")
|
|
batch_size = int(input("请输入批次大小 (默认32): ") or "32")
|
|
learning_rate = float(input("请输入学习率 (默认0.001): ") or "0.001")
|
|
sequence_length = int(input("请输入输入序列长度 (默认30): ") or "30")
|
|
forecast_horizon = int(input("请输入预测天数 (默认7): ") or "7")
|
|
hidden_size = int(input("请输入隐藏层大小 (默认64): ") or "64")
|
|
num_layers = int(input("请输入层数 (默认2): ") or "2")
|
|
dropout = float(input("请输入Dropout比例 (默认0.1): ") or "0.1")
|
|
|
|
print(f"\n开始比较产品 {product_id} 的原始KAN和优化版KAN模型性能...")
|
|
|
|
comparison = predictor.compare_kan_models(
|
|
product_id=product_id,
|
|
epochs=epochs,
|
|
batch_size=batch_size,
|
|
learning_rate=learning_rate,
|
|
sequence_length=sequence_length,
|
|
forecast_horizon=forecast_horizon,
|
|
hidden_size=hidden_size,
|
|
num_layers=num_layers,
|
|
dropout=dropout
|
|
)
|
|
else:
|
|
print("\n无效的选择!")
|
|
except (ValueError, IndexError):
|
|
print("\n无效的输入!")
|
|
else:
|
|
print("\n没有可用的数据。请先生成或加载数据。")
|
|
|
|
input("\n按Enter键继续...")
|
|
|
|
def view_predictions_menu(predictor):
|
|
"""查看已有预测结果的菜单"""
|
|
clear_screen()
|
|
print_header()
|
|
print("\n查看已有预测结果")
|
|
|
|
# 获取所有产品ID
|
|
if predictor.data is not None:
|
|
product_ids = predictor.data['product_id'].unique()
|
|
print("\n可用的产品ID:")
|
|
for i, product_id in enumerate(product_ids):
|
|
print(f"{i+1}. {product_id}")
|
|
|
|
choice = input("\n请选择产品ID的编号 (或输入0返回): ")
|
|
|
|
if choice == '0':
|
|
return
|
|
|
|
try:
|
|
idx = int(choice) - 1
|
|
if 0 <= idx < len(product_ids):
|
|
product_id = product_ids[idx]
|
|
|
|
# 获取可用模型
|
|
models = predictor.list_available_models(product_id)
|
|
|
|
if models:
|
|
print(f"\n产品 {product_id} 的可用模型:")
|
|
for i, model in enumerate(models):
|
|
print(f"{i+1}. {model['model_type']}")
|
|
|
|
model_choice = input("\n请选择模型的编号 (或输入0返回): ")
|
|
|
|
if model_choice == '0':
|
|
return
|
|
|
|
model_idx = int(model_choice) - 1
|
|
if 0 <= model_idx < len(models):
|
|
model_type = models[model_idx]['model_type']
|
|
|
|
# 加载并显示预测结果
|
|
result = predictor.predict(product_id, model_type, analyze_result=True)
|
|
|
|
if result:
|
|
print(f"\n{product_id} - {model_type}模型预测结果已保存到图像文件")
|
|
print(f"文件名: {product_id}_{model_type}_prediction.png")
|
|
else:
|
|
print(f"\n无法加载 {product_id} 的 {model_type} 模型预测结果")
|
|
else:
|
|
print("\n无效的选择!")
|
|
else:
|
|
print(f"\n产品 {product_id} 没有可用的预训练模型")
|
|
else:
|
|
print("\n无效的选择!")
|
|
except (ValueError, IndexError):
|
|
print("\n无效的输入!")
|
|
else:
|
|
print("\n没有可用的数据。请先生成或加载数据。")
|
|
|
|
input("\n按Enter键继续...")
|
|
|
|
def predict_menu(predictor):
|
|
"""使用已训练的模型进行预测的菜单"""
|
|
clear_screen()
|
|
print_header()
|
|
print("\n使用已训练的模型进行预测")
|
|
|
|
# 获取所有产品ID
|
|
if predictor.data is not None:
|
|
product_ids = predictor.data['product_id'].unique()
|
|
print("\n可用的产品ID:")
|
|
for i, product_id in enumerate(product_ids):
|
|
print(f"{i+1}. {product_id}")
|
|
|
|
choice = input("\n请选择产品ID的编号 (或输入0返回): ")
|
|
|
|
if choice == '0':
|
|
return
|
|
|
|
try:
|
|
idx = int(choice) - 1
|
|
if 0 <= idx < len(product_ids):
|
|
product_id = product_ids[idx]
|
|
|
|
# 获取可用模型
|
|
models = predictor.list_available_models(product_id)
|
|
|
|
if models:
|
|
print(f"\n产品 {product_id} 的可用模型:")
|
|
for i, model in enumerate(models):
|
|
print(f"{i+1}. {model['model_type']}")
|
|
|
|
model_choice = input("\n请选择模型的编号 (或输入0返回): ")
|
|
|
|
if model_choice == '0':
|
|
return
|
|
|
|
model_idx = int(model_choice) - 1
|
|
if 0 <= model_idx < len(models):
|
|
model_type = models[model_idx]['model_type']
|
|
|
|
# 获取预测参数
|
|
future_days = int(input("\n请输入预测天数 (默认7): ") or "7")
|
|
analyze = input("是否分析预测结果? (y/n, 默认y): ").lower() != 'n'
|
|
|
|
# 进行预测
|
|
result = predictor.predict(
|
|
product_id,
|
|
model_type,
|
|
future_days=future_days,
|
|
analyze_result=analyze
|
|
)
|
|
|
|
if result:
|
|
print(f"\n{product_id} - {model_type}模型预测结果已保存到图像文件")
|
|
print(f"文件名: {product_id}_{model_type}_prediction.png")
|
|
|
|
if analyze and 'analysis' in result and result['analysis']:
|
|
print("\n预测结果分析:")
|
|
print(result['analysis']['explanation'])
|
|
else:
|
|
print(f"\n无法使用 {product_id} 的 {model_type} 模型进行预测")
|
|
else:
|
|
print("\n无效的选择!")
|
|
else:
|
|
print(f"\n产品 {product_id} 没有可用的预训练模型")
|
|
else:
|
|
print("\n无效的选择!")
|
|
except (ValueError, IndexError):
|
|
print("\n无效的输入!")
|
|
else:
|
|
print("\n没有可用的数据。请先生成或加载数据。")
|
|
|
|
input("\n按Enter键继续...")
|
|
|
|
def compare_models_menu(predictor):
|
|
"""比较不同模型的预测结果的菜单"""
|
|
clear_screen()
|
|
print_header()
|
|
print("\n比较不同模型的预测结果")
|
|
|
|
# 获取所有产品ID
|
|
if predictor.data is not None:
|
|
product_ids = predictor.data['product_id'].unique()
|
|
print("\n可用的产品ID:")
|
|
for i, product_id in enumerate(product_ids):
|
|
print(f"{i+1}. {product_id}")
|
|
|
|
choice = input("\n请选择产品ID的编号 (或输入0返回): ")
|
|
|
|
if choice == '0':
|
|
return
|
|
|
|
try:
|
|
idx = int(choice) - 1
|
|
if 0 <= idx < len(product_ids):
|
|
product_id = product_ids[idx]
|
|
|
|
# 获取可用模型
|
|
models = predictor.list_available_models(product_id)
|
|
|
|
if len(models) >= 2:
|
|
print(f"\n产品 {product_id} 的可用模型:")
|
|
for i, model in enumerate(models):
|
|
print(f"{i+1}. {model['model_type']}")
|
|
|
|
print("\n请选择要比较的模型 (输入模型编号,用逗号分隔):")
|
|
model_choices = input("例如: 1,3,4 (或输入0返回): ")
|
|
|
|
if model_choices == '0':
|
|
return
|
|
|
|
try:
|
|
model_indices = [int(x.strip()) - 1 for x in model_choices.split(',')]
|
|
selected_models = []
|
|
|
|
for model_idx in model_indices:
|
|
if 0 <= model_idx < len(models):
|
|
selected_models.append(models[model_idx]['model_type'])
|
|
|
|
if len(selected_models) >= 2:
|
|
print(f"\n比较产品 {product_id} 的 {', '.join(selected_models)} 模型预测结果...")
|
|
|
|
# 获取预测参数
|
|
future_days = int(input("\n请输入预测天数 (默认7): ") or "7")
|
|
|
|
# 比较模型
|
|
predictions = {}
|
|
dates = None
|
|
|
|
for model_type in selected_models:
|
|
result = predictor.predict(
|
|
product_id,
|
|
model_type,
|
|
future_days=future_days
|
|
)
|
|
|
|
if result and 'predictions' in result:
|
|
predictions[model_type] = result['predictions']['predicted_sales'].values
|
|
if dates is None:
|
|
dates = result['predictions']['date'].values
|
|
|
|
if predictions:
|
|
# 绘制比较图
|
|
plt.figure(figsize=(12, 6))
|
|
|
|
for model_type, values in predictions.items():
|
|
plt.plot(dates, values, label=model_type)
|
|
|
|
plt.title(f'产品 {product_id} - 多模型预测结果比较')
|
|
plt.xlabel('日期')
|
|
plt.ylabel('销量')
|
|
plt.legend()
|
|
plt.grid(True)
|
|
plt.xticks(rotation=45)
|
|
plt.tight_layout()
|
|
|
|
# 保存图像
|
|
comparison_file = f'{product_id}_model_comparison.png'
|
|
plt.savefig(comparison_file)
|
|
plt.close()
|
|
|
|
print(f"\n模型比较结果已保存到图像文件: {comparison_file}")
|
|
else:
|
|
print("\n无法获取预测结果进行比较")
|
|
else:
|
|
print("\n请至少选择两个模型进行比较")
|
|
except ValueError:
|
|
print("\n无效的输入!")
|
|
else:
|
|
print(f"\n产品 {product_id} 没有足够的预训练模型进行比较 (至少需要2个)")
|
|
else:
|
|
print("\n无效的选择!")
|
|
except (ValueError, IndexError):
|
|
print("\n无效的输入!")
|
|
else:
|
|
print("\n没有可用的数据。请先生成或加载数据。")
|
|
|
|
input("\n按Enter键继续...")
|
|
|
|
def model_management_menu(predictor):
|
|
"""模型管理菜单"""
|
|
while True:
|
|
clear_screen()
|
|
print_header()
|
|
print("\n模型管理")
|
|
print("\n1. 查看所有模型")
|
|
print("2. 查看特定产品的模型")
|
|
print("3. 删除模型")
|
|
print("0. 返回主菜单")
|
|
|
|
choice = input("\n请输入您的选择 (0-3): ")
|
|
|
|
if choice == '0':
|
|
return
|
|
elif choice == '1':
|
|
# 查看所有模型
|
|
models = predictor.list_available_models()
|
|
|
|
if models:
|
|
print("\n所有可用模型:")
|
|
print(f"{'序号':<5} {'产品ID':<10} {'模型类型':<15} {'文件名'}")
|
|
print("-" * 60)
|
|
|
|
for i, model in enumerate(models):
|
|
print(f"{i+1:<5} {model['product_id']:<10} {model['model_type']:<15} {model['file_name']}")
|
|
else:
|
|
print("\n没有可用的模型")
|
|
|
|
input("\n按Enter键继续...")
|
|
elif choice == '2':
|
|
# 查看特定产品的模型
|
|
if predictor.data is not None:
|
|
product_ids = predictor.data['product_id'].unique()
|
|
print("\n可用的产品ID:")
|
|
|
|
for i, product_id in enumerate(product_ids):
|
|
print(f"{i+1}. {product_id}")
|
|
|
|
product_choice = input("\n请选择产品ID的编号 (或输入0返回): ")
|
|
|
|
if product_choice == '0':
|
|
continue
|
|
|
|
try:
|
|
product_idx = int(product_choice) - 1
|
|
if 0 <= product_idx < len(product_ids):
|
|
product_id = product_ids[product_idx]
|
|
|
|
models = predictor.list_available_models(product_id)
|
|
|
|
if models:
|
|
print(f"\n产品 {product_id} 的可用模型:")
|
|
print(f"{'序号':<5} {'模型类型':<15} {'文件名'}")
|
|
print("-" * 50)
|
|
|
|
for i, model in enumerate(models):
|
|
print(f"{i+1:<5} {model['model_type']:<15} {model['file_name']}")
|
|
else:
|
|
print(f"\n产品 {product_id} 没有可用的模型")
|
|
else:
|
|
print("\n无效的选择!")
|
|
except (ValueError, IndexError):
|
|
print("\n无效的输入!")
|
|
else:
|
|
print("\n没有可用的数据。请先生成或加载数据。")
|
|
|
|
input("\n按Enter键继续...")
|
|
elif choice == '3':
|
|
# 删除模型
|
|
if predictor.data is not None:
|
|
product_ids = predictor.data['product_id'].unique()
|
|
print("\n可用的产品ID:")
|
|
|
|
for i, product_id in enumerate(product_ids):
|
|
print(f"{i+1}. {product_id}")
|
|
|
|
product_choice = input("\n请选择产品ID的编号 (或输入0返回): ")
|
|
|
|
if product_choice == '0':
|
|
continue
|
|
|
|
try:
|
|
product_idx = int(product_choice) - 1
|
|
if 0 <= product_idx < len(product_ids):
|
|
product_id = product_ids[product_idx]
|
|
|
|
models = predictor.list_available_models(product_id)
|
|
|
|
if models:
|
|
print(f"\n产品 {product_id} 的可用模型:")
|
|
print(f"{'序号':<5} {'模型类型':<15} {'文件名'}")
|
|
print("-" * 50)
|
|
|
|
for i, model in enumerate(models):
|
|
print(f"{i+1:<5} {model['model_type']:<15} {model['file_name']}")
|
|
|
|
model_choice = input("\n请选择要删除的模型编号 (或输入0返回): ")
|
|
|
|
if model_choice == '0':
|
|
continue
|
|
|
|
model_idx = int(model_choice) - 1
|
|
if 0 <= model_idx < len(models):
|
|
model_type = models[model_idx]['model_type']
|
|
|
|
confirm = input(f"\n确定要删除产品 {product_id} 的 {model_type} 模型吗? (y/n): ").lower()
|
|
|
|
if confirm == 'y':
|
|
if predictor.delete_model(product_id, model_type):
|
|
print(f"\n已成功删除产品 {product_id} 的 {model_type} 模型")
|
|
else:
|
|
print(f"\n删除产品 {product_id} 的 {model_type} 模型失败")
|
|
else:
|
|
print("\n无效的选择!")
|
|
else:
|
|
print(f"\n产品 {product_id} 没有可用的模型")
|
|
else:
|
|
print("\n无效的选择!")
|
|
except (ValueError, IndexError):
|
|
print("\n无效的输入!")
|
|
else:
|
|
print("\n没有可用的数据。请先生成或加载数据。")
|
|
|
|
input("\n按Enter键继续...")
|
|
else:
|
|
print("\n无效的选择!")
|
|
input("\n按Enter键继续...")
|
|
|
|
def main():
|
|
"""主函数"""
|
|
# 检查数据文件是否存在,如果不存在则生成模拟数据
|
|
if not os.path.exists('pharmacy_sales.xlsx'):
|
|
print("数据文件不存在,正在生成模拟数据...")
|
|
generate_pharmacy_data()
|
|
|
|
# 创建预测器实例
|
|
predictor = PharmacyPredictor()
|
|
|
|
while True:
|
|
clear_screen()
|
|
print_header()
|
|
|
|
choice = main_menu()
|
|
|
|
if choice == '0':
|
|
print("\n感谢使用药店销售预测系统!再见!")
|
|
break
|
|
elif choice == '1':
|
|
train_all_products_menu(predictor)
|
|
elif choice == '2':
|
|
train_single_product_menu(predictor, 'transformer')
|
|
elif choice == '3':
|
|
train_single_product_menu(predictor, 'mlstm')
|
|
elif choice == '4':
|
|
train_single_product_menu(predictor, 'kan')
|
|
elif choice == '5':
|
|
train_single_product_menu(predictor, 'kan', use_optimized=True)
|
|
elif choice == '6':
|
|
compare_kan_models_menu(predictor)
|
|
elif choice == '7':
|
|
train_single_product_menu(predictor, 'tcn')
|
|
elif choice == '8':
|
|
view_predictions_menu(predictor)
|
|
elif choice == '9':
|
|
predict_menu(predictor)
|
|
elif choice == '10':
|
|
compare_models_menu(predictor)
|
|
elif choice == '11':
|
|
model_management_menu(predictor)
|
|
else:
|
|
print("\n无效的选择!请重试。")
|
|
input("\n按Enter键继续...")
|
|
|
|
if __name__ == "__main__":
|
|
main() |