import os import pandas as pd import numpy as np import matplotlib.pyplot as plt from datetime import datetime import sys # 导入模块 from pharmacy_predictor import train_product_model, train_product_model_with_mlstm, train_product_model_with_kan, train_product_model_with_transformer from pharmacy_predictor import load_model_and_predict, load_kan_model_and_predict from models import ModelManager def print_header(): print("\n" + "="*60) print("📊 药店单品销售预测系统 📊") print("="*60) def main(): # 初始化模型管理器 model_manager = ModelManager() # 首先检查数据文件是否存在 try: df = pd.read_excel('pharmacy_sales.xlsx') print("✅ 检测到现有数据文件 'pharmacy_sales.xlsx'") # 获取所有产品 products = df[['product_id', 'product_name']].drop_duplicates().sort_values('product_id') print(f"\n📋 发现 {len(products)} 种药品:") for i, (_, row) in enumerate(products.iterrows(), 1): print(f" {i}. {row['product_name']} (ID: {row['product_id']})") except: print("❌ 未找到数据文件。正在生成模拟数据...") import generate_pharmacy_data df = pd.read_excel('pharmacy_sales.xlsx') print("✅ 数据生成完成!") # 获取所有产品 products = df[['product_id', 'product_name']].drop_duplicates().sort_values('product_id') while True: print("\n" + "="*60) print("📋 请选择操作:") print(" 1. 训练所有药品的销售预测模型") print(" 2. 训练单个药品的销售预测模型") print(" 3. 使用mLSTM模型训练单个药品的销售预测模型") print(" 4. 使用KAN模型训练单个药品的销售预测模型") print(" 5. 使用Transformer模型训练单个药品的销售预测模型") print(" 6. 查看已有预测结果") print(" 7. 使用已训练的模型进行预测") print(" 8. 比较不同模型的预测结果") print(" 9. 模型管理") print(" 0. 退出") print("="*60) choice = input("\n请输入选项 (0-9): ") if choice == '0': print("感谢使用药店销售预测系统!再见!") break elif choice == '1': # 训练所有药品的预测模型 print("\n开始训练所有药品的销售预测模型...") all_metrics = {} for _, row in products.iterrows(): product_id = row['product_id'] print(f"\n{'='*50}") print(f"开始训练产品 {row['product_name']} (ID: {product_id}) 的模型") print(f"{'='*50}") _, metrics = train_product_model(product_id) all_metrics[product_id] = metrics # 输出所有产品的评估指标 print("\n所有产品模型评估结果汇总:") for product_id, metrics in all_metrics.items(): product_name = df[df['product_id'] == product_id]['product_name'].iloc[0] print(f"\n{product_name} (ID: {product_id}):") for metric, value in metrics.items(): print(f" {metric}: {value:.4f}") print("\n模型训练和评估完成!") elif choice == '2': # 训练单个药品的预测模型 print("\n请选择要训练的药品:") for i, (_, row) in enumerate(products.iterrows(), 1): print(f" {i}. {row['product_name']} (ID: {row['product_id']})") product_idx = int(input("\n请输入药品编号 (1-{}): ".format(len(products)))) if 1 <= product_idx <= len(products): product_id = products.iloc[product_idx-1]['product_id'] product_name = products.iloc[product_idx-1]['product_name'] print(f"\n开始训练 {product_name} (ID: {product_id}) 的销售预测模型...") _, metrics = train_product_model(product_id) print(f"\n{product_name} 模型评估指标:") for metric, value in metrics.items(): print(f" {metric}: {value:.4f}") print(f"\n模型训练和评估完成!") else: print("\n❌ 无效的选择!") elif choice == '3': # 使用mLSTM模型训练单个药品 print("\n请选择要训练的药品:") for i, (_, row) in enumerate(products.iterrows(), 1): print(f" {i}. {row['product_name']} (ID: {row['product_id']})") product_idx = int(input("\n请输入药品编号 (1-{}): ".format(len(products)))) if 1 <= product_idx <= len(products): product_id = products.iloc[product_idx-1]['product_id'] product_name = products.iloc[product_idx-1]['product_name'] print(f"\n开始使用mLSTM模型训练 {product_name} (ID: {product_id}) 的销售预测模型...") model, metrics = train_product_model_with_mlstm(product_id) # 保存模型到模型管理器 try: model_manager.save_model( model=model, model_type='mlstm', product_id=product_id, metrics=metrics, features=['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'], look_back=14, T=7 ) print("✅ 模型已保存到模型管理器") except Exception as e: print(f"❌ 保存模型到管理器时出错: {str(e)}") print(f"\n{product_name} mLSTM模型评估指标:") for metric, value in metrics.items(): print(f" {metric}: {value:.4f}") print(f"\n模型训练和评估完成!") else: print("\n❌ 无效的选择!") elif choice == '4': # 使用KAN模型训练单个药品 print("\n请选择要训练的药品:") for i, (_, row) in enumerate(products.iterrows(), 1): print(f" {i}. {row['product_name']} (ID: {row['product_id']})") product_idx = int(input("\n请输入药品编号 (1-{}): ".format(len(products)))) if 1 <= product_idx <= len(products): product_id = products.iloc[product_idx-1]['product_id'] product_name = products.iloc[product_idx-1]['product_name'] print(f"\n开始使用KAN模型训练 {product_name} (ID: {product_id}) 的销售预测模型...") model, metrics = train_product_model_with_kan(product_id) # 保存模型到模型管理器 try: model_manager.save_model( model=model, model_type='kan', product_id=product_id, metrics=metrics, features=['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'], look_back=14, T=7 ) print("✅ 模型已保存到模型管理器") except Exception as e: print(f"❌ 保存模型到管理器时出错: {str(e)}") print(f"\n{product_name} KAN模型评估指标:") for metric, value in metrics.items(): print(f" {metric}: {value:.4f}") print(f"\n模型训练和评估完成!") else: print("\n❌ 无效的选择!") elif choice == '5': # 使用Transformer模型训练单个药品 print("\n请选择要训练的药品:") for i, (_, row) in enumerate(products.iterrows(), 1): print(f" {i}. {row['product_name']} (ID: {row['product_id']})") product_idx = int(input("\n请输入药品编号 (1-{}): ".format(len(products)))) if 1 <= product_idx <= len(products): product_id = products.iloc[product_idx-1]['product_id'] product_name = products.iloc[product_idx-1]['product_name'] print(f"\n开始使用Transformer模型训练 {product_name} (ID: {product_id}) 的销售预测模型...") model, metrics = train_product_model_with_transformer(product_id) # 保存模型到模型管理器 try: model_manager.save_model( model=model, model_type='transformer', product_id=product_id, metrics=metrics, features=['sales', 'price', 'weekday', 'month', 'is_holiday', 'is_weekend', 'is_promotion', 'temperature'], look_back=14, T=7 ) print("✅ 模型已保存到模型管理器") except Exception as e: print(f"❌ 保存模型到管理器时出错: {str(e)}") print(f"\n{product_name} Transformer模型评估指标:") for metric, value in metrics.items(): print(f" {metric}: {value:.4f}") print(f"\n模型训练和评估完成!") else: print("\n❌ 无效的选择!") elif choice == '6': # 查看已有预测结果 print("\n正在搜索预测结果文件...") prediction_files = [f for f in os.listdir() if f.endswith('_prediction_results.csv')] if not prediction_files: print("❌ 未找到任何预测结果文件。请先训练模型生成预测结果。") continue print(f"\n找到 {len(prediction_files)} 个预测结果文件:") for i, file in enumerate(prediction_files, 1): print(f" {i}. {file}") file_idx = int(input("\n请选择要查看的文件 (1-{}): ".format(len(prediction_files)))) if 1 <= file_idx <= len(prediction_files): file_path = prediction_files[file_idx-1] try: results_df = pd.read_csv(file_path) print(f"\n{file_path} 内容:") print(results_df) # 可视化结果 plt.figure(figsize=(12, 6)) plt.plot(results_df['date'], results_df['actual_sales'], 'b-', label='实际销量') plt.plot(results_df['date'], results_df['predicted_sales'], 'r--', label='预测销量') plt.title('销量预测结果') plt.xlabel('日期') plt.ylabel('销量') plt.legend() plt.grid(True) plt.xticks(rotation=45) plt.tight_layout() plt.show() except Exception as e: print(f"❌ 读取文件时出错: {str(e)}") else: print("\n❌ 无效的选择!") elif choice == '7': # 使用已训练的模型进行预测 print("\n请选择要使用的模型类型:") print(" 1. mLSTM模型") print(" 2. KAN模型") print(" 3. Transformer模型") model_choice = input("\n请输入选项 (1-3): ") if model_choice not in ['1', '2', '3']: print("\n❌ 无效的选择!") continue # 选择产品 print("\n请选择要预测的药品:") for i, (_, row) in enumerate(products.iterrows(), 1): print(f" {i}. {row['product_name']} (ID: {row['product_id']})") product_idx = int(input("\n请输入药品编号 (1-{}): ".format(len(products)))) if 1 <= product_idx <= len(products): product_id = products.iloc[product_idx-1]['product_id'] product_name = products.iloc[product_idx-1]['product_name'] # 使用模型管理器进行预测 try: if model_choice == '1': # mLSTM模型 model_type = 'mlstm' elif model_choice == '2': # KAN模型 model_type = 'kan' else: # Transformer模型 model_type = 'transformer' print(f"\n使用{model_type.upper()}模型预测 {product_name} (ID: {product_id}) 的未来销量...") # 使用模型管理器进行预测 product_df = df[df['product_id'] == product_id].sort_values('date') predictions = model_manager.predict_with_model( product_id=product_id, model_type=model_type, product_df=product_df ) if predictions is not None: print("\n✅ 预测完成!") except Exception as e: print(f"\n❌ 预测时出错: {str(e)}") print("请确保已经训练并保存了对应的模型。") else: print("\n❌ 无效的选择!") elif choice == '8': # 比较不同模型的预测结果 print("\n请选择要比较的药品:") for i, (_, row) in enumerate(products.iterrows(), 1): print(f" {i}. {row['product_name']} (ID: {row['product_id']})") product_idx = int(input("\n请输入药品编号 (1-{}): ".format(len(products)))) if 1 <= product_idx <= len(products): product_id = products.iloc[product_idx-1]['product_id'] product_name = products.iloc[product_idx-1]['product_name'] print(f"\n比较 {product_name} (ID: {product_id}) 的不同模型预测结果...") # 使用模型管理器进行比较 try: product_df = df[df['product_id'] == product_id].sort_values('date') comparison = model_manager.compare_models( product_id=product_id, product_df=product_df ) if comparison is not None: print("\n✅ 比较完成!") except Exception as e: print(f"\n❌ 比较时出错: {str(e)}") print("请确保已经训练并保存了对应的模型。") else: print("\n❌ 无效的选择!") elif choice == '9': # 模型管理 print("\n启动模型管理工具...") import model_management model_management.interactive_mode() else: print("\n❌ 无效的选项!请重新输入。") if __name__ == "__main__": print_header() main()