ShopTRAINING/test/test_model_id_fix.py

107 lines
3.9 KiB
Python
Raw Normal View History

2025-07-02 11:05:23 +08:00
#!/usr/bin/env python3
"""
测试模型ID修复
验证API返回的模型列表中model_id字段是否正确填充
"""
import sys
import os
import requests
import json
# 添加server目录到Python路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'server'))
def test_model_manager_directly():
"""直接测试ModelManager"""
print("=== 直接测试ModelManager ===")
try:
from utils.model_manager import model_manager
# 检查模型目录
print(f"模型目录: {model_manager.model_dir}")
print(f"目录存在: {os.path.exists(model_manager.model_dir)}")
# 获取模型列表
models = model_manager.list_models()
print(f"找到 {len(models)} 个模型")
for i, model in enumerate(models[:3]): # 只显示前3个
print(f"\n模型 {i+1}:")
print(f" filename: {model.get('filename', 'MISSING')}")
print(f" product_id: {model.get('product_id', 'MISSING')}")
print(f" model_type: {model.get('model_type', 'MISSING')}")
print(f" version: {model.get('version', 'MISSING')}")
print(f" training_mode: {model.get('training_mode', 'MISSING')}")
# 生成model_id使用API中的逻辑
model_id = model.get('filename', '').replace('.pth', '')
if not model_id:
product_id = model.get('product_id', 'unknown')
model_type = model.get('model_type', 'unknown')
version = model.get('version', 'v1')
training_mode = model.get('training_mode', 'product')
store_id = model.get('store_id')
if training_mode == 'store' and store_id:
model_id = f"{model_type}_store_{store_id}_{product_id}_{version}"
elif training_mode == 'global':
aggregation_method = model.get('aggregation_method', 'mean')
model_id = f"{model_type}_global_{product_id}_{aggregation_method}_{version}"
else:
model_id = f"{model_type}_product_{product_id}_{version}"
print(f" 生成的model_id: {model_id}")
except Exception as e:
print(f"直接测试失败: {e}")
import traceback
traceback.print_exc()
def test_api_endpoint():
"""测试API端点"""
print("\n=== 测试API端点 ===")
api_url = "http://localhost:5000/api/models"
try:
response = requests.get(api_url, timeout=10)
if response.status_code == 200:
data = response.json()
models = data.get('data', [])
print(f"API返回 {len(models)} 个模型")
for i, model in enumerate(models[:3]): # 只显示前3个
print(f"\n模型 {i+1}:")
print(f" model_id: {model.get('model_id', 'MISSING')}")
print(f" product_id: {model.get('product_id', 'MISSING')}")
print(f" model_type: {model.get('model_type', 'MISSING')}")
print(f" version: {model.get('version', 'MISSING')}")
# 检查model_id是否为空
if not model.get('model_id'):
print(" ❌ model_id为空!")
else:
print(" ✅ model_id已填充")
else:
print(f"API请求失败: {response.status_code}")
print(f"响应: {response.text}")
except requests.exceptions.ConnectionError:
print("无法连接到API服务器 (http://localhost:5000)")
print("请确保API服务器正在运行")
except Exception as e:
print(f"API测试失败: {e}")
if __name__ == "__main__":
print("测试模型ID修复...")
# 测试ModelManager
test_model_manager_directly()
# 测试API端点
test_api_endpoint()
print("\n测试完成!")