import os import pandas as pd import torch import argparse from models import ModelManager def main(): """模型管理工具主函数""" parser = argparse.ArgumentParser(description='药店销售预测系统 - 模型管理工具') parser.add_argument('--action', type=str, required=True, choices=['list', 'details', 'predict', 'compare', 'delete', 'export', 'import'], help='执行的操作类型') parser.add_argument('--product_id', type=str, help='产品ID,例如P001') parser.add_argument('--model_type', type=str, help='模型类型 (mlstm/transformer/kan)') parser.add_argument('--version', type=str, help='模型版本') parser.add_argument('--file_path', type=str, help='导入/导出的文件路径') parser.add_argument('--export_dir', type=str, default='exported_models', help='导出目录') parser.add_argument('--compare_models', type=str, help='要比较的模型类型,用逗号分隔') args = parser.parse_args() # 初始化模型管理器 model_manager = ModelManager() # 根据操作类型执行相应的功能 if args.action == 'list': # 列出模型 models = model_manager.list_models( product_id=args.product_id, model_type=args.model_type ) if not models: print("没有找到匹配的模型。") return print(f"\n找到 {len(models)} 个模型:") for i, model in enumerate(models, 1): print(f"\n{i}. {model['product_id']} - {model['model_type'].upper()} (版本: {model['version']})") print(f" 创建时间: {model['created_at']}") print(f" 文件大小: {model['file_size']}") # 如果有评估指标,显示它们 if 'metrics' in model: print(" 评估指标:") for metric, value in model['metrics'].items(): print(f" {metric}: {value}") elif args.action == 'details': # 显示模型详情 if not args.product_id or not args.model_type: print("错误: 需要指定product_id和model_type") return details = model_manager.get_model_details( product_id=args.product_id, model_type=args.model_type, version=args.version ) if not details: print("未找到模型详情。") return print(f"\n{args.product_id} - {args.model_type.upper()} 模型详情:") for key, value in details.items(): if key == 'metrics': print(f"\n评估指标:") for metric, metric_value in value.items(): print(f" {metric}: {metric_value}") else: print(f"{key}: {value}") elif args.action == 'predict': # 使用模型进行预测 if not args.product_id: print("错误: 需要指定product_id") return # 使用指定的模型类型,如果未指定则使用mlstm model_type = args.model_type if args.model_type else 'mlstm' # 执行预测 predictions = model_manager.predict_with_model( product_id=args.product_id, model_type=model_type, version=args.version ) if predictions is None: print("预测失败。") elif args.action == 'compare': # 比较不同模型的预测结果 if not args.product_id: print("错误: 需要指定product_id") return # 确定要比较的模型 if args.compare_models: model_types = args.compare_models.split(',') else: model_types = ['mlstm', 'transformer', 'kan'] # 执行比较 comparison = model_manager.compare_models( product_id=args.product_id, model_types=model_types ) if comparison is None: print("比较失败。") elif args.action == 'delete': # 删除模型 if not args.product_id or not args.model_type: print("错误: 需要指定product_id和model_type") return success = model_manager.delete_model( product_id=args.product_id, model_type=args.model_type, version=args.version ) if success: print("模型删除成功。") else: print("模型删除失败。") elif args.action == 'export': # 导出模型 if not args.product_id or not args.model_type: print("错误: 需要指定product_id和model_type") return export_path = model_manager.export_model( product_id=args.product_id, model_type=args.model_type, version=args.version, export_dir=args.export_dir ) if export_path: print(f"模型已成功导出到: {export_path}") else: print("模型导出失败。") elif args.action == 'import': # 导入模型 if not args.file_path: print("错误: 需要指定file_path") return import_path = model_manager.import_model( import_file=args.file_path, overwrite=True ) if import_path: print(f"模型已成功导入到: {import_path}") else: print("模型导入失败。") def interactive_mode(): """交互模式,通过菜单与用户交互""" model_manager = ModelManager() while True: print("\n" + "="*50) print("📊 药店销售预测系统 - 模型管理工具 📊") print("="*50) print("1. 查看所有模型") print("2. 查看特定产品的模型") print("3. 查看特定模型的详细信息") print("4. 使用模型进行预测") print("5. 比较不同模型的预测结果") print("6. 删除模型") print("7. 导出模型") print("8. 导入模型") print("0. 退出") print("="*50) choice = input("请输入选项 (0-8): ") if choice == '0': print("感谢使用模型管理工具!再见!") break elif choice == '1': # 查看所有模型 models = model_manager.list_models() if not models: print("没有找到任何模型。") continue print(f"\n找到 {len(models)} 个模型:") for i, model in enumerate(models, 1): print(f"\n{i}. {model['product_id']} - {model['model_type'].upper()} (版本: {model['version']})") print(f" 创建时间: {model['created_at']}") print(f" 文件大小: {model['file_size']}") elif choice == '2': # 查看特定产品的模型 product_id = input("请输入产品ID (例如P001): ") models = model_manager.list_models(product_id=product_id) if not models: print(f"没有找到产品 {product_id} 的模型。") continue print(f"\n找到 {len(models)} 个 {product_id} 的模型:") for i, model in enumerate(models, 1): print(f"\n{i}. {model['model_type'].upper()} (版本: {model['version']})") print(f" 创建时间: {model['created_at']}") print(f" 文件大小: {model['file_size']}") elif choice == '3': # 查看模型详情 product_id = input("请输入产品ID (例如P001): ") model_type = input("请输入模型类型 (mlstm/transformer/kan): ") version = input("请输入版本 (如不指定则使用最新版本): ") or None details = model_manager.get_model_details( product_id=product_id, model_type=model_type, version=version ) if not details: print("未找到模型详情。") continue print(f"\n{product_id} - {model_type.upper()} 模型详情:") for key, value in details.items(): if key == 'metrics': print(f"\n评估指标:") for metric, metric_value in value.items(): print(f" {metric}: {metric_value}") else: print(f"{key}: {value}") elif choice == '4': # 使用模型预测 product_id = input("请输入产品ID (例如P001): ") model_type = input("请输入模型类型 (mlstm/transformer/kan): ") or 'mlstm' version = input("请输入版本 (如不指定则使用最新版本): ") or None # 执行预测 predictions = model_manager.predict_with_model( product_id=product_id, model_type=model_type, version=version ) if predictions is None: print("预测失败。") elif choice == '5': # 比较模型 product_id = input("请输入产品ID (例如P001): ") compare_models = input("请输入要比较的模型类型,用逗号分隔 (如不指定则比较所有模型): ") or None # 确定要比较的模型 if compare_models: model_types = compare_models.split(',') else: model_types = ['mlstm', 'transformer', 'kan'] # 执行比较 comparison = model_manager.compare_models( product_id=product_id, model_types=model_types ) if comparison is None: print("比较失败。") elif choice == '6': # 删除模型 product_id = input("请输入产品ID (例如P001): ") model_type = input("请输入模型类型 (mlstm/transformer/kan): ") version = input("请输入版本 (如不指定则删除所有版本): ") or None confirm = input(f"确定要删除 {product_id} 的 {model_type} 模型吗?(y/n): ") if confirm.lower() != 'y': print("已取消删除操作。") continue success = model_manager.delete_model( product_id=product_id, model_type=model_type, version=version ) if success: print("模型删除成功。") else: print("模型删除失败。") elif choice == '7': # 导出模型 product_id = input("请输入产品ID (例如P001): ") model_type = input("请输入模型类型 (mlstm/transformer/kan): ") version = input("请输入版本 (如不指定则使用最新版本): ") or None export_dir = input("请输入导出目录 (默认为exported_models): ") or 'exported_models' export_path = model_manager.export_model( product_id=product_id, model_type=model_type, version=version, export_dir=export_dir ) if export_path: print(f"模型已成功导出到: {export_path}") else: print("模型导出失败。") elif choice == '8': # 导入模型 file_path = input("请输入要导入的模型文件路径: ") if not os.path.exists(file_path): print(f"错误: 文件 {file_path} 不存在") continue import_path = model_manager.import_model( import_file=file_path, overwrite=True ) if import_path: print(f"模型已成功导入到: {import_path}") else: print("模型导入失败。") else: print("无效的选项,请重新输入!") input("\n按回车键继续...") if __name__ == "__main__": # 检查命令行参数,如果没有参数则启动交互模式 import sys if len(sys.argv) == 1: interactive_mode() else: main()