ShopTRAINING/test/simple_training_direct_test.py
2025-07-02 11:05:23 +08:00

65 lines
1.7 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
简化训练测试 - 直接调用训练器,避免编码问题
"""
import os
import sys
# 设置UTF-8编码环境
os.environ['PYTHONIOENCODING'] = 'utf-8'
os.environ['PYTHONLEGACYWINDOWSSTDIO'] = '0'
# 添加server路径
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'server'))
def simple_training_test():
"""简化训练测试"""
print("=== Direct Training Test Started ===")
try:
# 导入训练器
print("Importing transformer trainer...")
from trainers.transformer_trainer import train_product_model_with_transformer
print("SUCCESS: Trainer imported")
# 测试训练
print("\nStarting training test...")
print("Product: P001")
print("Model: transformer")
print("Epochs: 2")
print("Mode: product")
print("---")
# 调用训练器
result = train_product_model_with_transformer(
product_id='P001',
epochs=2,
training_mode='product'
)
print("---")
print("SUCCESS: Training completed")
if result:
model, metrics, version = result
print(f"Model type: {type(model)}")
print(f"Version: {version}")
if metrics:
print(f"Metrics: {metrics}")
else:
print("WARNING: No metrics returned")
else:
print("WARNING: Training returned None")
except Exception as e:
print(f"ERROR: Training failed - {e}")
import traceback
traceback.print_exc()
print("=== Direct Training Test Completed ===")
if __name__ == "__main__":
simple_training_test()