ShopTRAINING/test/minimal_training_test.py
2025-07-02 11:05:23 +08:00

51 lines
1.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
最小训练测试 - 直接调用训练器
"""
import sys
import os
sys.path.append(os.path.join(os.path.dirname(__file__), 'server'))
def minimal_training_test():
"""直接调用训练器进行测试"""
print("=== 最小训练测试 ===")
try:
# 导入必要的模块
from trainers.mlstm_trainer import train_product_model_with_mlstm
print("✅ 训练器导入成功")
# 创建一个模拟的socketio对象
class MockSocketIO:
def emit(self, event, data, namespace=None):
print(f"[MockSocketIO] {event}: {data}")
mock_socketio = MockSocketIO()
print("🚀 开始最小训练测试2个epoch...")
# 直接调用训练器
result = train_product_model_with_mlstm(
product_id="P001",
epochs=2, # 只训练2个epoch
socketio=mock_socketio,
task_id="test-direct-123"
)
print(f"✅ 训练完成: {type(result)}")
if isinstance(result, tuple) and len(result) >= 2:
model, metrics = result[:2]
print(f"✅ 返回结果: 模型={type(model)}, 指标={metrics}")
else:
print(f"⚠️ 返回结果格式异常: {result}")
except Exception as e:
print(f"❌ 最小训练测试失败: {e}")
import traceback
traceback.print_exc()
if __name__ == "__main__":
print("🧪 开始最小训练测试")
minimal_training_test()
print("🏁 测试完成")