ShopTRAINING/server/model_management.py

351 lines
13 KiB
Python
Raw Normal View History

import os
import pandas as pd
import torch
import argparse
from models import ModelManager
def main():
"""模型管理工具主函数"""
parser = argparse.ArgumentParser(description='药店销售预测系统 - 模型管理工具')
parser.add_argument('--action', type=str, required=True,
choices=['list', 'details', 'predict', 'compare', 'delete', 'export', 'import'],
help='执行的操作类型')
parser.add_argument('--product_id', type=str, help='产品ID例如P001')
parser.add_argument('--model_type', type=str, help='模型类型 (mlstm/transformer/kan)')
parser.add_argument('--version', type=str, help='模型版本')
parser.add_argument('--file_path', type=str, help='导入/导出的文件路径')
parser.add_argument('--export_dir', type=str, default='exported_models', help='导出目录')
parser.add_argument('--compare_models', type=str, help='要比较的模型类型,用逗号分隔')
args = parser.parse_args()
# 初始化模型管理器
model_manager = ModelManager()
# 根据操作类型执行相应的功能
if args.action == 'list':
# 列出模型
models = model_manager.list_models(
product_id=args.product_id,
model_type=args.model_type
)
if not models:
print("没有找到匹配的模型。")
return
print(f"\n找到 {len(models)} 个模型:")
for i, model in enumerate(models, 1):
print(f"\n{i}. {model['product_id']} - {model['model_type'].upper()} (版本: {model['version']})")
print(f" 创建时间: {model['created_at']}")
print(f" 文件大小: {model['file_size']}")
# 如果有评估指标,显示它们
if 'metrics' in model:
print(" 评估指标:")
for metric, value in model['metrics'].items():
print(f" {metric}: {value}")
elif args.action == 'details':
# 显示模型详情
if not args.product_id or not args.model_type:
print("错误: 需要指定product_id和model_type")
return
details = model_manager.get_model_details(
product_id=args.product_id,
model_type=args.model_type,
version=args.version
)
if not details:
print("未找到模型详情。")
return
print(f"\n{args.product_id} - {args.model_type.upper()} 模型详情:")
for key, value in details.items():
if key == 'metrics':
print(f"\n评估指标:")
for metric, metric_value in value.items():
print(f" {metric}: {metric_value}")
else:
print(f"{key}: {value}")
elif args.action == 'predict':
# 使用模型进行预测
if not args.product_id:
print("错误: 需要指定product_id")
return
# 使用指定的模型类型如果未指定则使用mlstm
model_type = args.model_type if args.model_type else 'mlstm'
# 执行预测
predictions = model_manager.predict_with_model(
product_id=args.product_id,
model_type=model_type,
version=args.version
)
if predictions is None:
print("预测失败。")
elif args.action == 'compare':
# 比较不同模型的预测结果
if not args.product_id:
print("错误: 需要指定product_id")
return
# 确定要比较的模型
if args.compare_models:
model_types = args.compare_models.split(',')
else:
model_types = ['mlstm', 'transformer', 'kan']
# 执行比较
comparison = model_manager.compare_models(
product_id=args.product_id,
model_types=model_types
)
if comparison is None:
print("比较失败。")
elif args.action == 'delete':
# 删除模型
if not args.product_id or not args.model_type:
print("错误: 需要指定product_id和model_type")
return
success = model_manager.delete_model(
product_id=args.product_id,
model_type=args.model_type,
version=args.version
)
if success:
print("模型删除成功。")
else:
print("模型删除失败。")
elif args.action == 'export':
# 导出模型
if not args.product_id or not args.model_type:
print("错误: 需要指定product_id和model_type")
return
export_path = model_manager.export_model(
product_id=args.product_id,
model_type=args.model_type,
version=args.version,
export_dir=args.export_dir
)
if export_path:
print(f"模型已成功导出到: {export_path}")
else:
print("模型导出失败。")
elif args.action == 'import':
# 导入模型
if not args.file_path:
print("错误: 需要指定file_path")
return
import_path = model_manager.import_model(
import_file=args.file_path,
overwrite=True
)
if import_path:
print(f"模型已成功导入到: {import_path}")
else:
print("模型导入失败。")
def interactive_mode():
"""交互模式,通过菜单与用户交互"""
model_manager = ModelManager()
while True:
print("\n" + "="*50)
print("📊 药店销售预测系统 - 模型管理工具 📊")
print("="*50)
print("1. 查看所有模型")
print("2. 查看特定产品的模型")
print("3. 查看特定模型的详细信息")
print("4. 使用模型进行预测")
print("5. 比较不同模型的预测结果")
print("6. 删除模型")
print("7. 导出模型")
print("8. 导入模型")
print("0. 退出")
print("="*50)
choice = input("请输入选项 (0-8): ")
if choice == '0':
print("感谢使用模型管理工具!再见!")
break
elif choice == '1':
# 查看所有模型
models = model_manager.list_models()
if not models:
print("没有找到任何模型。")
continue
print(f"\n找到 {len(models)} 个模型:")
for i, model in enumerate(models, 1):
print(f"\n{i}. {model['product_id']} - {model['model_type'].upper()} (版本: {model['version']})")
print(f" 创建时间: {model['created_at']}")
print(f" 文件大小: {model['file_size']}")
elif choice == '2':
# 查看特定产品的模型
product_id = input("请输入产品ID (例如P001): ")
models = model_manager.list_models(product_id=product_id)
if not models:
print(f"没有找到产品 {product_id} 的模型。")
continue
print(f"\n找到 {len(models)}{product_id} 的模型:")
for i, model in enumerate(models, 1):
print(f"\n{i}. {model['model_type'].upper()} (版本: {model['version']})")
print(f" 创建时间: {model['created_at']}")
print(f" 文件大小: {model['file_size']}")
elif choice == '3':
# 查看模型详情
product_id = input("请输入产品ID (例如P001): ")
model_type = input("请输入模型类型 (mlstm/transformer/kan): ")
version = input("请输入版本 (如不指定则使用最新版本): ") or None
details = model_manager.get_model_details(
product_id=product_id,
model_type=model_type,
version=version
)
if not details:
print("未找到模型详情。")
continue
print(f"\n{product_id} - {model_type.upper()} 模型详情:")
for key, value in details.items():
if key == 'metrics':
print(f"\n评估指标:")
for metric, metric_value in value.items():
print(f" {metric}: {metric_value}")
else:
print(f"{key}: {value}")
elif choice == '4':
# 使用模型预测
product_id = input("请输入产品ID (例如P001): ")
model_type = input("请输入模型类型 (mlstm/transformer/kan): ") or 'mlstm'
version = input("请输入版本 (如不指定则使用最新版本): ") or None
# 执行预测
predictions = model_manager.predict_with_model(
product_id=product_id,
model_type=model_type,
version=version
)
if predictions is None:
print("预测失败。")
elif choice == '5':
# 比较模型
product_id = input("请输入产品ID (例如P001): ")
compare_models = input("请输入要比较的模型类型,用逗号分隔 (如不指定则比较所有模型): ") or None
# 确定要比较的模型
if compare_models:
model_types = compare_models.split(',')
else:
model_types = ['mlstm', 'transformer', 'kan']
# 执行比较
comparison = model_manager.compare_models(
product_id=product_id,
model_types=model_types
)
if comparison is None:
print("比较失败。")
elif choice == '6':
# 删除模型
product_id = input("请输入产品ID (例如P001): ")
model_type = input("请输入模型类型 (mlstm/transformer/kan): ")
version = input("请输入版本 (如不指定则删除所有版本): ") or None
confirm = input(f"确定要删除 {product_id}{model_type} 模型吗?(y/n): ")
if confirm.lower() != 'y':
print("已取消删除操作。")
continue
success = model_manager.delete_model(
product_id=product_id,
model_type=model_type,
version=version
)
if success:
print("模型删除成功。")
else:
print("模型删除失败。")
elif choice == '7':
# 导出模型
product_id = input("请输入产品ID (例如P001): ")
model_type = input("请输入模型类型 (mlstm/transformer/kan): ")
version = input("请输入版本 (如不指定则使用最新版本): ") or None
export_dir = input("请输入导出目录 (默认为exported_models): ") or 'exported_models'
export_path = model_manager.export_model(
product_id=product_id,
model_type=model_type,
version=version,
export_dir=export_dir
)
if export_path:
print(f"模型已成功导出到: {export_path}")
else:
print("模型导出失败。")
elif choice == '8':
# 导入模型
file_path = input("请输入要导入的模型文件路径: ")
if not os.path.exists(file_path):
print(f"错误: 文件 {file_path} 不存在")
continue
import_path = model_manager.import_model(
import_file=file_path,
overwrite=True
)
if import_path:
print(f"模型已成功导入到: {import_path}")
else:
print("模型导入失败。")
else:
print("无效的选项,请重新输入!")
input("\n按回车键继续...")
if __name__ == "__main__":
# 检查命令行参数,如果没有参数则启动交互模式
import sys
if len(sys.argv) == 1:
interactive_mode()
else:
main()