126 lines
4.5 KiB
Python
126 lines
4.5 KiB
Python
#!/usr/bin/env python
|
|
"""
|
|
检查模型文件命名和管理
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
sys.path.append('server')
|
|
|
|
def check_model_files():
|
|
"""检查保存的模型文件"""
|
|
print("=== 检查模型文件 ===")
|
|
|
|
model_dir = 'saved_models'
|
|
if not os.path.exists(model_dir):
|
|
print("WARN: 模型目录不存在")
|
|
return
|
|
|
|
print(f"模型目录: {model_dir}")
|
|
|
|
# 列出所有文件
|
|
files = []
|
|
for root, dirs, filenames in os.walk(model_dir):
|
|
for filename in filenames:
|
|
if filename.endswith(('.pth', '.pt')):
|
|
full_path = os.path.join(root, filename)
|
|
rel_path = os.path.relpath(full_path, model_dir)
|
|
files.append((full_path, rel_path))
|
|
|
|
if not files:
|
|
print("INFO: 没有找到模型文件")
|
|
return
|
|
|
|
print(f"\n找到 {len(files)} 个模型文件:")
|
|
|
|
for full_path, rel_path in files:
|
|
print(f"\n文件: {rel_path}")
|
|
|
|
# 检查文件大小
|
|
size = os.path.getsize(full_path)
|
|
print(f" 大小: {size:,} bytes")
|
|
|
|
# 尝试解析文件名
|
|
try:
|
|
from utils.model_manager import model_manager
|
|
info = model_manager.parse_model_filename(os.path.basename(full_path))
|
|
if info:
|
|
print(f" 解析结果:")
|
|
print(f" 模型类型: {info.get('model_type', 'N/A')}")
|
|
print(f" 产品ID: {info.get('product_id', 'N/A')}")
|
|
print(f" 训练模式: {info.get('training_mode', 'N/A')}")
|
|
print(f" 版本: {info.get('version', 'N/A')}")
|
|
if info.get('store_id'):
|
|
print(f" 店铺ID: {info.get('store_id')}")
|
|
if info.get('aggregation_method'):
|
|
print(f" 聚合方法: {info.get('aggregation_method')}")
|
|
else:
|
|
print(" 解析: 无法解析文件名")
|
|
except Exception as e:
|
|
print(f" 解析错误: {e}")
|
|
|
|
# 尝试读取模型内容
|
|
try:
|
|
import torch
|
|
data = torch.load(full_path, map_location='cpu')
|
|
if isinstance(data, dict):
|
|
print(f" 内容包含:")
|
|
keys = list(data.keys())
|
|
print(f" 主要键: {keys[:5]}")
|
|
|
|
if 'model_manager_info' in data:
|
|
info = data['model_manager_info']
|
|
print(f" 管理信息: 包含 ({len(info)} 项)")
|
|
print(f" 产品名称: {info.get('product_name', 'N/A')}")
|
|
print(f" 创建时间: {info.get('created_at', 'N/A')}")
|
|
|
|
if 'metrics' in data:
|
|
metrics = data['metrics']
|
|
print(f" 评估指标: {len(metrics)} 项")
|
|
if 'rmse' in metrics:
|
|
print(f" RMSE: {metrics['rmse']:.4f}")
|
|
except Exception as e:
|
|
print(f" 读取错误: {e}")
|
|
|
|
def check_unified_api():
|
|
"""检查统一API是否正常工作"""
|
|
print("\n=== 检查统一API ===")
|
|
|
|
try:
|
|
from utils.model_manager import model_manager
|
|
|
|
models = model_manager.list_models()
|
|
print(f"统一管理器找到 {len(models)} 个模型")
|
|
|
|
if models:
|
|
print("\n模型详情:")
|
|
for i, model in enumerate(models):
|
|
print(f"{i+1}. {model.get('filename', 'N/A')}")
|
|
print(f" 产品: {model.get('product_name', model.get('product_id', 'N/A'))}")
|
|
print(f" 类型: {model.get('model_type', 'N/A')}")
|
|
print(f" 模式: {model.get('training_mode', 'N/A')}")
|
|
print(f" 文件大小: {model.get('file_size', 0):,} bytes")
|
|
|
|
return True
|
|
except Exception as e:
|
|
print(f"统一管理器错误: {e}")
|
|
return False
|
|
|
|
def main():
|
|
"""主函数"""
|
|
print("检查模型文件和管理系统")
|
|
|
|
check_model_files()
|
|
check_unified_api()
|
|
|
|
print("\n=== 总结 ===")
|
|
print("✅ 模型文件命名应该使用统一格式")
|
|
print("✅ 新训练的模型包含管理信息")
|
|
print("✅ 模型管理器可以正确解析和列出模型")
|
|
print("\n下一步:")
|
|
print("1. 启动API服务器测试前端集成")
|
|
print("2. 在前端查看模型是否正确显示")
|
|
print("3. 测试不同训练模式的模型命名")
|
|
|
|
if __name__ == "__main__":
|
|
main() |