ShopTRAINING/test/test_api_logic.py
2025-07-02 11:05:23 +08:00

102 lines
4.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
Test script to verify API logic directly
"""
import os
import sys
# Add server directory to path
server_path = os.path.join(os.path.dirname(__file__), 'server')
sys.path.insert(0, server_path)
def test_api_models_logic():
"""Test the models endpoint logic directly"""
print("=== Testing API Models Logic ===")
try:
from utils.model_manager import model_manager
print("1. Testing ModelManager.list_models()...")
models = model_manager.list_models()
print(f"Found {len(models)} models")
print("\n2. Testing API formatting logic...")
formatted_models = []
for model in models:
# This is the same logic as in api.py line 1944-1959
model_id = model.get('filename', '').replace('.pth', '')
if not model_id:
# 备用方案基于模型信息生成ID
product_id = model.get('product_id', 'unknown')
model_type = model.get('model_type', 'unknown')
version = model.get('version', 'v1')
training_mode = model.get('training_mode', 'product')
store_id = model.get('store_id')
if training_mode == 'store' and store_id:
model_id = f"{model_type}_store_{store_id}_{product_id}_{version}"
elif training_mode == 'global':
aggregation_method = model.get('aggregation_method', 'mean')
model_id = f"{model_type}_global_{product_id}_{aggregation_method}_{version}"
else:
model_id = f"{model_type}_product_{product_id}_{version}"
formatted_model = {
'model_id': model_id,
'product_id': model.get('product_id', ''),
'product_name': model.get('product_name', model.get('product_id', '')),
'model_type': model.get('model_type', ''),
'training_mode': model.get('training_mode', 'product'),
'store_id': model.get('store_id'),
'aggregation_method': model.get('aggregation_method'),
'version': model.get('version', 'v1'),
'created_at': model.get('created_at', model.get('modified_at', '')),
'file_size': model.get('file_size', 0),
'metrics': model.get('metrics', {}),
'config': model.get('config', {})
}
formatted_models.append(formatted_model)
print(f"Formatted {len(formatted_models)} models")
print("\n3. Showing formatted models...")
for i, model in enumerate(formatted_models):
print(f"\nModel {i+1}:")
print(f" model_id: {model['model_id']}")
print(f" product_id: {model['product_id']}")
print(f" model_type: {model['model_type']}")
print(f" training_mode: {model['training_mode']}")
print(f" version: {model['version']}")
print(f" file_size: {model['file_size']}")
# Check if model_id is valid
if model['model_id'] and model['model_id'] != '':
print(f" [OK] model_id is properly set")
else:
print(f" [ERROR] model_id is missing or empty")
print("\n4. Testing filters...")
# Test product_id filter
filtered_models = model_manager.list_models(product_id='P001')
print(f"P001 models: {len(filtered_models)}")
# Test model_type filter
mlstm_models = model_manager.list_models(model_type='mlstm')
print(f"MLSTM models: {len(mlstm_models)}")
kan_models = model_manager.list_models(model_type='kan_optimized')
print(f"KAN optimized models: {len(kan_models)}")
print(f"\n[SUCCESS] All tests completed successfully!")
return True
except Exception as e:
print(f"[ERROR] Error: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
test_api_models_logic()