107 lines
3.9 KiB
Python
107 lines
3.9 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
测试模型ID修复
|
||
验证API返回的模型列表中model_id字段是否正确填充
|
||
"""
|
||
|
||
import sys
|
||
import os
|
||
import requests
|
||
import json
|
||
|
||
# 添加server目录到Python路径
|
||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'server'))
|
||
|
||
def test_model_manager_directly():
|
||
"""直接测试ModelManager"""
|
||
print("=== 直接测试ModelManager ===")
|
||
|
||
try:
|
||
from utils.model_manager import model_manager
|
||
|
||
# 检查模型目录
|
||
print(f"模型目录: {model_manager.model_dir}")
|
||
print(f"目录存在: {os.path.exists(model_manager.model_dir)}")
|
||
|
||
# 获取模型列表
|
||
models = model_manager.list_models()
|
||
print(f"找到 {len(models)} 个模型")
|
||
|
||
for i, model in enumerate(models[:3]): # 只显示前3个
|
||
print(f"\n模型 {i+1}:")
|
||
print(f" filename: {model.get('filename', 'MISSING')}")
|
||
print(f" product_id: {model.get('product_id', 'MISSING')}")
|
||
print(f" model_type: {model.get('model_type', 'MISSING')}")
|
||
print(f" version: {model.get('version', 'MISSING')}")
|
||
print(f" training_mode: {model.get('training_mode', 'MISSING')}")
|
||
|
||
# 生成model_id(使用API中的逻辑)
|
||
model_id = model.get('filename', '').replace('.pth', '')
|
||
if not model_id:
|
||
product_id = model.get('product_id', 'unknown')
|
||
model_type = model.get('model_type', 'unknown')
|
||
version = model.get('version', 'v1')
|
||
training_mode = model.get('training_mode', 'product')
|
||
store_id = model.get('store_id')
|
||
|
||
if training_mode == 'store' and store_id:
|
||
model_id = f"{model_type}_store_{store_id}_{product_id}_{version}"
|
||
elif training_mode == 'global':
|
||
aggregation_method = model.get('aggregation_method', 'mean')
|
||
model_id = f"{model_type}_global_{product_id}_{aggregation_method}_{version}"
|
||
else:
|
||
model_id = f"{model_type}_product_{product_id}_{version}"
|
||
|
||
print(f" 生成的model_id: {model_id}")
|
||
|
||
except Exception as e:
|
||
print(f"直接测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
def test_api_endpoint():
|
||
"""测试API端点"""
|
||
print("\n=== 测试API端点 ===")
|
||
|
||
api_url = "http://localhost:5000/api/models"
|
||
|
||
try:
|
||
response = requests.get(api_url, timeout=10)
|
||
|
||
if response.status_code == 200:
|
||
data = response.json()
|
||
models = data.get('data', [])
|
||
print(f"API返回 {len(models)} 个模型")
|
||
|
||
for i, model in enumerate(models[:3]): # 只显示前3个
|
||
print(f"\n模型 {i+1}:")
|
||
print(f" model_id: {model.get('model_id', 'MISSING')}")
|
||
print(f" product_id: {model.get('product_id', 'MISSING')}")
|
||
print(f" model_type: {model.get('model_type', 'MISSING')}")
|
||
print(f" version: {model.get('version', 'MISSING')}")
|
||
|
||
# 检查model_id是否为空
|
||
if not model.get('model_id'):
|
||
print(" ❌ model_id为空!")
|
||
else:
|
||
print(" ✅ model_id已填充")
|
||
else:
|
||
print(f"API请求失败: {response.status_code}")
|
||
print(f"响应: {response.text}")
|
||
|
||
except requests.exceptions.ConnectionError:
|
||
print("无法连接到API服务器 (http://localhost:5000)")
|
||
print("请确保API服务器正在运行")
|
||
except Exception as e:
|
||
print(f"API测试失败: {e}")
|
||
|
||
if __name__ == "__main__":
|
||
print("测试模型ID修复...")
|
||
|
||
# 测试ModelManager
|
||
test_model_manager_directly()
|
||
|
||
# 测试API端点
|
||
test_api_endpoint()
|
||
|
||
print("\n测试完成!") |