73 lines
2.3 KiB
Python
73 lines
2.3 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
简单的训练日志测试
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
|
|
# 设置编码环境
|
|
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
|
|
|
if os.name == 'nt':
|
|
try:
|
|
os.system('chcp 65001 >nul 2>&1')
|
|
if hasattr(sys.stdout, 'reconfigure'):
|
|
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
|
|
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
|
|
except Exception as e:
|
|
print(f"Warning: Failed to set UTF-8 encoding: {e}")
|
|
|
|
def test_imports():
|
|
"""测试关键模块导入"""
|
|
print("🧪 测试模块导入...")
|
|
|
|
try:
|
|
# 添加server目录到路径
|
|
server_dir = os.path.join(os.getcwd(), 'server')
|
|
if server_dir not in sys.path:
|
|
sys.path.insert(0, server_dir)
|
|
|
|
# 测试导入
|
|
print("1. 测试导入 training_progress...")
|
|
from utils.training_progress import TrainingProgressManager
|
|
print("✅ 成功导入 TrainingProgressManager")
|
|
|
|
print("2. 测试创建进度管理器...")
|
|
manager = TrainingProgressManager()
|
|
print("✅ 成功创建进度管理器")
|
|
|
|
print("3. 测试基本功能...")
|
|
manager.reset()
|
|
print("✅ reset() 正常")
|
|
|
|
manager.start_training("test", "P001", "transformer", "product", 3, 10, 32, 320)
|
|
print("✅ start_training() 正常")
|
|
|
|
print("4. 测试导入预测器...")
|
|
from core.predictor import PharmacyPredictor
|
|
print("✅ 成功导入 PharmacyPredictor")
|
|
|
|
print("5. 测试导入训练器...")
|
|
from trainers.transformer_trainer import train_product_model_with_transformer
|
|
print("✅ 成功导入 transformer 训练器")
|
|
|
|
print("\n🎉 所有模块导入测试通过!")
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"❌ 导入测试失败: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
if __name__ == "__main__":
|
|
success = test_imports()
|
|
|
|
if success:
|
|
print("\n📋 下一步:")
|
|
print("1. 启动API服务器: PYTHONIOENCODING=utf-8 uv run server/api.py")
|
|
print("2. 检查控制台输出是否有训练进度日志")
|
|
else:
|
|
print("\n❌ 测试失败,需要进一步调试") |