71 lines
2.0 KiB
Python
71 lines
2.0 KiB
Python
#!/usr/bin/env python3
|
|
import sys
|
|
import os
|
|
|
|
# Add server path
|
|
server_path = os.path.join(os.path.dirname(__file__), 'server')
|
|
sys.path.append(server_path)
|
|
|
|
print("=== BASIC TEST START ===")
|
|
|
|
try:
|
|
print("1. Testing progress manager import...")
|
|
from utils.training_progress import progress_manager
|
|
print(" OK: Progress manager imported")
|
|
|
|
print("2. Testing callback setup...")
|
|
def test_callback(message):
|
|
event = message.get('event_type', 'unknown')
|
|
print(f" CALLBACK: {event}")
|
|
|
|
progress_manager.websocket_callback = test_callback
|
|
print(" OK: Callback set")
|
|
|
|
print("3. Testing progress manager functions...")
|
|
progress_manager.start_training(
|
|
training_id="test-123",
|
|
product_id="P001",
|
|
model_type="mlstm",
|
|
training_mode="product",
|
|
total_epochs=2,
|
|
total_batches=10,
|
|
batch_size=32,
|
|
total_samples=320
|
|
)
|
|
print(" OK: Training started")
|
|
|
|
progress_manager.set_stage("data_preprocessing", 50)
|
|
print(" OK: Stage set")
|
|
|
|
progress_manager.start_epoch(0)
|
|
print(" OK: Epoch started")
|
|
|
|
progress_manager.finish_training(success=True)
|
|
print(" OK: Training finished")
|
|
|
|
print("4. Testing trainer import...")
|
|
from trainers.mlstm_trainer import train_product_model_with_mlstm
|
|
print(" OK: Trainer imported")
|
|
|
|
import inspect
|
|
sig = inspect.signature(train_product_model_with_mlstm)
|
|
params = list(sig.parameters.keys())
|
|
print(f" Parameters: {params}")
|
|
|
|
has_socketio = 'socketio' in params
|
|
has_task_id = 'task_id' in params
|
|
print(f" Has socketio: {has_socketio}")
|
|
print(f" Has task_id: {has_task_id}")
|
|
|
|
if has_socketio and has_task_id:
|
|
print(" OK: Required parameters present")
|
|
else:
|
|
print(" ERROR: Missing required parameters")
|
|
|
|
print("\n=== TEST COMPLETED SUCCESSFULLY ===")
|
|
|
|
except Exception as e:
|
|
print(f"ERROR: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
print("\n=== TEST FAILED ===") |