ShopTRAINING/test/test_training_log_fix.py
2025-07-02 11:05:23 +08:00

140 lines
4.9 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
测试训练日志修复效果
验证API服务器中的训练进度管理器是否正常工作
"""
import os
import sys
import time
import json
# 尝试导入requests如果失败则跳过API测试
try:
import requests
REQUESTS_AVAILABLE = True
except ImportError:
REQUESTS_AVAILABLE = False
# 设置编码环境
os.environ['PYTHONIOENCODING'] = 'utf-8'
os.environ['PYTHONLEGACYWINDOWSSTDIO'] = '0'
# Windows系统额外配置
if os.name == 'nt':
try:
os.system('chcp 65001 >nul 2>&1')
if hasattr(sys.stdout, 'reconfigure'):
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
else:
import io
if hasattr(sys.stdout, 'buffer'):
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace', line_buffering=True)
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace', line_buffering=True)
except Exception as e:
print(f"Warning: Failed to set UTF-8 encoding: {e}")
def test_training_log_fix():
"""测试训练日志修复效果"""
print("🧪 开始测试训练日志修复效果")
print("=" * 50)
# 1. 测试导入修复
print("1⃣ 测试模块导入...")
try:
# 添加server目录到路径
server_dir = os.path.join(os.getcwd(), 'server')
if server_dir not in sys.path:
sys.path.insert(0, server_dir)
# 测试导入training_progress
from utils.training_progress import TrainingProgressManager
print("✅ 成功导入 TrainingProgressManager")
# 测试创建实例
manager = TrainingProgressManager()
print("✅ 成功创建进度管理器实例")
# 测试基本功能
manager.reset()
manager.start_training(
training_id="test_001",
product_id="P001",
model_type="transformer",
training_mode="product",
total_epochs=3,
total_batches=10,
batch_size=32,
total_samples=320
)
print("✅ 进度管理器基本功能正常")
manager.start_epoch(1)
manager.update_batch(5, 0.1234, 0.001)
manager.finish_epoch(0.1234)
manager.finish_training(True)
print("✅ 进度管理器完整流程测试通过")
except Exception as e:
print(f"❌ 模块导入测试失败: {e}")
import traceback
traceback.print_exc()
return False
# 2. 测试API服务器启动
print("\n2⃣ 测试API服务器模块加载...")
try:
# 测试导入api模块中的关键组件
from core.predictor import PharmacyPredictor
print("✅ 成功导入 PharmacyPredictor")
from trainers.transformer_trainer import train_product_model_with_transformer
print("✅ 成功导入 transformer 训练器")
# 测试创建预测器实例
predictor = PharmacyPredictor()
print("✅ 成功创建预测器实例")
except Exception as e:
print(f"❌ API服务器模块测试失败: {e}")
import traceback
traceback.print_exc()
return False
# 3. 测试API连接如果服务器正在运行
print("\n3⃣ 测试API连接...")
if REQUESTS_AVAILABLE:
try:
response = requests.get('http://localhost:5000/api/products', timeout=5)
if response.status_code == 200:
print("✅ API服务器正在运行并响应正常")
# 测试模型列表端点
models_response = requests.get('http://localhost:5000/api/models', timeout=5)
if models_response.status_code == 200:
print("✅ 模型列表端点正常")
else:
print(f"⚠️ 模型列表端点状态码: {models_response.status_code}")
else:
print(f"⚠️ API服务器响应状态码: {response.status_code}")
except requests.exceptions.ConnectionError:
print(" API服务器未运行这是正常的需要手动启动")
except Exception as e:
print(f"⚠️ API连接测试异常: {e}")
else:
print(" requests库未安装跳过API连接测试")
print("\n" + "=" * 50)
print("🎉 训练日志修复测试完成!")
print("\n📋 下一步操作建议:")
print("1. 启动API服务器: PYTHONIOENCODING=utf-8 uv run server/api.py")
print("2. 在前端进行模型训练测试")
print("3. 观察控制台是否有详细的训练进度输出")
return True
if __name__ == "__main__":
test_training_log_fix()