
1. 修复前端图表日期排序问题: - 改进 PredictionView.vue 和 HistoryView.vue 中的图表渲染逻辑 - 确保历史数据和预测数据按照正确的日期顺序显示 2. 修复后端API处理: - 解决 optimized_kan 模型类型的路径映射问题 - 添加 JSON 序列化器处理 Pandas Timestamp 对象 - 改进预测数据与历史数据的衔接处理 3. 优化图表样式和用户体验
279 lines
11 KiB
Python
279 lines
11 KiB
Python
import os
|
||
import pandas as pd
|
||
import numpy as np
|
||
import matplotlib.pyplot as plt
|
||
from datetime import datetime
|
||
import sys
|
||
|
||
# 导入模块
|
||
from pharmacy_predictor import train_product_model, train_product_model_with_mlstm, train_product_model_with_kan, train_product_model_with_transformer
|
||
from pharmacy_predictor import load_model_and_predict, load_kan_model_and_predict
|
||
from models import ModelManager
|
||
from pharmacy_predictor import PharmacyPredictor
|
||
from generate_pharmacy_data import generate_pharmacy_data
|
||
|
||
def clear_screen():
|
||
"""清空控制台屏幕"""
|
||
os.system('cls' if os.name == 'nt' else 'clear')
|
||
|
||
def print_header():
|
||
"""打印应用程序标题"""
|
||
print("=" * 80)
|
||
print(" " * 25 + "药店销售预测系统" + " " * 25)
|
||
print("=" * 80)
|
||
|
||
def main_menu():
|
||
"""显示主菜单并获取用户选择"""
|
||
print("\n主菜单:")
|
||
print("1. 训练所有药品的销售预测模型")
|
||
print("2. 训练单个药品的销售预测模型 (Transformer)")
|
||
print("3. 训练单个药品的销售预测模型 (mLSTM)")
|
||
print("4. 训练单个药品的销售预测模型 (KAN)")
|
||
print("5. 训练单个药品的销售预测模型 (优化版KAN)") # 新增选项
|
||
print("6. 比较原始KAN和优化版KAN模型性能") # 新增选项
|
||
print("7. 查看已有预测结果")
|
||
print("8. 使用已训练的模型进行预测")
|
||
print("9. 比较不同模型的预测结果")
|
||
print("10. 模型管理")
|
||
print("0. 退出")
|
||
|
||
choice = input("\n请输入您的选择 (0-10): ")
|
||
return choice
|
||
|
||
def train_all_products_menu(predictor):
|
||
"""训练所有药品的菜单"""
|
||
clear_screen()
|
||
print_header()
|
||
print("\n选择模型类型:")
|
||
print("1. Transformer")
|
||
print("2. mLSTM")
|
||
print("3. KAN")
|
||
print("4. 优化版KAN") # 新增选项
|
||
print("0. 返回主菜单")
|
||
|
||
choice = input("\n请输入您的选择 (0-4): ")
|
||
|
||
if choice == '0':
|
||
return
|
||
|
||
model_types = {
|
||
'1': 'transformer',
|
||
'2': 'mlstm',
|
||
'3': 'kan',
|
||
'4': 'kan' # 使用相同的模型类型,但在训练时指定use_optimized=True
|
||
}
|
||
|
||
if choice in model_types:
|
||
model_type = model_types[choice]
|
||
use_optimized = (choice == '4') # 如果选择4,则使用优化版KAN
|
||
|
||
# 获取所有产品ID
|
||
product_ids = predictor.data['product_id'].unique()
|
||
|
||
for product_id in product_ids:
|
||
print(f"\n正在训练产品 {product_id} 的模型...")
|
||
if model_type == 'kan' and use_optimized:
|
||
predictor.train_optimized_kan_model(product_id)
|
||
else:
|
||
predictor.train_model(product_id, model_type=model_type)
|
||
|
||
print("\n所有产品的模型训练完成!")
|
||
input("\n按Enter键继续...")
|
||
else:
|
||
print("\n无效的选择!")
|
||
input("\n按Enter键继续...")
|
||
|
||
def train_single_product_menu(predictor, model_type, use_optimized=False):
|
||
"""训练单个药品的菜单"""
|
||
clear_screen()
|
||
print_header()
|
||
|
||
model_name = model_type
|
||
if model_type == 'kan' and use_optimized:
|
||
model_name = "优化版KAN"
|
||
|
||
print(f"\n训练单个药品的{model_name}模型")
|
||
|
||
# 获取所有产品ID
|
||
if predictor.data is not None:
|
||
product_ids = predictor.data['product_id'].unique()
|
||
print("\n可用的产品ID:")
|
||
for i, product_id in enumerate(product_ids):
|
||
print(f"{i+1}. {product_id}")
|
||
|
||
choice = input("\n请选择产品ID的编号 (或输入0返回): ")
|
||
|
||
if choice == '0':
|
||
return
|
||
|
||
try:
|
||
idx = int(choice) - 1
|
||
if 0 <= idx < len(product_ids):
|
||
product_id = product_ids[idx]
|
||
|
||
# 获取训练参数
|
||
epochs = int(input("\n请输入训练轮数 (默认100): ") or "100")
|
||
batch_size = int(input("请输入批次大小 (默认32): ") or "32")
|
||
learning_rate = float(input("请输入学习率 (默认0.001): ") or "0.001")
|
||
sequence_length = int(input("请输入输入序列长度 (默认30): ") or "30")
|
||
forecast_horizon = int(input("请输入预测天数 (默认7): ") or "7")
|
||
hidden_size = int(input("请输入隐藏层大小 (默认64): ") or "64")
|
||
num_layers = int(input("请输入层数 (默认2): ") or "2")
|
||
dropout = float(input("请输入Dropout比例 (默认0.1): ") or "0.1")
|
||
|
||
print(f"\n开始训练产品 {product_id} 的{model_name}模型...")
|
||
|
||
if model_type == 'kan' and use_optimized:
|
||
# 使用优化版KAN模型
|
||
metrics = predictor.train_optimized_kan_model(
|
||
product_id=product_id,
|
||
epochs=epochs,
|
||
batch_size=batch_size,
|
||
learning_rate=learning_rate,
|
||
sequence_length=sequence_length,
|
||
forecast_horizon=forecast_horizon,
|
||
hidden_size=hidden_size,
|
||
num_layers=num_layers,
|
||
dropout=dropout
|
||
)
|
||
else:
|
||
# 使用普通模型
|
||
metrics = predictor.train_model(
|
||
product_id=product_id,
|
||
model_type=model_type,
|
||
epochs=epochs,
|
||
batch_size=batch_size,
|
||
learning_rate=learning_rate,
|
||
sequence_length=sequence_length,
|
||
forecast_horizon=forecast_horizon,
|
||
hidden_size=hidden_size,
|
||
num_layers=num_layers,
|
||
dropout=dropout
|
||
)
|
||
|
||
print("\n模型训练完成!")
|
||
print(f"MSE: {metrics['mse']:.4f}")
|
||
print(f"RMSE: {metrics['rmse']:.4f}")
|
||
print(f"MAE: {metrics['mae']:.4f}")
|
||
print(f"R²: {metrics['r2']:.4f}")
|
||
print(f"MAPE: {metrics['mape']:.2f}%")
|
||
print(f"训练时间: {metrics['training_time']:.2f}秒")
|
||
else:
|
||
print("\n无效的选择!")
|
||
except (ValueError, IndexError):
|
||
print("\n无效的输入!")
|
||
else:
|
||
print("\n没有可用的数据。请先生成或加载数据。")
|
||
|
||
input("\n按Enter键继续...")
|
||
|
||
def compare_kan_models_menu(predictor):
|
||
"""比较原始KAN和优化版KAN模型性能的菜单"""
|
||
clear_screen()
|
||
print_header()
|
||
print("\n比较原始KAN和优化版KAN模型性能")
|
||
|
||
# 获取所有产品ID
|
||
if predictor.data is not None:
|
||
product_ids = predictor.data['product_id'].unique()
|
||
print("\n可用的产品ID:")
|
||
for i, product_id in enumerate(product_ids):
|
||
print(f"{i+1}. {product_id}")
|
||
|
||
choice = input("\n请选择产品ID的编号 (或输入0返回): ")
|
||
|
||
if choice == '0':
|
||
return
|
||
|
||
try:
|
||
idx = int(choice) - 1
|
||
if 0 <= idx < len(product_ids):
|
||
product_id = product_ids[idx]
|
||
|
||
# 获取训练参数
|
||
epochs = int(input("\n请输入训练轮数 (默认100): ") or "100")
|
||
batch_size = int(input("请输入批次大小 (默认32): ") or "32")
|
||
learning_rate = float(input("请输入学习率 (默认0.001): ") or "0.001")
|
||
sequence_length = int(input("请输入输入序列长度 (默认30): ") or "30")
|
||
forecast_horizon = int(input("请输入预测天数 (默认7): ") or "7")
|
||
hidden_size = int(input("请输入隐藏层大小 (默认64): ") or "64")
|
||
num_layers = int(input("请输入层数 (默认2): ") or "2")
|
||
dropout = float(input("请输入Dropout比例 (默认0.1): ") or "0.1")
|
||
|
||
print(f"\n开始比较产品 {product_id} 的原始KAN和优化版KAN模型...")
|
||
|
||
# 使用比较方法
|
||
comparison = predictor.compare_kan_models(
|
||
product_id=product_id,
|
||
epochs=epochs,
|
||
batch_size=batch_size,
|
||
learning_rate=learning_rate,
|
||
sequence_length=sequence_length,
|
||
forecast_horizon=forecast_horizon,
|
||
hidden_size=hidden_size,
|
||
num_layers=num_layers,
|
||
dropout=dropout
|
||
)
|
||
|
||
print("\n比较完成!")
|
||
else:
|
||
print("\n无效的选择!")
|
||
except (ValueError, IndexError):
|
||
print("\n无效的输入!")
|
||
else:
|
||
print("\n没有可用的数据。请先生成或加载数据。")
|
||
|
||
input("\n按Enter键继续...")
|
||
|
||
def main():
|
||
"""主函数"""
|
||
# 检查数据文件是否存在,如果不存在则生成
|
||
data_path = 'pharmacy_sales.xlsx'
|
||
if not os.path.exists(data_path):
|
||
print("生成药店销售数据...")
|
||
generate_pharmacy_data(num_products=10, days=365, output_file=data_path)
|
||
print(f"数据已生成并保存到 {data_path}")
|
||
|
||
# 创建预测器实例
|
||
predictor = PharmacyPredictor(data_path=data_path)
|
||
|
||
while True:
|
||
clear_screen()
|
||
print_header()
|
||
choice = main_menu()
|
||
|
||
if choice == '0':
|
||
print("\n感谢使用药店销售预测系统,再见!")
|
||
break
|
||
elif choice == '1':
|
||
train_all_products_menu(predictor)
|
||
elif choice == '2':
|
||
train_single_product_menu(predictor, model_type='transformer')
|
||
elif choice == '3':
|
||
train_single_product_menu(predictor, model_type='mlstm')
|
||
elif choice == '4':
|
||
train_single_product_menu(predictor, model_type='kan')
|
||
elif choice == '5':
|
||
# 使用优化版KAN模型
|
||
train_single_product_menu(predictor, model_type='kan', use_optimized=True)
|
||
elif choice == '6':
|
||
# 比较原始KAN和优化版KAN模型性能
|
||
compare_kan_models_menu(predictor)
|
||
elif choice == '7':
|
||
# 查看已有预测结果
|
||
pass
|
||
elif choice == '8':
|
||
# 使用已训练的模型进行预测
|
||
pass
|
||
elif choice == '9':
|
||
# 比较不同模型的预测结果
|
||
pass
|
||
elif choice == '10':
|
||
# 模型管理
|
||
pass
|
||
else:
|
||
print("\n无效的选择,请重试!")
|
||
input("\n按Enter键继续...")
|
||
|
||
if __name__ == "__main__":
|
||
main() |