65 lines
1.7 KiB
Python
65 lines
1.7 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
简化训练测试 - 直接调用训练器,避免编码问题
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
|
|
# 设置UTF-8编码环境
|
|
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
|
os.environ['PYTHONLEGACYWINDOWSSTDIO'] = '0'
|
|
|
|
# 添加server路径
|
|
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'server'))
|
|
|
|
def simple_training_test():
|
|
"""简化训练测试"""
|
|
|
|
print("=== Direct Training Test Started ===")
|
|
|
|
try:
|
|
# 导入训练器
|
|
print("Importing transformer trainer...")
|
|
from trainers.transformer_trainer import train_product_model_with_transformer
|
|
print("SUCCESS: Trainer imported")
|
|
|
|
# 测试训练
|
|
print("\nStarting training test...")
|
|
print("Product: P001")
|
|
print("Model: transformer")
|
|
print("Epochs: 2")
|
|
print("Mode: product")
|
|
print("---")
|
|
|
|
# 调用训练器
|
|
result = train_product_model_with_transformer(
|
|
product_id='P001',
|
|
epochs=2,
|
|
training_mode='product'
|
|
)
|
|
|
|
print("---")
|
|
print("SUCCESS: Training completed")
|
|
|
|
if result:
|
|
model, metrics, version = result
|
|
print(f"Model type: {type(model)}")
|
|
print(f"Version: {version}")
|
|
if metrics:
|
|
print(f"Metrics: {metrics}")
|
|
else:
|
|
print("WARNING: No metrics returned")
|
|
else:
|
|
print("WARNING: Training returned None")
|
|
|
|
except Exception as e:
|
|
print(f"ERROR: Training failed - {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
|
|
print("=== Direct Training Test Completed ===")
|
|
|
|
if __name__ == "__main__":
|
|
simple_training_test() |