51 lines
1.6 KiB
Python
51 lines
1.6 KiB
Python
![]() |
#!/usr/bin/env python3
|
|||
|
"""
|
|||
|
最小训练测试 - 直接调用训练器
|
|||
|
"""
|
|||
|
import sys
|
|||
|
import os
|
|||
|
sys.path.append(os.path.join(os.path.dirname(__file__), 'server'))
|
|||
|
|
|||
|
def minimal_training_test():
|
|||
|
"""直接调用训练器进行测试"""
|
|||
|
print("=== 最小训练测试 ===")
|
|||
|
|
|||
|
try:
|
|||
|
# 导入必要的模块
|
|||
|
from trainers.mlstm_trainer import train_product_model_with_mlstm
|
|||
|
|
|||
|
print("✅ 训练器导入成功")
|
|||
|
|
|||
|
# 创建一个模拟的socketio对象
|
|||
|
class MockSocketIO:
|
|||
|
def emit(self, event, data, namespace=None):
|
|||
|
print(f"[MockSocketIO] {event}: {data}")
|
|||
|
|
|||
|
mock_socketio = MockSocketIO()
|
|||
|
|
|||
|
print("🚀 开始最小训练测试(2个epoch)...")
|
|||
|
|
|||
|
# 直接调用训练器
|
|||
|
result = train_product_model_with_mlstm(
|
|||
|
product_id="P001",
|
|||
|
epochs=2, # 只训练2个epoch
|
|||
|
socketio=mock_socketio,
|
|||
|
task_id="test-direct-123"
|
|||
|
)
|
|||
|
|
|||
|
print(f"✅ 训练完成: {type(result)}")
|
|||
|
if isinstance(result, tuple) and len(result) >= 2:
|
|||
|
model, metrics = result[:2]
|
|||
|
print(f"✅ 返回结果: 模型={type(model)}, 指标={metrics}")
|
|||
|
else:
|
|||
|
print(f"⚠️ 返回结果格式异常: {result}")
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"❌ 最小训练测试失败: {e}")
|
|||
|
import traceback
|
|||
|
traceback.print_exc()
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
print("🧪 开始最小训练测试")
|
|||
|
minimal_training_test()
|
|||
|
print("🏁 测试完成")
|