ShopTRAINING/server/model_management.py

351 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import pandas as pd
import torch
import argparse
from models import ModelManager
def main():
"""模型管理工具主函数"""
parser = argparse.ArgumentParser(description='药店销售预测系统 - 模型管理工具')
parser.add_argument('--action', type=str, required=True,
choices=['list', 'details', 'predict', 'compare', 'delete', 'export', 'import'],
help='执行的操作类型')
parser.add_argument('--product_id', type=str, help='产品ID例如P001')
parser.add_argument('--model_type', type=str, help='模型类型 (mlstm/transformer/kan)')
parser.add_argument('--version', type=str, help='模型版本')
parser.add_argument('--file_path', type=str, help='导入/导出的文件路径')
parser.add_argument('--export_dir', type=str, default='exported_models', help='导出目录')
parser.add_argument('--compare_models', type=str, help='要比较的模型类型,用逗号分隔')
args = parser.parse_args()
# 初始化模型管理器
model_manager = ModelManager()
# 根据操作类型执行相应的功能
if args.action == 'list':
# 列出模型
models = model_manager.list_models(
product_id=args.product_id,
model_type=args.model_type
)
if not models:
print("没有找到匹配的模型。")
return
print(f"\n找到 {len(models)} 个模型:")
for i, model in enumerate(models, 1):
print(f"\n{i}. {model['product_id']} - {model['model_type'].upper()} (版本: {model['version']})")
print(f" 创建时间: {model['created_at']}")
print(f" 文件大小: {model['file_size']}")
# 如果有评估指标,显示它们
if 'metrics' in model:
print(" 评估指标:")
for metric, value in model['metrics'].items():
print(f" {metric}: {value}")
elif args.action == 'details':
# 显示模型详情
if not args.product_id or not args.model_type:
print("错误: 需要指定product_id和model_type")
return
details = model_manager.get_model_details(
product_id=args.product_id,
model_type=args.model_type,
version=args.version
)
if not details:
print("未找到模型详情。")
return
print(f"\n{args.product_id} - {args.model_type.upper()} 模型详情:")
for key, value in details.items():
if key == 'metrics':
print(f"\n评估指标:")
for metric, metric_value in value.items():
print(f" {metric}: {metric_value}")
else:
print(f"{key}: {value}")
elif args.action == 'predict':
# 使用模型进行预测
if not args.product_id:
print("错误: 需要指定product_id")
return
# 使用指定的模型类型如果未指定则使用mlstm
model_type = args.model_type if args.model_type else 'mlstm'
# 执行预测
predictions = model_manager.predict_with_model(
product_id=args.product_id,
model_type=model_type,
version=args.version
)
if predictions is None:
print("预测失败。")
elif args.action == 'compare':
# 比较不同模型的预测结果
if not args.product_id:
print("错误: 需要指定product_id")
return
# 确定要比较的模型
if args.compare_models:
model_types = args.compare_models.split(',')
else:
model_types = ['mlstm', 'transformer', 'kan']
# 执行比较
comparison = model_manager.compare_models(
product_id=args.product_id,
model_types=model_types
)
if comparison is None:
print("比较失败。")
elif args.action == 'delete':
# 删除模型
if not args.product_id or not args.model_type:
print("错误: 需要指定product_id和model_type")
return
success = model_manager.delete_model(
product_id=args.product_id,
model_type=args.model_type,
version=args.version
)
if success:
print("模型删除成功。")
else:
print("模型删除失败。")
elif args.action == 'export':
# 导出模型
if not args.product_id or not args.model_type:
print("错误: 需要指定product_id和model_type")
return
export_path = model_manager.export_model(
product_id=args.product_id,
model_type=args.model_type,
version=args.version,
export_dir=args.export_dir
)
if export_path:
print(f"模型已成功导出到: {export_path}")
else:
print("模型导出失败。")
elif args.action == 'import':
# 导入模型
if not args.file_path:
print("错误: 需要指定file_path")
return
import_path = model_manager.import_model(
import_file=args.file_path,
overwrite=True
)
if import_path:
print(f"模型已成功导入到: {import_path}")
else:
print("模型导入失败。")
def interactive_mode():
"""交互模式,通过菜单与用户交互"""
model_manager = ModelManager()
while True:
print("\n" + "="*50)
print("📊 药店销售预测系统 - 模型管理工具 📊")
print("="*50)
print("1. 查看所有模型")
print("2. 查看特定产品的模型")
print("3. 查看特定模型的详细信息")
print("4. 使用模型进行预测")
print("5. 比较不同模型的预测结果")
print("6. 删除模型")
print("7. 导出模型")
print("8. 导入模型")
print("0. 退出")
print("="*50)
choice = input("请输入选项 (0-8): ")
if choice == '0':
print("感谢使用模型管理工具!再见!")
break
elif choice == '1':
# 查看所有模型
models = model_manager.list_models()
if not models:
print("没有找到任何模型。")
continue
print(f"\n找到 {len(models)} 个模型:")
for i, model in enumerate(models, 1):
print(f"\n{i}. {model['product_id']} - {model['model_type'].upper()} (版本: {model['version']})")
print(f" 创建时间: {model['created_at']}")
print(f" 文件大小: {model['file_size']}")
elif choice == '2':
# 查看特定产品的模型
product_id = input("请输入产品ID (例如P001): ")
models = model_manager.list_models(product_id=product_id)
if not models:
print(f"没有找到产品 {product_id} 的模型。")
continue
print(f"\n找到 {len(models)}{product_id} 的模型:")
for i, model in enumerate(models, 1):
print(f"\n{i}. {model['model_type'].upper()} (版本: {model['version']})")
print(f" 创建时间: {model['created_at']}")
print(f" 文件大小: {model['file_size']}")
elif choice == '3':
# 查看模型详情
product_id = input("请输入产品ID (例如P001): ")
model_type = input("请输入模型类型 (mlstm/transformer/kan): ")
version = input("请输入版本 (如不指定则使用最新版本): ") or None
details = model_manager.get_model_details(
product_id=product_id,
model_type=model_type,
version=version
)
if not details:
print("未找到模型详情。")
continue
print(f"\n{product_id} - {model_type.upper()} 模型详情:")
for key, value in details.items():
if key == 'metrics':
print(f"\n评估指标:")
for metric, metric_value in value.items():
print(f" {metric}: {metric_value}")
else:
print(f"{key}: {value}")
elif choice == '4':
# 使用模型预测
product_id = input("请输入产品ID (例如P001): ")
model_type = input("请输入模型类型 (mlstm/transformer/kan): ") or 'mlstm'
version = input("请输入版本 (如不指定则使用最新版本): ") or None
# 执行预测
predictions = model_manager.predict_with_model(
product_id=product_id,
model_type=model_type,
version=version
)
if predictions is None:
print("预测失败。")
elif choice == '5':
# 比较模型
product_id = input("请输入产品ID (例如P001): ")
compare_models = input("请输入要比较的模型类型,用逗号分隔 (如不指定则比较所有模型): ") or None
# 确定要比较的模型
if compare_models:
model_types = compare_models.split(',')
else:
model_types = ['mlstm', 'transformer', 'kan']
# 执行比较
comparison = model_manager.compare_models(
product_id=product_id,
model_types=model_types
)
if comparison is None:
print("比较失败。")
elif choice == '6':
# 删除模型
product_id = input("请输入产品ID (例如P001): ")
model_type = input("请输入模型类型 (mlstm/transformer/kan): ")
version = input("请输入版本 (如不指定则删除所有版本): ") or None
confirm = input(f"确定要删除 {product_id}{model_type} 模型吗?(y/n): ")
if confirm.lower() != 'y':
print("已取消删除操作。")
continue
success = model_manager.delete_model(
product_id=product_id,
model_type=model_type,
version=version
)
if success:
print("模型删除成功。")
else:
print("模型删除失败。")
elif choice == '7':
# 导出模型
product_id = input("请输入产品ID (例如P001): ")
model_type = input("请输入模型类型 (mlstm/transformer/kan): ")
version = input("请输入版本 (如不指定则使用最新版本): ") or None
export_dir = input("请输入导出目录 (默认为exported_models): ") or 'exported_models'
export_path = model_manager.export_model(
product_id=product_id,
model_type=model_type,
version=version,
export_dir=export_dir
)
if export_path:
print(f"模型已成功导出到: {export_path}")
else:
print("模型导出失败。")
elif choice == '8':
# 导入模型
file_path = input("请输入要导入的模型文件路径: ")
if not os.path.exists(file_path):
print(f"错误: 文件 {file_path} 不存在")
continue
import_path = model_manager.import_model(
import_file=file_path,
overwrite=True
)
if import_path:
print(f"模型已成功导入到: {import_path}")
else:
print("模型导入失败。")
else:
print("无效的选项,请重新输入!")
input("\n按回车键继续...")
if __name__ == "__main__":
# 检查命令行参数,如果没有参数则启动交互模式
import sys
if len(sys.argv) == 1:
interactive_mode()
else:
main()