ShopTRAINING/test/check_model_files.py
2025-07-02 11:05:23 +08:00

126 lines
4.5 KiB
Python

#!/usr/bin/env python
"""
检查模型文件命名和管理
"""
import os
import sys
sys.path.append('server')
def check_model_files():
"""检查保存的模型文件"""
print("=== 检查模型文件 ===")
model_dir = 'saved_models'
if not os.path.exists(model_dir):
print("WARN: 模型目录不存在")
return
print(f"模型目录: {model_dir}")
# 列出所有文件
files = []
for root, dirs, filenames in os.walk(model_dir):
for filename in filenames:
if filename.endswith(('.pth', '.pt')):
full_path = os.path.join(root, filename)
rel_path = os.path.relpath(full_path, model_dir)
files.append((full_path, rel_path))
if not files:
print("INFO: 没有找到模型文件")
return
print(f"\n找到 {len(files)} 个模型文件:")
for full_path, rel_path in files:
print(f"\n文件: {rel_path}")
# 检查文件大小
size = os.path.getsize(full_path)
print(f" 大小: {size:,} bytes")
# 尝试解析文件名
try:
from utils.model_manager import model_manager
info = model_manager.parse_model_filename(os.path.basename(full_path))
if info:
print(f" 解析结果:")
print(f" 模型类型: {info.get('model_type', 'N/A')}")
print(f" 产品ID: {info.get('product_id', 'N/A')}")
print(f" 训练模式: {info.get('training_mode', 'N/A')}")
print(f" 版本: {info.get('version', 'N/A')}")
if info.get('store_id'):
print(f" 店铺ID: {info.get('store_id')}")
if info.get('aggregation_method'):
print(f" 聚合方法: {info.get('aggregation_method')}")
else:
print(" 解析: 无法解析文件名")
except Exception as e:
print(f" 解析错误: {e}")
# 尝试读取模型内容
try:
import torch
data = torch.load(full_path, map_location='cpu')
if isinstance(data, dict):
print(f" 内容包含:")
keys = list(data.keys())
print(f" 主要键: {keys[:5]}")
if 'model_manager_info' in data:
info = data['model_manager_info']
print(f" 管理信息: 包含 ({len(info)} 项)")
print(f" 产品名称: {info.get('product_name', 'N/A')}")
print(f" 创建时间: {info.get('created_at', 'N/A')}")
if 'metrics' in data:
metrics = data['metrics']
print(f" 评估指标: {len(metrics)}")
if 'rmse' in metrics:
print(f" RMSE: {metrics['rmse']:.4f}")
except Exception as e:
print(f" 读取错误: {e}")
def check_unified_api():
"""检查统一API是否正常工作"""
print("\n=== 检查统一API ===")
try:
from utils.model_manager import model_manager
models = model_manager.list_models()
print(f"统一管理器找到 {len(models)} 个模型")
if models:
print("\n模型详情:")
for i, model in enumerate(models):
print(f"{i+1}. {model.get('filename', 'N/A')}")
print(f" 产品: {model.get('product_name', model.get('product_id', 'N/A'))}")
print(f" 类型: {model.get('model_type', 'N/A')}")
print(f" 模式: {model.get('training_mode', 'N/A')}")
print(f" 文件大小: {model.get('file_size', 0):,} bytes")
return True
except Exception as e:
print(f"统一管理器错误: {e}")
return False
def main():
"""主函数"""
print("检查模型文件和管理系统")
check_model_files()
check_unified_api()
print("\n=== 总结 ===")
print("✅ 模型文件命名应该使用统一格式")
print("✅ 新训练的模型包含管理信息")
print("✅ 模型管理器可以正确解析和列出模型")
print("\n下一步:")
print("1. 启动API服务器测试前端集成")
print("2. 在前端查看模型是否正确显示")
print("3. 测试不同训练模式的模型命名")
if __name__ == "__main__":
main()