ShopTRAINING/test/direct_training_test.py
2025-07-02 11:05:23 +08:00

71 lines
2.0 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
直接训练测试 - 绕过API直接调用训练器测试日志输出
"""
import os
import sys
# 设置环境变量确保UTF-8编码
os.environ['PYTHONIOENCODING'] = 'utf-8'
if os.name == 'nt':
os.system('chcp 65001 >nul 2>&1')
# 添加server路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'server'))
def test_direct_training():
"""直接测试训练器"""
print("🧪 直接训练测试开始")
print("="*50)
try:
# 导入训练器
print("📦 导入训练器模块...")
from trainers.transformer_trainer import train_product_model_with_transformer
print("✅ 模块导入成功")
# 测试训练
print("\n🚀 开始训练测试...")
print("📋 产品: P001")
print("🤖 模型: transformer")
print("⚙️ 轮次: 2 (快速测试)")
print("-" * 50)
# 调用训练器
result = train_product_model_with_transformer(
product_id='P001',
epochs=2, # 使用很少的轮次快速测试
training_mode='product'
)
print("-" * 50)
print("✅ 训练完成!")
if result:
model, metrics, version = result
print(f"📊 返回结果:")
print(f" 模型: {type(model)}")
print(f" 版本: {version}")
if metrics:
print(f" 指标: {metrics}")
else:
print(" ⚠️ 指标为空")
else:
print("⚠️ 训练返回None")
except Exception as e:
print(f"❌ 训练失败: {e}")
import traceback
traceback.print_exc()
print("\n" + "="*50)
print("🎉 直接训练测试完成")
print("💡 如果看到了训练进度输出,说明日志修复成功")
print("="*50)
if __name__ == "__main__":
test_direct_training()