ShopTRAINING/test/basic_test.py
2025-07-02 11:05:23 +08:00

71 lines
2.0 KiB
Python

#!/usr/bin/env python3
import sys
import os
# Add server path
server_path = os.path.join(os.path.dirname(__file__), 'server')
sys.path.append(server_path)
print("=== BASIC TEST START ===")
try:
print("1. Testing progress manager import...")
from utils.training_progress import progress_manager
print(" OK: Progress manager imported")
print("2. Testing callback setup...")
def test_callback(message):
event = message.get('event_type', 'unknown')
print(f" CALLBACK: {event}")
progress_manager.websocket_callback = test_callback
print(" OK: Callback set")
print("3. Testing progress manager functions...")
progress_manager.start_training(
training_id="test-123",
product_id="P001",
model_type="mlstm",
training_mode="product",
total_epochs=2,
total_batches=10,
batch_size=32,
total_samples=320
)
print(" OK: Training started")
progress_manager.set_stage("data_preprocessing", 50)
print(" OK: Stage set")
progress_manager.start_epoch(0)
print(" OK: Epoch started")
progress_manager.finish_training(success=True)
print(" OK: Training finished")
print("4. Testing trainer import...")
from trainers.mlstm_trainer import train_product_model_with_mlstm
print(" OK: Trainer imported")
import inspect
sig = inspect.signature(train_product_model_with_mlstm)
params = list(sig.parameters.keys())
print(f" Parameters: {params}")
has_socketio = 'socketio' in params
has_task_id = 'task_id' in params
print(f" Has socketio: {has_socketio}")
print(f" Has task_id: {has_task_id}")
if has_socketio and has_task_id:
print(" OK: Required parameters present")
else:
print(" ERROR: Missing required parameters")
print("\n=== TEST COMPLETED SUCCESSFULLY ===")
except Exception as e:
print(f"ERROR: {e}")
import traceback
traceback.print_exc()
print("\n=== TEST FAILED ===")