175 lines
6.2 KiB
Python
175 lines
6.2 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
测试现代化训练系统
|
||
验证loguru日志和独立训练进程的效果
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import urllib.request
|
||
import json
|
||
import time
|
||
|
||
# 设置UTF-8编码
|
||
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
||
os.environ['PYTHONLEGACYWINDOWSSTDIO'] = '0'
|
||
|
||
# Windows系统额外配置
|
||
if os.name == 'nt':
|
||
try:
|
||
os.system('chcp 65001 >nul 2>&1')
|
||
if hasattr(sys.stdout, 'reconfigure'):
|
||
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
|
||
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
|
||
except Exception as e:
|
||
print(f"Warning: Failed to set UTF-8 encoding: {e}")
|
||
|
||
def send_training_request():
|
||
"""发送训练请求"""
|
||
training_data = {
|
||
"product_id": "P001",
|
||
"model_type": "transformer",
|
||
"epochs": 3, # 短期训练用于测试
|
||
"training_mode": "product"
|
||
}
|
||
|
||
try:
|
||
json_data = json.dumps(training_data).encode('utf-8')
|
||
req = urllib.request.Request(
|
||
"http://127.0.0.1:5000/api/training",
|
||
data=json_data,
|
||
headers={'Content-Type': 'application/json'}
|
||
)
|
||
|
||
with urllib.request.urlopen(req, timeout=30) as response:
|
||
return response.getcode(), json.loads(response.read().decode('utf-8'))
|
||
except Exception as e:
|
||
return None, str(e)
|
||
|
||
def check_task_status(task_id):
|
||
"""检查任务状态"""
|
||
try:
|
||
with urllib.request.urlopen(f"http://127.0.0.1:5000/api/training/{task_id}", timeout=5) as response:
|
||
return json.loads(response.read().decode('utf-8'))
|
||
except Exception as e:
|
||
return {'error': str(e)}
|
||
|
||
def get_all_tasks():
|
||
"""获取所有任务"""
|
||
try:
|
||
with urllib.request.urlopen("http://127.0.0.1:5000/api/training", timeout=5) as response:
|
||
return json.loads(response.read().decode('utf-8'))
|
||
except Exception as e:
|
||
return {'error': str(e)}
|
||
|
||
def main():
|
||
"""主测试函数"""
|
||
print("🚀 现代化训练系统测试")
|
||
print("=" * 60)
|
||
print("📋 新系统特性:")
|
||
print(" ✅ loguru现代化日志系统")
|
||
print(" ✅ 独立训练进程(避免GIL限制)")
|
||
print(" ✅ 完美中文和emoji支持")
|
||
print(" ✅ 多线程安全日志输出")
|
||
print(" ✅ 实时进度WebSocket推送")
|
||
print("=" * 60)
|
||
|
||
# 检查API连接
|
||
try:
|
||
with urllib.request.urlopen("http://127.0.0.1:5000/api/products", timeout=3) as response:
|
||
if response.getcode() != 200:
|
||
raise Exception("API不可用")
|
||
except:
|
||
print("❌ API服务器未运行,请先启动新版API服务器")
|
||
print(" 启动命令: uv run server/api.py")
|
||
return
|
||
|
||
print("✅ API服务器连接正常")
|
||
print()
|
||
|
||
print("🎯 测试1: 提交训练任务到新的进程管理器")
|
||
print("-" * 50)
|
||
|
||
status, response = send_training_request()
|
||
|
||
if status == 200:
|
||
task_id = response.get('task_id')
|
||
print(f"✅ 训练任务提交成功")
|
||
print(f"🆔 任务ID: {task_id}")
|
||
print(f"💬 响应消息: {response.get('message')}")
|
||
print(f"🔧 使用进程管理器: {'是' if '独立进程' in response.get('message', '') else '否'}")
|
||
print()
|
||
|
||
print("👀 观察要点:")
|
||
print("1. API服务器控制台应显示彩色的loguru日志")
|
||
print("2. 训练任务在独立进程中运行")
|
||
print("3. 日志格式: [时间] | 级别 | 线程ID | 消息")
|
||
print("4. 中文和emoji应完美显示")
|
||
print()
|
||
|
||
print("⏳ 监控训练进度...")
|
||
for i in range(15):
|
||
time.sleep(2)
|
||
|
||
# 查询任务状态
|
||
task_status = check_task_status(task_id)
|
||
if task_status.get('status') == 'success':
|
||
data = task_status.get('data', {})
|
||
status_text = data.get('status', 'unknown')
|
||
progress = data.get('progress', 0)
|
||
message = data.get('message', '')
|
||
process_id = data.get('process_id', 'N/A')
|
||
|
||
print(f"📊 进度检查 {i+1}/15: 状态={status_text}, 进度={progress:.1f}%, "
|
||
f"进程ID={process_id}, 消息={message}")
|
||
|
||
if status_text in ['completed', 'failed']:
|
||
break
|
||
else:
|
||
print(f"⚠️ 状态查询失败: {task_status}")
|
||
|
||
print()
|
||
print("📈 最终状态检查...")
|
||
final_status = check_task_status(task_id)
|
||
if final_status.get('status') == 'success':
|
||
data = final_status.get('data', {})
|
||
print(f"🏁 最终状态: {data.get('status')}")
|
||
print(f"📊 最终进度: {data.get('progress', 0):.1f}%")
|
||
if data.get('metrics'):
|
||
print(f"📈 训练指标: {data.get('metrics')}")
|
||
if data.get('error'):
|
||
print(f"❌ 错误信息: {data.get('error')}")
|
||
else:
|
||
print(f"❌ 训练请求失败: {response}")
|
||
|
||
print()
|
||
print("🎯 测试2: 检查所有任务列表")
|
||
print("-" * 50)
|
||
|
||
all_tasks = get_all_tasks()
|
||
if all_tasks.get('status') == 'success':
|
||
tasks = all_tasks.get('data', [])
|
||
print(f"📋 发现 {len(tasks)} 个训练任务:")
|
||
|
||
for i, task in enumerate(tasks[:5]): # 只显示前5个
|
||
print(f" {i+1}. ID: {task.get('task_id', 'N/A')[:8]}...")
|
||
print(f" 状态: {task.get('status', 'N/A')}")
|
||
print(f" 模型: {task.get('model_type', 'N/A')}")
|
||
print(f" 产品: {task.get('product_id', 'N/A')}")
|
||
print(f" 进程ID: {task.get('process_id', 'N/A')}")
|
||
else:
|
||
print(f"❌ 获取任务列表失败: {all_tasks}")
|
||
|
||
print()
|
||
print("=" * 60)
|
||
print("🎯 测试总结:")
|
||
print("1. 检查API服务器控制台的日志输出格式")
|
||
print("2. 验证训练任务是否在独立进程中运行")
|
||
print("3. 确认中文和emoji显示是否正常")
|
||
print("4. 观察是否还有重复日志输出的问题")
|
||
print("5. 测试进程间通信和状态同步")
|
||
print("=" * 60)
|
||
|
||
if __name__ == "__main__":
|
||
main() |