140 lines
4.9 KiB
Python
140 lines
4.9 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
测试训练日志修复效果
|
||
验证API服务器中的训练进度管理器是否正常工作
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import time
|
||
import json
|
||
|
||
# 尝试导入requests,如果失败则跳过API测试
|
||
try:
|
||
import requests
|
||
REQUESTS_AVAILABLE = True
|
||
except ImportError:
|
||
REQUESTS_AVAILABLE = False
|
||
|
||
# 设置编码环境
|
||
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
||
os.environ['PYTHONLEGACYWINDOWSSTDIO'] = '0'
|
||
|
||
# Windows系统额外配置
|
||
if os.name == 'nt':
|
||
try:
|
||
os.system('chcp 65001 >nul 2>&1')
|
||
if hasattr(sys.stdout, 'reconfigure'):
|
||
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
|
||
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
|
||
else:
|
||
import io
|
||
if hasattr(sys.stdout, 'buffer'):
|
||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace', line_buffering=True)
|
||
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace', line_buffering=True)
|
||
except Exception as e:
|
||
print(f"Warning: Failed to set UTF-8 encoding: {e}")
|
||
|
||
def test_training_log_fix():
|
||
"""测试训练日志修复效果"""
|
||
print("🧪 开始测试训练日志修复效果")
|
||
print("=" * 50)
|
||
|
||
# 1. 测试导入修复
|
||
print("1️⃣ 测试模块导入...")
|
||
try:
|
||
# 添加server目录到路径
|
||
server_dir = os.path.join(os.getcwd(), 'server')
|
||
if server_dir not in sys.path:
|
||
sys.path.insert(0, server_dir)
|
||
|
||
# 测试导入training_progress
|
||
from utils.training_progress import TrainingProgressManager
|
||
print("✅ 成功导入 TrainingProgressManager")
|
||
|
||
# 测试创建实例
|
||
manager = TrainingProgressManager()
|
||
print("✅ 成功创建进度管理器实例")
|
||
|
||
# 测试基本功能
|
||
manager.reset()
|
||
manager.start_training(
|
||
training_id="test_001",
|
||
product_id="P001",
|
||
model_type="transformer",
|
||
training_mode="product",
|
||
total_epochs=3,
|
||
total_batches=10,
|
||
batch_size=32,
|
||
total_samples=320
|
||
)
|
||
print("✅ 进度管理器基本功能正常")
|
||
|
||
manager.start_epoch(1)
|
||
manager.update_batch(5, 0.1234, 0.001)
|
||
manager.finish_epoch(0.1234)
|
||
manager.finish_training(True)
|
||
print("✅ 进度管理器完整流程测试通过")
|
||
|
||
except Exception as e:
|
||
print(f"❌ 模块导入测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
# 2. 测试API服务器启动
|
||
print("\n2️⃣ 测试API服务器模块加载...")
|
||
try:
|
||
# 测试导入api模块中的关键组件
|
||
from core.predictor import PharmacyPredictor
|
||
print("✅ 成功导入 PharmacyPredictor")
|
||
|
||
from trainers.transformer_trainer import train_product_model_with_transformer
|
||
print("✅ 成功导入 transformer 训练器")
|
||
|
||
# 测试创建预测器实例
|
||
predictor = PharmacyPredictor()
|
||
print("✅ 成功创建预测器实例")
|
||
|
||
except Exception as e:
|
||
print(f"❌ API服务器模块测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
# 3. 测试API连接(如果服务器正在运行)
|
||
print("\n3️⃣ 测试API连接...")
|
||
if REQUESTS_AVAILABLE:
|
||
try:
|
||
response = requests.get('http://localhost:5000/api/products', timeout=5)
|
||
if response.status_code == 200:
|
||
print("✅ API服务器正在运行并响应正常")
|
||
|
||
# 测试模型列表端点
|
||
models_response = requests.get('http://localhost:5000/api/models', timeout=5)
|
||
if models_response.status_code == 200:
|
||
print("✅ 模型列表端点正常")
|
||
else:
|
||
print(f"⚠️ 模型列表端点状态码: {models_response.status_code}")
|
||
|
||
else:
|
||
print(f"⚠️ API服务器响应状态码: {response.status_code}")
|
||
except requests.exceptions.ConnectionError:
|
||
print("ℹ️ API服务器未运行(这是正常的,需要手动启动)")
|
||
except Exception as e:
|
||
print(f"⚠️ API连接测试异常: {e}")
|
||
else:
|
||
print("ℹ️ requests库未安装,跳过API连接测试")
|
||
|
||
print("\n" + "=" * 50)
|
||
print("🎉 训练日志修复测试完成!")
|
||
print("\n📋 下一步操作建议:")
|
||
print("1. 启动API服务器: PYTHONIOENCODING=utf-8 uv run server/api.py")
|
||
print("2. 在前端进行模型训练测试")
|
||
print("3. 观察控制台是否有详细的训练进度输出")
|
||
|
||
return True
|
||
|
||
if __name__ == "__main__":
|
||
test_training_log_fix() |