#!/usr/bin/env python """ 检查模型文件命名和管理 """ import os import sys sys.path.append('server') def check_model_files(): """检查保存的模型文件""" print("=== 检查模型文件 ===") model_dir = 'saved_models' if not os.path.exists(model_dir): print("WARN: 模型目录不存在") return print(f"模型目录: {model_dir}") # 列出所有文件 files = [] for root, dirs, filenames in os.walk(model_dir): for filename in filenames: if filename.endswith(('.pth', '.pt')): full_path = os.path.join(root, filename) rel_path = os.path.relpath(full_path, model_dir) files.append((full_path, rel_path)) if not files: print("INFO: 没有找到模型文件") return print(f"\n找到 {len(files)} 个模型文件:") for full_path, rel_path in files: print(f"\n文件: {rel_path}") # 检查文件大小 size = os.path.getsize(full_path) print(f" 大小: {size:,} bytes") # 尝试解析文件名 try: from utils.model_manager import model_manager info = model_manager.parse_model_filename(os.path.basename(full_path)) if info: print(f" 解析结果:") print(f" 模型类型: {info.get('model_type', 'N/A')}") print(f" 产品ID: {info.get('product_id', 'N/A')}") print(f" 训练模式: {info.get('training_mode', 'N/A')}") print(f" 版本: {info.get('version', 'N/A')}") if info.get('store_id'): print(f" 店铺ID: {info.get('store_id')}") if info.get('aggregation_method'): print(f" 聚合方法: {info.get('aggregation_method')}") else: print(" 解析: 无法解析文件名") except Exception as e: print(f" 解析错误: {e}") # 尝试读取模型内容 try: import torch data = torch.load(full_path, map_location='cpu') if isinstance(data, dict): print(f" 内容包含:") keys = list(data.keys()) print(f" 主要键: {keys[:5]}") if 'model_manager_info' in data: info = data['model_manager_info'] print(f" 管理信息: 包含 ({len(info)} 项)") print(f" 产品名称: {info.get('product_name', 'N/A')}") print(f" 创建时间: {info.get('created_at', 'N/A')}") if 'metrics' in data: metrics = data['metrics'] print(f" 评估指标: {len(metrics)} 项") if 'rmse' in metrics: print(f" RMSE: {metrics['rmse']:.4f}") except Exception as e: print(f" 读取错误: {e}") def check_unified_api(): """检查统一API是否正常工作""" print("\n=== 检查统一API ===") try: from utils.model_manager import model_manager models = model_manager.list_models() print(f"统一管理器找到 {len(models)} 个模型") if models: print("\n模型详情:") for i, model in enumerate(models): print(f"{i+1}. {model.get('filename', 'N/A')}") print(f" 产品: {model.get('product_name', model.get('product_id', 'N/A'))}") print(f" 类型: {model.get('model_type', 'N/A')}") print(f" 模式: {model.get('training_mode', 'N/A')}") print(f" 文件大小: {model.get('file_size', 0):,} bytes") return True except Exception as e: print(f"统一管理器错误: {e}") return False def main(): """主函数""" print("检查模型文件和管理系统") check_model_files() check_unified_api() print("\n=== 总结 ===") print("✅ 模型文件命名应该使用统一格式") print("✅ 新训练的模型包含管理信息") print("✅ 模型管理器可以正确解析和列出模型") print("\n下一步:") print("1. 启动API服务器测试前端集成") print("2. 在前端查看模型是否正确显示") print("3. 测试不同训练模式的模型命名") if __name__ == "__main__": main()