102 lines
4.1 KiB
Python
102 lines
4.1 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
Test script to verify API logic directly
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
|
||
# Add server directory to path
|
||
server_path = os.path.join(os.path.dirname(__file__), 'server')
|
||
sys.path.insert(0, server_path)
|
||
|
||
def test_api_models_logic():
|
||
"""Test the models endpoint logic directly"""
|
||
|
||
print("=== Testing API Models Logic ===")
|
||
|
||
try:
|
||
from utils.model_manager import model_manager
|
||
|
||
print("1. Testing ModelManager.list_models()...")
|
||
models = model_manager.list_models()
|
||
print(f"Found {len(models)} models")
|
||
|
||
print("\n2. Testing API formatting logic...")
|
||
formatted_models = []
|
||
for model in models:
|
||
# This is the same logic as in api.py line 1944-1959
|
||
model_id = model.get('filename', '').replace('.pth', '')
|
||
if not model_id:
|
||
# 备用方案:基于模型信息生成ID
|
||
product_id = model.get('product_id', 'unknown')
|
||
model_type = model.get('model_type', 'unknown')
|
||
version = model.get('version', 'v1')
|
||
training_mode = model.get('training_mode', 'product')
|
||
store_id = model.get('store_id')
|
||
|
||
if training_mode == 'store' and store_id:
|
||
model_id = f"{model_type}_store_{store_id}_{product_id}_{version}"
|
||
elif training_mode == 'global':
|
||
aggregation_method = model.get('aggregation_method', 'mean')
|
||
model_id = f"{model_type}_global_{product_id}_{aggregation_method}_{version}"
|
||
else:
|
||
model_id = f"{model_type}_product_{product_id}_{version}"
|
||
|
||
formatted_model = {
|
||
'model_id': model_id,
|
||
'product_id': model.get('product_id', ''),
|
||
'product_name': model.get('product_name', model.get('product_id', '')),
|
||
'model_type': model.get('model_type', ''),
|
||
'training_mode': model.get('training_mode', 'product'),
|
||
'store_id': model.get('store_id'),
|
||
'aggregation_method': model.get('aggregation_method'),
|
||
'version': model.get('version', 'v1'),
|
||
'created_at': model.get('created_at', model.get('modified_at', '')),
|
||
'file_size': model.get('file_size', 0),
|
||
'metrics': model.get('metrics', {}),
|
||
'config': model.get('config', {})
|
||
}
|
||
formatted_models.append(formatted_model)
|
||
|
||
print(f"Formatted {len(formatted_models)} models")
|
||
|
||
print("\n3. Showing formatted models...")
|
||
for i, model in enumerate(formatted_models):
|
||
print(f"\nModel {i+1}:")
|
||
print(f" model_id: {model['model_id']}")
|
||
print(f" product_id: {model['product_id']}")
|
||
print(f" model_type: {model['model_type']}")
|
||
print(f" training_mode: {model['training_mode']}")
|
||
print(f" version: {model['version']}")
|
||
print(f" file_size: {model['file_size']}")
|
||
|
||
# Check if model_id is valid
|
||
if model['model_id'] and model['model_id'] != '':
|
||
print(f" [OK] model_id is properly set")
|
||
else:
|
||
print(f" [ERROR] model_id is missing or empty")
|
||
|
||
print("\n4. Testing filters...")
|
||
# Test product_id filter
|
||
filtered_models = model_manager.list_models(product_id='P001')
|
||
print(f"P001 models: {len(filtered_models)}")
|
||
|
||
# Test model_type filter
|
||
mlstm_models = model_manager.list_models(model_type='mlstm')
|
||
print(f"MLSTM models: {len(mlstm_models)}")
|
||
|
||
kan_models = model_manager.list_models(model_type='kan_optimized')
|
||
print(f"KAN optimized models: {len(kan_models)}")
|
||
|
||
print(f"\n[SUCCESS] All tests completed successfully!")
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"[ERROR] Error: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
if __name__ == "__main__":
|
||
test_api_models_logic() |