116 lines
3.2 KiB
Python
116 lines
3.2 KiB
Python
#!/usr/bin/env python3
|
|
"""
|
|
No unicode test - English only
|
|
"""
|
|
import sys
|
|
import os
|
|
import requests
|
|
import time
|
|
|
|
# Add server path
|
|
server_path = os.path.join(os.path.dirname(__file__), 'server')
|
|
sys.path.append(server_path)
|
|
|
|
def test_imports():
|
|
print("=== Testing Imports ===")
|
|
try:
|
|
from utils.training_progress import progress_manager
|
|
print("OK: Progress manager imported")
|
|
|
|
from trainers.mlstm_trainer import train_product_model_with_mlstm
|
|
print("OK: mLSTM trainer imported")
|
|
|
|
import inspect
|
|
sig = inspect.signature(train_product_model_with_mlstm)
|
|
params = list(sig.parameters.keys())
|
|
has_socketio = 'socketio' in params
|
|
has_task_id = 'task_id' in params
|
|
print(f"OK: Has socketio={has_socketio}, task_id={has_task_id}")
|
|
|
|
return True
|
|
except Exception as e:
|
|
print(f"ERROR: Import failed - {e}")
|
|
return False
|
|
|
|
def test_progress_manager():
|
|
print("\n=== Testing Progress Manager ===")
|
|
try:
|
|
from utils.training_progress import progress_manager
|
|
|
|
# Set a test callback
|
|
def test_callback(message):
|
|
print(f"CALLBACK: {message.get('event_type', 'unknown')}")
|
|
|
|
progress_manager.websocket_callback = test_callback
|
|
|
|
# Test progress manager
|
|
progress_manager.start_training(
|
|
training_id="test-123",
|
|
product_id="P001",
|
|
model_type="mlstm",
|
|
training_mode="product",
|
|
total_epochs=2,
|
|
total_batches=10,
|
|
batch_size=32,
|
|
total_samples=320
|
|
)
|
|
print("OK: Progress manager started")
|
|
|
|
progress_manager.set_stage("data_preprocessing", 50)
|
|
print("OK: Stage update")
|
|
|
|
progress_manager.start_epoch(0)
|
|
print("OK: Epoch started")
|
|
|
|
progress_manager.update_batch(5, 0.1234, 0.001)
|
|
print("OK: Batch update")
|
|
|
|
progress_manager.finish_epoch(0.1234, 0.1567)
|
|
print("OK: Epoch finished")
|
|
|
|
progress_manager.finish_training(success=True)
|
|
print("OK: Training finished")
|
|
|
|
return True
|
|
except Exception as e:
|
|
print(f"ERROR: Progress manager test failed - {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
def test_api():
|
|
print("\n=== Testing API ===")
|
|
try:
|
|
response = requests.get('http://localhost:5000/api/version', timeout=5)
|
|
if response.status_code == 200:
|
|
print("OK: API server is running")
|
|
return True
|
|
else:
|
|
print(f"ERROR: API returned {response.status_code}")
|
|
return False
|
|
except Exception as e:
|
|
print(f"ERROR: Cannot connect to API - {e}")
|
|
return False
|
|
|
|
def main():
|
|
print("=== UNICODE-FREE TESTING ===")
|
|
|
|
success = True
|
|
|
|
if not test_imports():
|
|
success = False
|
|
|
|
if not test_progress_manager():
|
|
success = False
|
|
|
|
if not test_api():
|
|
success = False
|
|
print("NOTE: Make sure API server is running")
|
|
|
|
if success:
|
|
print("\n=== ALL TESTS PASSED ===")
|
|
else:
|
|
print("\n=== SOME TESTS FAILED ===")
|
|
|
|
if __name__ == "__main__":
|
|
main() |