#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 测试训练日志修复效果 验证API服务器中的训练进度管理器是否正常工作 """ import os import sys import time import json # 尝试导入requests,如果失败则跳过API测试 try: import requests REQUESTS_AVAILABLE = True except ImportError: REQUESTS_AVAILABLE = False # 设置编码环境 os.environ['PYTHONIOENCODING'] = 'utf-8' os.environ['PYTHONLEGACYWINDOWSSTDIO'] = '0' # Windows系统额外配置 if os.name == 'nt': try: os.system('chcp 65001 >nul 2>&1') if hasattr(sys.stdout, 'reconfigure'): sys.stdout.reconfigure(encoding='utf-8', errors='replace') sys.stderr.reconfigure(encoding='utf-8', errors='replace') else: import io if hasattr(sys.stdout, 'buffer'): sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace', line_buffering=True) sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace', line_buffering=True) except Exception as e: print(f"Warning: Failed to set UTF-8 encoding: {e}") def test_training_log_fix(): """测试训练日志修复效果""" print("🧪 开始测试训练日志修复效果") print("=" * 50) # 1. 测试导入修复 print("1️⃣ 测试模块导入...") try: # 添加server目录到路径 server_dir = os.path.join(os.getcwd(), 'server') if server_dir not in sys.path: sys.path.insert(0, server_dir) # 测试导入training_progress from utils.training_progress import TrainingProgressManager print("✅ 成功导入 TrainingProgressManager") # 测试创建实例 manager = TrainingProgressManager() print("✅ 成功创建进度管理器实例") # 测试基本功能 manager.reset() manager.start_training( training_id="test_001", product_id="P001", model_type="transformer", training_mode="product", total_epochs=3, total_batches=10, batch_size=32, total_samples=320 ) print("✅ 进度管理器基本功能正常") manager.start_epoch(1) manager.update_batch(5, 0.1234, 0.001) manager.finish_epoch(0.1234) manager.finish_training(True) print("✅ 进度管理器完整流程测试通过") except Exception as e: print(f"❌ 模块导入测试失败: {e}") import traceback traceback.print_exc() return False # 2. 测试API服务器启动 print("\n2️⃣ 测试API服务器模块加载...") try: # 测试导入api模块中的关键组件 from core.predictor import PharmacyPredictor print("✅ 成功导入 PharmacyPredictor") from trainers.transformer_trainer import train_product_model_with_transformer print("✅ 成功导入 transformer 训练器") # 测试创建预测器实例 predictor = PharmacyPredictor() print("✅ 成功创建预测器实例") except Exception as e: print(f"❌ API服务器模块测试失败: {e}") import traceback traceback.print_exc() return False # 3. 测试API连接(如果服务器正在运行) print("\n3️⃣ 测试API连接...") if REQUESTS_AVAILABLE: try: response = requests.get('http://localhost:5000/api/products', timeout=5) if response.status_code == 200: print("✅ API服务器正在运行并响应正常") # 测试模型列表端点 models_response = requests.get('http://localhost:5000/api/models', timeout=5) if models_response.status_code == 200: print("✅ 模型列表端点正常") else: print(f"⚠️ 模型列表端点状态码: {models_response.status_code}") else: print(f"⚠️ API服务器响应状态码: {response.status_code}") except requests.exceptions.ConnectionError: print("ℹ️ API服务器未运行(这是正常的,需要手动启动)") except Exception as e: print(f"⚠️ API连接测试异常: {e}") else: print("ℹ️ requests库未安装,跳过API连接测试") print("\n" + "=" * 50) print("🎉 训练日志修复测试完成!") print("\n📋 下一步操作建议:") print("1. 启动API服务器: PYTHONIOENCODING=utf-8 uv run server/api.py") print("2. 在前端进行模型训练测试") print("3. 观察控制台是否有详细的训练进度输出") return True if __name__ == "__main__": test_training_log_fix()