351 lines
13 KiB
Python
351 lines
13 KiB
Python
![]() |
import os
|
|||
|
import pandas as pd
|
|||
|
import torch
|
|||
|
import argparse
|
|||
|
from models import ModelManager
|
|||
|
|
|||
|
def main():
|
|||
|
"""模型管理工具主函数"""
|
|||
|
parser = argparse.ArgumentParser(description='药店销售预测系统 - 模型管理工具')
|
|||
|
parser.add_argument('--action', type=str, required=True,
|
|||
|
choices=['list', 'details', 'predict', 'compare', 'delete', 'export', 'import'],
|
|||
|
help='执行的操作类型')
|
|||
|
parser.add_argument('--product_id', type=str, help='产品ID,例如P001')
|
|||
|
parser.add_argument('--model_type', type=str, help='模型类型 (mlstm/transformer/kan)')
|
|||
|
parser.add_argument('--version', type=str, help='模型版本')
|
|||
|
parser.add_argument('--file_path', type=str, help='导入/导出的文件路径')
|
|||
|
parser.add_argument('--export_dir', type=str, default='exported_models', help='导出目录')
|
|||
|
parser.add_argument('--compare_models', type=str, help='要比较的模型类型,用逗号分隔')
|
|||
|
|
|||
|
args = parser.parse_args()
|
|||
|
|
|||
|
# 初始化模型管理器
|
|||
|
model_manager = ModelManager()
|
|||
|
|
|||
|
# 根据操作类型执行相应的功能
|
|||
|
if args.action == 'list':
|
|||
|
# 列出模型
|
|||
|
models = model_manager.list_models(
|
|||
|
product_id=args.product_id,
|
|||
|
model_type=args.model_type
|
|||
|
)
|
|||
|
|
|||
|
if not models:
|
|||
|
print("没有找到匹配的模型。")
|
|||
|
return
|
|||
|
|
|||
|
print(f"\n找到 {len(models)} 个模型:")
|
|||
|
for i, model in enumerate(models, 1):
|
|||
|
print(f"\n{i}. {model['product_id']} - {model['model_type'].upper()} (版本: {model['version']})")
|
|||
|
print(f" 创建时间: {model['created_at']}")
|
|||
|
print(f" 文件大小: {model['file_size']}")
|
|||
|
|
|||
|
# 如果有评估指标,显示它们
|
|||
|
if 'metrics' in model:
|
|||
|
print(" 评估指标:")
|
|||
|
for metric, value in model['metrics'].items():
|
|||
|
print(f" {metric}: {value}")
|
|||
|
|
|||
|
elif args.action == 'details':
|
|||
|
# 显示模型详情
|
|||
|
if not args.product_id or not args.model_type:
|
|||
|
print("错误: 需要指定product_id和model_type")
|
|||
|
return
|
|||
|
|
|||
|
details = model_manager.get_model_details(
|
|||
|
product_id=args.product_id,
|
|||
|
model_type=args.model_type,
|
|||
|
version=args.version
|
|||
|
)
|
|||
|
|
|||
|
if not details:
|
|||
|
print("未找到模型详情。")
|
|||
|
return
|
|||
|
|
|||
|
print(f"\n{args.product_id} - {args.model_type.upper()} 模型详情:")
|
|||
|
for key, value in details.items():
|
|||
|
if key == 'metrics':
|
|||
|
print(f"\n评估指标:")
|
|||
|
for metric, metric_value in value.items():
|
|||
|
print(f" {metric}: {metric_value}")
|
|||
|
else:
|
|||
|
print(f"{key}: {value}")
|
|||
|
|
|||
|
elif args.action == 'predict':
|
|||
|
# 使用模型进行预测
|
|||
|
if not args.product_id:
|
|||
|
print("错误: 需要指定product_id")
|
|||
|
return
|
|||
|
|
|||
|
# 使用指定的模型类型,如果未指定则使用mlstm
|
|||
|
model_type = args.model_type if args.model_type else 'mlstm'
|
|||
|
|
|||
|
# 执行预测
|
|||
|
predictions = model_manager.predict_with_model(
|
|||
|
product_id=args.product_id,
|
|||
|
model_type=model_type,
|
|||
|
version=args.version
|
|||
|
)
|
|||
|
|
|||
|
if predictions is None:
|
|||
|
print("预测失败。")
|
|||
|
|
|||
|
elif args.action == 'compare':
|
|||
|
# 比较不同模型的预测结果
|
|||
|
if not args.product_id:
|
|||
|
print("错误: 需要指定product_id")
|
|||
|
return
|
|||
|
|
|||
|
# 确定要比较的模型
|
|||
|
if args.compare_models:
|
|||
|
model_types = args.compare_models.split(',')
|
|||
|
else:
|
|||
|
model_types = ['mlstm', 'transformer', 'kan']
|
|||
|
|
|||
|
# 执行比较
|
|||
|
comparison = model_manager.compare_models(
|
|||
|
product_id=args.product_id,
|
|||
|
model_types=model_types
|
|||
|
)
|
|||
|
|
|||
|
if comparison is None:
|
|||
|
print("比较失败。")
|
|||
|
|
|||
|
elif args.action == 'delete':
|
|||
|
# 删除模型
|
|||
|
if not args.product_id or not args.model_type:
|
|||
|
print("错误: 需要指定product_id和model_type")
|
|||
|
return
|
|||
|
|
|||
|
success = model_manager.delete_model(
|
|||
|
product_id=args.product_id,
|
|||
|
model_type=args.model_type,
|
|||
|
version=args.version
|
|||
|
)
|
|||
|
|
|||
|
if success:
|
|||
|
print("模型删除成功。")
|
|||
|
else:
|
|||
|
print("模型删除失败。")
|
|||
|
|
|||
|
elif args.action == 'export':
|
|||
|
# 导出模型
|
|||
|
if not args.product_id or not args.model_type:
|
|||
|
print("错误: 需要指定product_id和model_type")
|
|||
|
return
|
|||
|
|
|||
|
export_path = model_manager.export_model(
|
|||
|
product_id=args.product_id,
|
|||
|
model_type=args.model_type,
|
|||
|
version=args.version,
|
|||
|
export_dir=args.export_dir
|
|||
|
)
|
|||
|
|
|||
|
if export_path:
|
|||
|
print(f"模型已成功导出到: {export_path}")
|
|||
|
else:
|
|||
|
print("模型导出失败。")
|
|||
|
|
|||
|
elif args.action == 'import':
|
|||
|
# 导入模型
|
|||
|
if not args.file_path:
|
|||
|
print("错误: 需要指定file_path")
|
|||
|
return
|
|||
|
|
|||
|
import_path = model_manager.import_model(
|
|||
|
import_file=args.file_path,
|
|||
|
overwrite=True
|
|||
|
)
|
|||
|
|
|||
|
if import_path:
|
|||
|
print(f"模型已成功导入到: {import_path}")
|
|||
|
else:
|
|||
|
print("模型导入失败。")
|
|||
|
|
|||
|
def interactive_mode():
|
|||
|
"""交互模式,通过菜单与用户交互"""
|
|||
|
model_manager = ModelManager()
|
|||
|
|
|||
|
while True:
|
|||
|
print("\n" + "="*50)
|
|||
|
print("📊 药店销售预测系统 - 模型管理工具 📊")
|
|||
|
print("="*50)
|
|||
|
print("1. 查看所有模型")
|
|||
|
print("2. 查看特定产品的模型")
|
|||
|
print("3. 查看特定模型的详细信息")
|
|||
|
print("4. 使用模型进行预测")
|
|||
|
print("5. 比较不同模型的预测结果")
|
|||
|
print("6. 删除模型")
|
|||
|
print("7. 导出模型")
|
|||
|
print("8. 导入模型")
|
|||
|
print("0. 退出")
|
|||
|
print("="*50)
|
|||
|
|
|||
|
choice = input("请输入选项 (0-8): ")
|
|||
|
|
|||
|
if choice == '0':
|
|||
|
print("感谢使用模型管理工具!再见!")
|
|||
|
break
|
|||
|
|
|||
|
elif choice == '1':
|
|||
|
# 查看所有模型
|
|||
|
models = model_manager.list_models()
|
|||
|
|
|||
|
if not models:
|
|||
|
print("没有找到任何模型。")
|
|||
|
continue
|
|||
|
|
|||
|
print(f"\n找到 {len(models)} 个模型:")
|
|||
|
for i, model in enumerate(models, 1):
|
|||
|
print(f"\n{i}. {model['product_id']} - {model['model_type'].upper()} (版本: {model['version']})")
|
|||
|
print(f" 创建时间: {model['created_at']}")
|
|||
|
print(f" 文件大小: {model['file_size']}")
|
|||
|
|
|||
|
elif choice == '2':
|
|||
|
# 查看特定产品的模型
|
|||
|
product_id = input("请输入产品ID (例如P001): ")
|
|||
|
|
|||
|
models = model_manager.list_models(product_id=product_id)
|
|||
|
|
|||
|
if not models:
|
|||
|
print(f"没有找到产品 {product_id} 的模型。")
|
|||
|
continue
|
|||
|
|
|||
|
print(f"\n找到 {len(models)} 个 {product_id} 的模型:")
|
|||
|
for i, model in enumerate(models, 1):
|
|||
|
print(f"\n{i}. {model['model_type'].upper()} (版本: {model['version']})")
|
|||
|
print(f" 创建时间: {model['created_at']}")
|
|||
|
print(f" 文件大小: {model['file_size']}")
|
|||
|
|
|||
|
elif choice == '3':
|
|||
|
# 查看模型详情
|
|||
|
product_id = input("请输入产品ID (例如P001): ")
|
|||
|
model_type = input("请输入模型类型 (mlstm/transformer/kan): ")
|
|||
|
version = input("请输入版本 (如不指定则使用最新版本): ") or None
|
|||
|
|
|||
|
details = model_manager.get_model_details(
|
|||
|
product_id=product_id,
|
|||
|
model_type=model_type,
|
|||
|
version=version
|
|||
|
)
|
|||
|
|
|||
|
if not details:
|
|||
|
print("未找到模型详情。")
|
|||
|
continue
|
|||
|
|
|||
|
print(f"\n{product_id} - {model_type.upper()} 模型详情:")
|
|||
|
for key, value in details.items():
|
|||
|
if key == 'metrics':
|
|||
|
print(f"\n评估指标:")
|
|||
|
for metric, metric_value in value.items():
|
|||
|
print(f" {metric}: {metric_value}")
|
|||
|
else:
|
|||
|
print(f"{key}: {value}")
|
|||
|
|
|||
|
elif choice == '4':
|
|||
|
# 使用模型预测
|
|||
|
product_id = input("请输入产品ID (例如P001): ")
|
|||
|
model_type = input("请输入模型类型 (mlstm/transformer/kan): ") or 'mlstm'
|
|||
|
version = input("请输入版本 (如不指定则使用最新版本): ") or None
|
|||
|
|
|||
|
# 执行预测
|
|||
|
predictions = model_manager.predict_with_model(
|
|||
|
product_id=product_id,
|
|||
|
model_type=model_type,
|
|||
|
version=version
|
|||
|
)
|
|||
|
|
|||
|
if predictions is None:
|
|||
|
print("预测失败。")
|
|||
|
|
|||
|
elif choice == '5':
|
|||
|
# 比较模型
|
|||
|
product_id = input("请输入产品ID (例如P001): ")
|
|||
|
compare_models = input("请输入要比较的模型类型,用逗号分隔 (如不指定则比较所有模型): ") or None
|
|||
|
|
|||
|
# 确定要比较的模型
|
|||
|
if compare_models:
|
|||
|
model_types = compare_models.split(',')
|
|||
|
else:
|
|||
|
model_types = ['mlstm', 'transformer', 'kan']
|
|||
|
|
|||
|
# 执行比较
|
|||
|
comparison = model_manager.compare_models(
|
|||
|
product_id=product_id,
|
|||
|
model_types=model_types
|
|||
|
)
|
|||
|
|
|||
|
if comparison is None:
|
|||
|
print("比较失败。")
|
|||
|
|
|||
|
elif choice == '6':
|
|||
|
# 删除模型
|
|||
|
product_id = input("请输入产品ID (例如P001): ")
|
|||
|
model_type = input("请输入模型类型 (mlstm/transformer/kan): ")
|
|||
|
version = input("请输入版本 (如不指定则删除所有版本): ") or None
|
|||
|
|
|||
|
confirm = input(f"确定要删除 {product_id} 的 {model_type} 模型吗?(y/n): ")
|
|||
|
if confirm.lower() != 'y':
|
|||
|
print("已取消删除操作。")
|
|||
|
continue
|
|||
|
|
|||
|
success = model_manager.delete_model(
|
|||
|
product_id=product_id,
|
|||
|
model_type=model_type,
|
|||
|
version=version
|
|||
|
)
|
|||
|
|
|||
|
if success:
|
|||
|
print("模型删除成功。")
|
|||
|
else:
|
|||
|
print("模型删除失败。")
|
|||
|
|
|||
|
elif choice == '7':
|
|||
|
# 导出模型
|
|||
|
product_id = input("请输入产品ID (例如P001): ")
|
|||
|
model_type = input("请输入模型类型 (mlstm/transformer/kan): ")
|
|||
|
version = input("请输入版本 (如不指定则使用最新版本): ") or None
|
|||
|
export_dir = input("请输入导出目录 (默认为exported_models): ") or 'exported_models'
|
|||
|
|
|||
|
export_path = model_manager.export_model(
|
|||
|
product_id=product_id,
|
|||
|
model_type=model_type,
|
|||
|
version=version,
|
|||
|
export_dir=export_dir
|
|||
|
)
|
|||
|
|
|||
|
if export_path:
|
|||
|
print(f"模型已成功导出到: {export_path}")
|
|||
|
else:
|
|||
|
print("模型导出失败。")
|
|||
|
|
|||
|
elif choice == '8':
|
|||
|
# 导入模型
|
|||
|
file_path = input("请输入要导入的模型文件路径: ")
|
|||
|
|
|||
|
if not os.path.exists(file_path):
|
|||
|
print(f"错误: 文件 {file_path} 不存在")
|
|||
|
continue
|
|||
|
|
|||
|
import_path = model_manager.import_model(
|
|||
|
import_file=file_path,
|
|||
|
overwrite=True
|
|||
|
)
|
|||
|
|
|||
|
if import_path:
|
|||
|
print(f"模型已成功导入到: {import_path}")
|
|||
|
else:
|
|||
|
print("模型导入失败。")
|
|||
|
|
|||
|
else:
|
|||
|
print("无效的选项,请重新输入!")
|
|||
|
|
|||
|
input("\n按回车键继续...")
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
# 检查命令行参数,如果没有参数则启动交互模式
|
|||
|
import sys
|
|||
|
if len(sys.argv) == 1:
|
|||
|
interactive_mode()
|
|||
|
else:
|
|||
|
main()
|