ShopTRAINING/test/test_api_logic.py

102 lines
4.1 KiB
Python
Raw Permalink Normal View History

2025-07-02 11:05:23 +08:00
#!/usr/bin/env python3
"""
Test script to verify API logic directly
"""
import os
import sys
# Add server directory to path
server_path = os.path.join(os.path.dirname(__file__), 'server')
sys.path.insert(0, server_path)
def test_api_models_logic():
"""Test the models endpoint logic directly"""
print("=== Testing API Models Logic ===")
try:
from utils.model_manager import model_manager
print("1. Testing ModelManager.list_models()...")
models = model_manager.list_models()
print(f"Found {len(models)} models")
print("\n2. Testing API formatting logic...")
formatted_models = []
for model in models:
# This is the same logic as in api.py line 1944-1959
model_id = model.get('filename', '').replace('.pth', '')
if not model_id:
# 备用方案基于模型信息生成ID
product_id = model.get('product_id', 'unknown')
model_type = model.get('model_type', 'unknown')
version = model.get('version', 'v1')
training_mode = model.get('training_mode', 'product')
store_id = model.get('store_id')
if training_mode == 'store' and store_id:
model_id = f"{model_type}_store_{store_id}_{product_id}_{version}"
elif training_mode == 'global':
aggregation_method = model.get('aggregation_method', 'mean')
model_id = f"{model_type}_global_{product_id}_{aggregation_method}_{version}"
else:
model_id = f"{model_type}_product_{product_id}_{version}"
formatted_model = {
'model_id': model_id,
'product_id': model.get('product_id', ''),
'product_name': model.get('product_name', model.get('product_id', '')),
'model_type': model.get('model_type', ''),
'training_mode': model.get('training_mode', 'product'),
'store_id': model.get('store_id'),
'aggregation_method': model.get('aggregation_method'),
'version': model.get('version', 'v1'),
'created_at': model.get('created_at', model.get('modified_at', '')),
'file_size': model.get('file_size', 0),
'metrics': model.get('metrics', {}),
'config': model.get('config', {})
}
formatted_models.append(formatted_model)
print(f"Formatted {len(formatted_models)} models")
print("\n3. Showing formatted models...")
for i, model in enumerate(formatted_models):
print(f"\nModel {i+1}:")
print(f" model_id: {model['model_id']}")
print(f" product_id: {model['product_id']}")
print(f" model_type: {model['model_type']}")
print(f" training_mode: {model['training_mode']}")
print(f" version: {model['version']}")
print(f" file_size: {model['file_size']}")
# Check if model_id is valid
if model['model_id'] and model['model_id'] != '':
print(f" [OK] model_id is properly set")
else:
print(f" [ERROR] model_id is missing or empty")
print("\n4. Testing filters...")
# Test product_id filter
filtered_models = model_manager.list_models(product_id='P001')
print(f"P001 models: {len(filtered_models)}")
# Test model_type filter
mlstm_models = model_manager.list_models(model_type='mlstm')
print(f"MLSTM models: {len(mlstm_models)}")
kan_models = model_manager.list_models(model_type='kan_optimized')
print(f"KAN optimized models: {len(kan_models)}")
print(f"\n[SUCCESS] All tests completed successfully!")
return True
except Exception as e:
print(f"[ERROR] Error: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
test_api_models_logic()