ShopTRAINING/test/simple_training_log_test.py
2025-07-02 11:05:23 +08:00

73 lines
2.3 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
简单的训练日志测试
"""
import os
import sys
# 设置编码环境
os.environ['PYTHONIOENCODING'] = 'utf-8'
if os.name == 'nt':
try:
os.system('chcp 65001 >nul 2>&1')
if hasattr(sys.stdout, 'reconfigure'):
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
except Exception as e:
print(f"Warning: Failed to set UTF-8 encoding: {e}")
def test_imports():
"""测试关键模块导入"""
print("🧪 测试模块导入...")
try:
# 添加server目录到路径
server_dir = os.path.join(os.getcwd(), 'server')
if server_dir not in sys.path:
sys.path.insert(0, server_dir)
# 测试导入
print("1. 测试导入 training_progress...")
from utils.training_progress import TrainingProgressManager
print("✅ 成功导入 TrainingProgressManager")
print("2. 测试创建进度管理器...")
manager = TrainingProgressManager()
print("✅ 成功创建进度管理器")
print("3. 测试基本功能...")
manager.reset()
print("✅ reset() 正常")
manager.start_training("test", "P001", "transformer", "product", 3, 10, 32, 320)
print("✅ start_training() 正常")
print("4. 测试导入预测器...")
from core.predictor import PharmacyPredictor
print("✅ 成功导入 PharmacyPredictor")
print("5. 测试导入训练器...")
from trainers.transformer_trainer import train_product_model_with_transformer
print("✅ 成功导入 transformer 训练器")
print("\n🎉 所有模块导入测试通过!")
return True
except Exception as e:
print(f"❌ 导入测试失败: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
success = test_imports()
if success:
print("\n📋 下一步:")
print("1. 启动API服务器: PYTHONIOENCODING=utf-8 uv run server/api.py")
print("2. 检查控制台输出是否有训练进度日志")
else:
print("\n❌ 测试失败,需要进一步调试")