ShopTRAINING/test/test_training_log_fix.py

140 lines
4.9 KiB
Python
Raw Normal View History

2025-07-02 11:05:23 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
测试训练日志修复效果
验证API服务器中的训练进度管理器是否正常工作
"""
import os
import sys
import time
import json
# 尝试导入requests如果失败则跳过API测试
try:
import requests
REQUESTS_AVAILABLE = True
except ImportError:
REQUESTS_AVAILABLE = False
# 设置编码环境
os.environ['PYTHONIOENCODING'] = 'utf-8'
os.environ['PYTHONLEGACYWINDOWSSTDIO'] = '0'
# Windows系统额外配置
if os.name == 'nt':
try:
os.system('chcp 65001 >nul 2>&1')
if hasattr(sys.stdout, 'reconfigure'):
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
else:
import io
if hasattr(sys.stdout, 'buffer'):
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace', line_buffering=True)
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace', line_buffering=True)
except Exception as e:
print(f"Warning: Failed to set UTF-8 encoding: {e}")
def test_training_log_fix():
"""测试训练日志修复效果"""
print("🧪 开始测试训练日志修复效果")
print("=" * 50)
# 1. 测试导入修复
print("1⃣ 测试模块导入...")
try:
# 添加server目录到路径
server_dir = os.path.join(os.getcwd(), 'server')
if server_dir not in sys.path:
sys.path.insert(0, server_dir)
# 测试导入training_progress
from utils.training_progress import TrainingProgressManager
print("✅ 成功导入 TrainingProgressManager")
# 测试创建实例
manager = TrainingProgressManager()
print("✅ 成功创建进度管理器实例")
# 测试基本功能
manager.reset()
manager.start_training(
training_id="test_001",
product_id="P001",
model_type="transformer",
training_mode="product",
total_epochs=3,
total_batches=10,
batch_size=32,
total_samples=320
)
print("✅ 进度管理器基本功能正常")
manager.start_epoch(1)
manager.update_batch(5, 0.1234, 0.001)
manager.finish_epoch(0.1234)
manager.finish_training(True)
print("✅ 进度管理器完整流程测试通过")
except Exception as e:
print(f"❌ 模块导入测试失败: {e}")
import traceback
traceback.print_exc()
return False
# 2. 测试API服务器启动
print("\n2⃣ 测试API服务器模块加载...")
try:
# 测试导入api模块中的关键组件
from core.predictor import PharmacyPredictor
print("✅ 成功导入 PharmacyPredictor")
from trainers.transformer_trainer import train_product_model_with_transformer
print("✅ 成功导入 transformer 训练器")
# 测试创建预测器实例
predictor = PharmacyPredictor()
print("✅ 成功创建预测器实例")
except Exception as e:
print(f"❌ API服务器模块测试失败: {e}")
import traceback
traceback.print_exc()
return False
# 3. 测试API连接如果服务器正在运行
print("\n3⃣ 测试API连接...")
if REQUESTS_AVAILABLE:
try:
response = requests.get('http://localhost:5000/api/products', timeout=5)
if response.status_code == 200:
print("✅ API服务器正在运行并响应正常")
# 测试模型列表端点
models_response = requests.get('http://localhost:5000/api/models', timeout=5)
if models_response.status_code == 200:
print("✅ 模型列表端点正常")
else:
print(f"⚠️ 模型列表端点状态码: {models_response.status_code}")
else:
print(f"⚠️ API服务器响应状态码: {response.status_code}")
except requests.exceptions.ConnectionError:
print(" API服务器未运行这是正常的需要手动启动")
except Exception as e:
print(f"⚠️ API连接测试异常: {e}")
else:
print(" requests库未安装跳过API连接测试")
print("\n" + "=" * 50)
print("🎉 训练日志修复测试完成!")
print("\n📋 下一步操作建议:")
print("1. 启动API服务器: PYTHONIOENCODING=utf-8 uv run server/api.py")
print("2. 在前端进行模型训练测试")
print("3. 观察控制台是否有详细的训练进度输出")
return True
if __name__ == "__main__":
test_training_log_fix()