51 lines
1.6 KiB
Python
51 lines
1.6 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
最小训练测试 - 直接调用训练器
|
||
"""
|
||
import sys
|
||
import os
|
||
sys.path.append(os.path.join(os.path.dirname(__file__), 'server'))
|
||
|
||
def minimal_training_test():
|
||
"""直接调用训练器进行测试"""
|
||
print("=== 最小训练测试 ===")
|
||
|
||
try:
|
||
# 导入必要的模块
|
||
from trainers.mlstm_trainer import train_product_model_with_mlstm
|
||
|
||
print("✅ 训练器导入成功")
|
||
|
||
# 创建一个模拟的socketio对象
|
||
class MockSocketIO:
|
||
def emit(self, event, data, namespace=None):
|
||
print(f"[MockSocketIO] {event}: {data}")
|
||
|
||
mock_socketio = MockSocketIO()
|
||
|
||
print("🚀 开始最小训练测试(2个epoch)...")
|
||
|
||
# 直接调用训练器
|
||
result = train_product_model_with_mlstm(
|
||
product_id="P001",
|
||
epochs=2, # 只训练2个epoch
|
||
socketio=mock_socketio,
|
||
task_id="test-direct-123"
|
||
)
|
||
|
||
print(f"✅ 训练完成: {type(result)}")
|
||
if isinstance(result, tuple) and len(result) >= 2:
|
||
model, metrics = result[:2]
|
||
print(f"✅ 返回结果: 模型={type(model)}, 指标={metrics}")
|
||
else:
|
||
print(f"⚠️ 返回结果格式异常: {result}")
|
||
|
||
except Exception as e:
|
||
print(f"❌ 最小训练测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
if __name__ == "__main__":
|
||
print("🧪 开始最小训练测试")
|
||
minimal_training_test()
|
||
print("🏁 测试完成") |