ShopTRAINING/test/test_model_id_fix.py
2025-07-02 11:05:23 +08:00

107 lines
3.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
测试模型ID修复
验证API返回的模型列表中model_id字段是否正确填充
"""
import sys
import os
import requests
import json
# 添加server目录到Python路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'server'))
def test_model_manager_directly():
"""直接测试ModelManager"""
print("=== 直接测试ModelManager ===")
try:
from utils.model_manager import model_manager
# 检查模型目录
print(f"模型目录: {model_manager.model_dir}")
print(f"目录存在: {os.path.exists(model_manager.model_dir)}")
# 获取模型列表
models = model_manager.list_models()
print(f"找到 {len(models)} 个模型")
for i, model in enumerate(models[:3]): # 只显示前3个
print(f"\n模型 {i+1}:")
print(f" filename: {model.get('filename', 'MISSING')}")
print(f" product_id: {model.get('product_id', 'MISSING')}")
print(f" model_type: {model.get('model_type', 'MISSING')}")
print(f" version: {model.get('version', 'MISSING')}")
print(f" training_mode: {model.get('training_mode', 'MISSING')}")
# 生成model_id使用API中的逻辑
model_id = model.get('filename', '').replace('.pth', '')
if not model_id:
product_id = model.get('product_id', 'unknown')
model_type = model.get('model_type', 'unknown')
version = model.get('version', 'v1')
training_mode = model.get('training_mode', 'product')
store_id = model.get('store_id')
if training_mode == 'store' and store_id:
model_id = f"{model_type}_store_{store_id}_{product_id}_{version}"
elif training_mode == 'global':
aggregation_method = model.get('aggregation_method', 'mean')
model_id = f"{model_type}_global_{product_id}_{aggregation_method}_{version}"
else:
model_id = f"{model_type}_product_{product_id}_{version}"
print(f" 生成的model_id: {model_id}")
except Exception as e:
print(f"直接测试失败: {e}")
import traceback
traceback.print_exc()
def test_api_endpoint():
"""测试API端点"""
print("\n=== 测试API端点 ===")
api_url = "http://localhost:5000/api/models"
try:
response = requests.get(api_url, timeout=10)
if response.status_code == 200:
data = response.json()
models = data.get('data', [])
print(f"API返回 {len(models)} 个模型")
for i, model in enumerate(models[:3]): # 只显示前3个
print(f"\n模型 {i+1}:")
print(f" model_id: {model.get('model_id', 'MISSING')}")
print(f" product_id: {model.get('product_id', 'MISSING')}")
print(f" model_type: {model.get('model_type', 'MISSING')}")
print(f" version: {model.get('version', 'MISSING')}")
# 检查model_id是否为空
if not model.get('model_id'):
print(" ❌ model_id为空!")
else:
print(" ✅ model_id已填充")
else:
print(f"API请求失败: {response.status_code}")
print(f"响应: {response.text}")
except requests.exceptions.ConnectionError:
print("无法连接到API服务器 (http://localhost:5000)")
print("请确保API服务器正在运行")
except Exception as e:
print(f"API测试失败: {e}")
if __name__ == "__main__":
print("测试模型ID修复...")
# 测试ModelManager
test_model_manager_directly()
# 测试API端点
test_api_endpoint()
print("\n测试完成!")