71 lines
2.0 KiB
Python
71 lines
2.0 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
直接训练测试 - 绕过API直接调用训练器测试日志输出
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
|
|
# 设置环境变量确保UTF-8编码
|
|
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
|
if os.name == 'nt':
|
|
os.system('chcp 65001 >nul 2>&1')
|
|
|
|
# 添加server路径
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'server'))
|
|
|
|
def test_direct_training():
|
|
"""直接测试训练器"""
|
|
|
|
print("🧪 直接训练测试开始")
|
|
print("="*50)
|
|
|
|
try:
|
|
# 导入训练器
|
|
print("📦 导入训练器模块...")
|
|
from trainers.transformer_trainer import train_product_model_with_transformer
|
|
|
|
print("✅ 模块导入成功")
|
|
|
|
# 测试训练
|
|
print("\n🚀 开始训练测试...")
|
|
print("📋 产品: P001")
|
|
print("🤖 模型: transformer")
|
|
print("⚙️ 轮次: 2 (快速测试)")
|
|
print("-" * 50)
|
|
|
|
# 调用训练器
|
|
result = train_product_model_with_transformer(
|
|
product_id='P001',
|
|
epochs=2, # 使用很少的轮次快速测试
|
|
training_mode='product'
|
|
)
|
|
|
|
print("-" * 50)
|
|
print("✅ 训练完成!")
|
|
|
|
if result:
|
|
model, metrics, version = result
|
|
print(f"📊 返回结果:")
|
|
print(f" 模型: {type(model)}")
|
|
print(f" 版本: {version}")
|
|
if metrics:
|
|
print(f" 指标: {metrics}")
|
|
else:
|
|
print(" ⚠️ 指标为空")
|
|
else:
|
|
print("⚠️ 训练返回None")
|
|
|
|
except Exception as e:
|
|
print(f"❌ 训练失败: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
print("\n" + "="*50)
|
|
print("🎉 直接训练测试完成")
|
|
print("💡 如果看到了训练进度输出,说明日志修复成功")
|
|
print("="*50)
|
|
|
|
if __name__ == "__main__":
|
|
test_direct_training() |