351 lines
13 KiB
Python
351 lines
13 KiB
Python
import os
|
||
import pandas as pd
|
||
import torch
|
||
import argparse
|
||
from models import ModelManager
|
||
|
||
def main():
|
||
"""模型管理工具主函数"""
|
||
parser = argparse.ArgumentParser(description='药店销售预测系统 - 模型管理工具')
|
||
parser.add_argument('--action', type=str, required=True,
|
||
choices=['list', 'details', 'predict', 'compare', 'delete', 'export', 'import'],
|
||
help='执行的操作类型')
|
||
parser.add_argument('--product_id', type=str, help='产品ID,例如P001')
|
||
parser.add_argument('--model_type', type=str, help='模型类型 (mlstm/transformer/kan)')
|
||
parser.add_argument('--version', type=str, help='模型版本')
|
||
parser.add_argument('--file_path', type=str, help='导入/导出的文件路径')
|
||
parser.add_argument('--export_dir', type=str, default='exported_models', help='导出目录')
|
||
parser.add_argument('--compare_models', type=str, help='要比较的模型类型,用逗号分隔')
|
||
|
||
args = parser.parse_args()
|
||
|
||
# 初始化模型管理器
|
||
model_manager = ModelManager()
|
||
|
||
# 根据操作类型执行相应的功能
|
||
if args.action == 'list':
|
||
# 列出模型
|
||
models = model_manager.list_models(
|
||
product_id=args.product_id,
|
||
model_type=args.model_type
|
||
)
|
||
|
||
if not models:
|
||
print("没有找到匹配的模型。")
|
||
return
|
||
|
||
print(f"\n找到 {len(models)} 个模型:")
|
||
for i, model in enumerate(models, 1):
|
||
print(f"\n{i}. {model['product_id']} - {model['model_type'].upper()} (版本: {model['version']})")
|
||
print(f" 创建时间: {model['created_at']}")
|
||
print(f" 文件大小: {model['file_size']}")
|
||
|
||
# 如果有评估指标,显示它们
|
||
if 'metrics' in model:
|
||
print(" 评估指标:")
|
||
for metric, value in model['metrics'].items():
|
||
print(f" {metric}: {value}")
|
||
|
||
elif args.action == 'details':
|
||
# 显示模型详情
|
||
if not args.product_id or not args.model_type:
|
||
print("错误: 需要指定product_id和model_type")
|
||
return
|
||
|
||
details = model_manager.get_model_details(
|
||
product_id=args.product_id,
|
||
model_type=args.model_type,
|
||
version=args.version
|
||
)
|
||
|
||
if not details:
|
||
print("未找到模型详情。")
|
||
return
|
||
|
||
print(f"\n{args.product_id} - {args.model_type.upper()} 模型详情:")
|
||
for key, value in details.items():
|
||
if key == 'metrics':
|
||
print(f"\n评估指标:")
|
||
for metric, metric_value in value.items():
|
||
print(f" {metric}: {metric_value}")
|
||
else:
|
||
print(f"{key}: {value}")
|
||
|
||
elif args.action == 'predict':
|
||
# 使用模型进行预测
|
||
if not args.product_id:
|
||
print("错误: 需要指定product_id")
|
||
return
|
||
|
||
# 使用指定的模型类型,如果未指定则使用mlstm
|
||
model_type = args.model_type if args.model_type else 'mlstm'
|
||
|
||
# 执行预测
|
||
predictions = model_manager.predict_with_model(
|
||
product_id=args.product_id,
|
||
model_type=model_type,
|
||
version=args.version
|
||
)
|
||
|
||
if predictions is None:
|
||
print("预测失败。")
|
||
|
||
elif args.action == 'compare':
|
||
# 比较不同模型的预测结果
|
||
if not args.product_id:
|
||
print("错误: 需要指定product_id")
|
||
return
|
||
|
||
# 确定要比较的模型
|
||
if args.compare_models:
|
||
model_types = args.compare_models.split(',')
|
||
else:
|
||
model_types = ['mlstm', 'transformer', 'kan']
|
||
|
||
# 执行比较
|
||
comparison = model_manager.compare_models(
|
||
product_id=args.product_id,
|
||
model_types=model_types
|
||
)
|
||
|
||
if comparison is None:
|
||
print("比较失败。")
|
||
|
||
elif args.action == 'delete':
|
||
# 删除模型
|
||
if not args.product_id or not args.model_type:
|
||
print("错误: 需要指定product_id和model_type")
|
||
return
|
||
|
||
success = model_manager.delete_model(
|
||
product_id=args.product_id,
|
||
model_type=args.model_type,
|
||
version=args.version
|
||
)
|
||
|
||
if success:
|
||
print("模型删除成功。")
|
||
else:
|
||
print("模型删除失败。")
|
||
|
||
elif args.action == 'export':
|
||
# 导出模型
|
||
if not args.product_id or not args.model_type:
|
||
print("错误: 需要指定product_id和model_type")
|
||
return
|
||
|
||
export_path = model_manager.export_model(
|
||
product_id=args.product_id,
|
||
model_type=args.model_type,
|
||
version=args.version,
|
||
export_dir=args.export_dir
|
||
)
|
||
|
||
if export_path:
|
||
print(f"模型已成功导出到: {export_path}")
|
||
else:
|
||
print("模型导出失败。")
|
||
|
||
elif args.action == 'import':
|
||
# 导入模型
|
||
if not args.file_path:
|
||
print("错误: 需要指定file_path")
|
||
return
|
||
|
||
import_path = model_manager.import_model(
|
||
import_file=args.file_path,
|
||
overwrite=True
|
||
)
|
||
|
||
if import_path:
|
||
print(f"模型已成功导入到: {import_path}")
|
||
else:
|
||
print("模型导入失败。")
|
||
|
||
def interactive_mode():
|
||
"""交互模式,通过菜单与用户交互"""
|
||
model_manager = ModelManager()
|
||
|
||
while True:
|
||
print("\n" + "="*50)
|
||
print("📊 药店销售预测系统 - 模型管理工具 📊")
|
||
print("="*50)
|
||
print("1. 查看所有模型")
|
||
print("2. 查看特定产品的模型")
|
||
print("3. 查看特定模型的详细信息")
|
||
print("4. 使用模型进行预测")
|
||
print("5. 比较不同模型的预测结果")
|
||
print("6. 删除模型")
|
||
print("7. 导出模型")
|
||
print("8. 导入模型")
|
||
print("0. 退出")
|
||
print("="*50)
|
||
|
||
choice = input("请输入选项 (0-8): ")
|
||
|
||
if choice == '0':
|
||
print("感谢使用模型管理工具!再见!")
|
||
break
|
||
|
||
elif choice == '1':
|
||
# 查看所有模型
|
||
models = model_manager.list_models()
|
||
|
||
if not models:
|
||
print("没有找到任何模型。")
|
||
continue
|
||
|
||
print(f"\n找到 {len(models)} 个模型:")
|
||
for i, model in enumerate(models, 1):
|
||
print(f"\n{i}. {model['product_id']} - {model['model_type'].upper()} (版本: {model['version']})")
|
||
print(f" 创建时间: {model['created_at']}")
|
||
print(f" 文件大小: {model['file_size']}")
|
||
|
||
elif choice == '2':
|
||
# 查看特定产品的模型
|
||
product_id = input("请输入产品ID (例如P001): ")
|
||
|
||
models = model_manager.list_models(product_id=product_id)
|
||
|
||
if not models:
|
||
print(f"没有找到产品 {product_id} 的模型。")
|
||
continue
|
||
|
||
print(f"\n找到 {len(models)} 个 {product_id} 的模型:")
|
||
for i, model in enumerate(models, 1):
|
||
print(f"\n{i}. {model['model_type'].upper()} (版本: {model['version']})")
|
||
print(f" 创建时间: {model['created_at']}")
|
||
print(f" 文件大小: {model['file_size']}")
|
||
|
||
elif choice == '3':
|
||
# 查看模型详情
|
||
product_id = input("请输入产品ID (例如P001): ")
|
||
model_type = input("请输入模型类型 (mlstm/transformer/kan): ")
|
||
version = input("请输入版本 (如不指定则使用最新版本): ") or None
|
||
|
||
details = model_manager.get_model_details(
|
||
product_id=product_id,
|
||
model_type=model_type,
|
||
version=version
|
||
)
|
||
|
||
if not details:
|
||
print("未找到模型详情。")
|
||
continue
|
||
|
||
print(f"\n{product_id} - {model_type.upper()} 模型详情:")
|
||
for key, value in details.items():
|
||
if key == 'metrics':
|
||
print(f"\n评估指标:")
|
||
for metric, metric_value in value.items():
|
||
print(f" {metric}: {metric_value}")
|
||
else:
|
||
print(f"{key}: {value}")
|
||
|
||
elif choice == '4':
|
||
# 使用模型预测
|
||
product_id = input("请输入产品ID (例如P001): ")
|
||
model_type = input("请输入模型类型 (mlstm/transformer/kan): ") or 'mlstm'
|
||
version = input("请输入版本 (如不指定则使用最新版本): ") or None
|
||
|
||
# 执行预测
|
||
predictions = model_manager.predict_with_model(
|
||
product_id=product_id,
|
||
model_type=model_type,
|
||
version=version
|
||
)
|
||
|
||
if predictions is None:
|
||
print("预测失败。")
|
||
|
||
elif choice == '5':
|
||
# 比较模型
|
||
product_id = input("请输入产品ID (例如P001): ")
|
||
compare_models = input("请输入要比较的模型类型,用逗号分隔 (如不指定则比较所有模型): ") or None
|
||
|
||
# 确定要比较的模型
|
||
if compare_models:
|
||
model_types = compare_models.split(',')
|
||
else:
|
||
model_types = ['mlstm', 'transformer', 'kan']
|
||
|
||
# 执行比较
|
||
comparison = model_manager.compare_models(
|
||
product_id=product_id,
|
||
model_types=model_types
|
||
)
|
||
|
||
if comparison is None:
|
||
print("比较失败。")
|
||
|
||
elif choice == '6':
|
||
# 删除模型
|
||
product_id = input("请输入产品ID (例如P001): ")
|
||
model_type = input("请输入模型类型 (mlstm/transformer/kan): ")
|
||
version = input("请输入版本 (如不指定则删除所有版本): ") or None
|
||
|
||
confirm = input(f"确定要删除 {product_id} 的 {model_type} 模型吗?(y/n): ")
|
||
if confirm.lower() != 'y':
|
||
print("已取消删除操作。")
|
||
continue
|
||
|
||
success = model_manager.delete_model(
|
||
product_id=product_id,
|
||
model_type=model_type,
|
||
version=version
|
||
)
|
||
|
||
if success:
|
||
print("模型删除成功。")
|
||
else:
|
||
print("模型删除失败。")
|
||
|
||
elif choice == '7':
|
||
# 导出模型
|
||
product_id = input("请输入产品ID (例如P001): ")
|
||
model_type = input("请输入模型类型 (mlstm/transformer/kan): ")
|
||
version = input("请输入版本 (如不指定则使用最新版本): ") or None
|
||
export_dir = input("请输入导出目录 (默认为exported_models): ") or 'exported_models'
|
||
|
||
export_path = model_manager.export_model(
|
||
product_id=product_id,
|
||
model_type=model_type,
|
||
version=version,
|
||
export_dir=export_dir
|
||
)
|
||
|
||
if export_path:
|
||
print(f"模型已成功导出到: {export_path}")
|
||
else:
|
||
print("模型导出失败。")
|
||
|
||
elif choice == '8':
|
||
# 导入模型
|
||
file_path = input("请输入要导入的模型文件路径: ")
|
||
|
||
if not os.path.exists(file_path):
|
||
print(f"错误: 文件 {file_path} 不存在")
|
||
continue
|
||
|
||
import_path = model_manager.import_model(
|
||
import_file=file_path,
|
||
overwrite=True
|
||
)
|
||
|
||
if import_path:
|
||
print(f"模型已成功导入到: {import_path}")
|
||
else:
|
||
print("模型导入失败。")
|
||
|
||
else:
|
||
print("无效的选项,请重新输入!")
|
||
|
||
input("\n按回车键继续...")
|
||
|
||
if __name__ == "__main__":
|
||
# 检查命令行参数,如果没有参数则启动交互模式
|
||
import sys
|
||
if len(sys.argv) == 1:
|
||
interactive_mode()
|
||
else:
|
||
main() |