140 lines
4.9 KiB
Python
140 lines
4.9 KiB
Python
![]() |
#!/usr/bin/env python3
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
"""
|
|||
|
测试训练日志修复效果
|
|||
|
验证API服务器中的训练进度管理器是否正常工作
|
|||
|
"""
|
|||
|
|
|||
|
import os
|
|||
|
import sys
|
|||
|
import time
|
|||
|
import json
|
|||
|
|
|||
|
# 尝试导入requests,如果失败则跳过API测试
|
|||
|
try:
|
|||
|
import requests
|
|||
|
REQUESTS_AVAILABLE = True
|
|||
|
except ImportError:
|
|||
|
REQUESTS_AVAILABLE = False
|
|||
|
|
|||
|
# 设置编码环境
|
|||
|
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
|||
|
os.environ['PYTHONLEGACYWINDOWSSTDIO'] = '0'
|
|||
|
|
|||
|
# Windows系统额外配置
|
|||
|
if os.name == 'nt':
|
|||
|
try:
|
|||
|
os.system('chcp 65001 >nul 2>&1')
|
|||
|
if hasattr(sys.stdout, 'reconfigure'):
|
|||
|
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
|
|||
|
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
|
|||
|
else:
|
|||
|
import io
|
|||
|
if hasattr(sys.stdout, 'buffer'):
|
|||
|
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace', line_buffering=True)
|
|||
|
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace', line_buffering=True)
|
|||
|
except Exception as e:
|
|||
|
print(f"Warning: Failed to set UTF-8 encoding: {e}")
|
|||
|
|
|||
|
def test_training_log_fix():
|
|||
|
"""测试训练日志修复效果"""
|
|||
|
print("🧪 开始测试训练日志修复效果")
|
|||
|
print("=" * 50)
|
|||
|
|
|||
|
# 1. 测试导入修复
|
|||
|
print("1️⃣ 测试模块导入...")
|
|||
|
try:
|
|||
|
# 添加server目录到路径
|
|||
|
server_dir = os.path.join(os.getcwd(), 'server')
|
|||
|
if server_dir not in sys.path:
|
|||
|
sys.path.insert(0, server_dir)
|
|||
|
|
|||
|
# 测试导入training_progress
|
|||
|
from utils.training_progress import TrainingProgressManager
|
|||
|
print("✅ 成功导入 TrainingProgressManager")
|
|||
|
|
|||
|
# 测试创建实例
|
|||
|
manager = TrainingProgressManager()
|
|||
|
print("✅ 成功创建进度管理器实例")
|
|||
|
|
|||
|
# 测试基本功能
|
|||
|
manager.reset()
|
|||
|
manager.start_training(
|
|||
|
training_id="test_001",
|
|||
|
product_id="P001",
|
|||
|
model_type="transformer",
|
|||
|
training_mode="product",
|
|||
|
total_epochs=3,
|
|||
|
total_batches=10,
|
|||
|
batch_size=32,
|
|||
|
total_samples=320
|
|||
|
)
|
|||
|
print("✅ 进度管理器基本功能正常")
|
|||
|
|
|||
|
manager.start_epoch(1)
|
|||
|
manager.update_batch(5, 0.1234, 0.001)
|
|||
|
manager.finish_epoch(0.1234)
|
|||
|
manager.finish_training(True)
|
|||
|
print("✅ 进度管理器完整流程测试通过")
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"❌ 模块导入测试失败: {e}")
|
|||
|
import traceback
|
|||
|
traceback.print_exc()
|
|||
|
return False
|
|||
|
|
|||
|
# 2. 测试API服务器启动
|
|||
|
print("\n2️⃣ 测试API服务器模块加载...")
|
|||
|
try:
|
|||
|
# 测试导入api模块中的关键组件
|
|||
|
from core.predictor import PharmacyPredictor
|
|||
|
print("✅ 成功导入 PharmacyPredictor")
|
|||
|
|
|||
|
from trainers.transformer_trainer import train_product_model_with_transformer
|
|||
|
print("✅ 成功导入 transformer 训练器")
|
|||
|
|
|||
|
# 测试创建预测器实例
|
|||
|
predictor = PharmacyPredictor()
|
|||
|
print("✅ 成功创建预测器实例")
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"❌ API服务器模块测试失败: {e}")
|
|||
|
import traceback
|
|||
|
traceback.print_exc()
|
|||
|
return False
|
|||
|
|
|||
|
# 3. 测试API连接(如果服务器正在运行)
|
|||
|
print("\n3️⃣ 测试API连接...")
|
|||
|
if REQUESTS_AVAILABLE:
|
|||
|
try:
|
|||
|
response = requests.get('http://localhost:5000/api/products', timeout=5)
|
|||
|
if response.status_code == 200:
|
|||
|
print("✅ API服务器正在运行并响应正常")
|
|||
|
|
|||
|
# 测试模型列表端点
|
|||
|
models_response = requests.get('http://localhost:5000/api/models', timeout=5)
|
|||
|
if models_response.status_code == 200:
|
|||
|
print("✅ 模型列表端点正常")
|
|||
|
else:
|
|||
|
print(f"⚠️ 模型列表端点状态码: {models_response.status_code}")
|
|||
|
|
|||
|
else:
|
|||
|
print(f"⚠️ API服务器响应状态码: {response.status_code}")
|
|||
|
except requests.exceptions.ConnectionError:
|
|||
|
print("ℹ️ API服务器未运行(这是正常的,需要手动启动)")
|
|||
|
except Exception as e:
|
|||
|
print(f"⚠️ API连接测试异常: {e}")
|
|||
|
else:
|
|||
|
print("ℹ️ requests库未安装,跳过API连接测试")
|
|||
|
|
|||
|
print("\n" + "=" * 50)
|
|||
|
print("🎉 训练日志修复测试完成!")
|
|||
|
print("\n📋 下一步操作建议:")
|
|||
|
print("1. 启动API服务器: PYTHONIOENCODING=utf-8 uv run server/api.py")
|
|||
|
print("2. 在前端进行模型训练测试")
|
|||
|
print("3. 观察控制台是否有详细的训练进度输出")
|
|||
|
|
|||
|
return True
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
test_training_log_fix()
|