175 lines
6.2 KiB
Python
175 lines
6.2 KiB
Python
![]() |
#!/usr/bin/env python3
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
"""
|
|||
|
测试现代化训练系统
|
|||
|
验证loguru日志和独立训练进程的效果
|
|||
|
"""
|
|||
|
|
|||
|
import os
|
|||
|
import sys
|
|||
|
import urllib.request
|
|||
|
import json
|
|||
|
import time
|
|||
|
|
|||
|
# 设置UTF-8编码
|
|||
|
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
|||
|
os.environ['PYTHONLEGACYWINDOWSSTDIO'] = '0'
|
|||
|
|
|||
|
# Windows系统额外配置
|
|||
|
if os.name == 'nt':
|
|||
|
try:
|
|||
|
os.system('chcp 65001 >nul 2>&1')
|
|||
|
if hasattr(sys.stdout, 'reconfigure'):
|
|||
|
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
|
|||
|
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
|
|||
|
except Exception as e:
|
|||
|
print(f"Warning: Failed to set UTF-8 encoding: {e}")
|
|||
|
|
|||
|
def send_training_request():
|
|||
|
"""发送训练请求"""
|
|||
|
training_data = {
|
|||
|
"product_id": "P001",
|
|||
|
"model_type": "transformer",
|
|||
|
"epochs": 3, # 短期训练用于测试
|
|||
|
"training_mode": "product"
|
|||
|
}
|
|||
|
|
|||
|
try:
|
|||
|
json_data = json.dumps(training_data).encode('utf-8')
|
|||
|
req = urllib.request.Request(
|
|||
|
"http://127.0.0.1:5000/api/training",
|
|||
|
data=json_data,
|
|||
|
headers={'Content-Type': 'application/json'}
|
|||
|
)
|
|||
|
|
|||
|
with urllib.request.urlopen(req, timeout=30) as response:
|
|||
|
return response.getcode(), json.loads(response.read().decode('utf-8'))
|
|||
|
except Exception as e:
|
|||
|
return None, str(e)
|
|||
|
|
|||
|
def check_task_status(task_id):
|
|||
|
"""检查任务状态"""
|
|||
|
try:
|
|||
|
with urllib.request.urlopen(f"http://127.0.0.1:5000/api/training/{task_id}", timeout=5) as response:
|
|||
|
return json.loads(response.read().decode('utf-8'))
|
|||
|
except Exception as e:
|
|||
|
return {'error': str(e)}
|
|||
|
|
|||
|
def get_all_tasks():
|
|||
|
"""获取所有任务"""
|
|||
|
try:
|
|||
|
with urllib.request.urlopen("http://127.0.0.1:5000/api/training", timeout=5) as response:
|
|||
|
return json.loads(response.read().decode('utf-8'))
|
|||
|
except Exception as e:
|
|||
|
return {'error': str(e)}
|
|||
|
|
|||
|
def main():
|
|||
|
"""主测试函数"""
|
|||
|
print("🚀 现代化训练系统测试")
|
|||
|
print("=" * 60)
|
|||
|
print("📋 新系统特性:")
|
|||
|
print(" ✅ loguru现代化日志系统")
|
|||
|
print(" ✅ 独立训练进程(避免GIL限制)")
|
|||
|
print(" ✅ 完美中文和emoji支持")
|
|||
|
print(" ✅ 多线程安全日志输出")
|
|||
|
print(" ✅ 实时进度WebSocket推送")
|
|||
|
print("=" * 60)
|
|||
|
|
|||
|
# 检查API连接
|
|||
|
try:
|
|||
|
with urllib.request.urlopen("http://127.0.0.1:5000/api/products", timeout=3) as response:
|
|||
|
if response.getcode() != 200:
|
|||
|
raise Exception("API不可用")
|
|||
|
except:
|
|||
|
print("❌ API服务器未运行,请先启动新版API服务器")
|
|||
|
print(" 启动命令: uv run server/api.py")
|
|||
|
return
|
|||
|
|
|||
|
print("✅ API服务器连接正常")
|
|||
|
print()
|
|||
|
|
|||
|
print("🎯 测试1: 提交训练任务到新的进程管理器")
|
|||
|
print("-" * 50)
|
|||
|
|
|||
|
status, response = send_training_request()
|
|||
|
|
|||
|
if status == 200:
|
|||
|
task_id = response.get('task_id')
|
|||
|
print(f"✅ 训练任务提交成功")
|
|||
|
print(f"🆔 任务ID: {task_id}")
|
|||
|
print(f"💬 响应消息: {response.get('message')}")
|
|||
|
print(f"🔧 使用进程管理器: {'是' if '独立进程' in response.get('message', '') else '否'}")
|
|||
|
print()
|
|||
|
|
|||
|
print("👀 观察要点:")
|
|||
|
print("1. API服务器控制台应显示彩色的loguru日志")
|
|||
|
print("2. 训练任务在独立进程中运行")
|
|||
|
print("3. 日志格式: [时间] | 级别 | 线程ID | 消息")
|
|||
|
print("4. 中文和emoji应完美显示")
|
|||
|
print()
|
|||
|
|
|||
|
print("⏳ 监控训练进度...")
|
|||
|
for i in range(15):
|
|||
|
time.sleep(2)
|
|||
|
|
|||
|
# 查询任务状态
|
|||
|
task_status = check_task_status(task_id)
|
|||
|
if task_status.get('status') == 'success':
|
|||
|
data = task_status.get('data', {})
|
|||
|
status_text = data.get('status', 'unknown')
|
|||
|
progress = data.get('progress', 0)
|
|||
|
message = data.get('message', '')
|
|||
|
process_id = data.get('process_id', 'N/A')
|
|||
|
|
|||
|
print(f"📊 进度检查 {i+1}/15: 状态={status_text}, 进度={progress:.1f}%, "
|
|||
|
f"进程ID={process_id}, 消息={message}")
|
|||
|
|
|||
|
if status_text in ['completed', 'failed']:
|
|||
|
break
|
|||
|
else:
|
|||
|
print(f"⚠️ 状态查询失败: {task_status}")
|
|||
|
|
|||
|
print()
|
|||
|
print("📈 最终状态检查...")
|
|||
|
final_status = check_task_status(task_id)
|
|||
|
if final_status.get('status') == 'success':
|
|||
|
data = final_status.get('data', {})
|
|||
|
print(f"🏁 最终状态: {data.get('status')}")
|
|||
|
print(f"📊 最终进度: {data.get('progress', 0):.1f}%")
|
|||
|
if data.get('metrics'):
|
|||
|
print(f"📈 训练指标: {data.get('metrics')}")
|
|||
|
if data.get('error'):
|
|||
|
print(f"❌ 错误信息: {data.get('error')}")
|
|||
|
else:
|
|||
|
print(f"❌ 训练请求失败: {response}")
|
|||
|
|
|||
|
print()
|
|||
|
print("🎯 测试2: 检查所有任务列表")
|
|||
|
print("-" * 50)
|
|||
|
|
|||
|
all_tasks = get_all_tasks()
|
|||
|
if all_tasks.get('status') == 'success':
|
|||
|
tasks = all_tasks.get('data', [])
|
|||
|
print(f"📋 发现 {len(tasks)} 个训练任务:")
|
|||
|
|
|||
|
for i, task in enumerate(tasks[:5]): # 只显示前5个
|
|||
|
print(f" {i+1}. ID: {task.get('task_id', 'N/A')[:8]}...")
|
|||
|
print(f" 状态: {task.get('status', 'N/A')}")
|
|||
|
print(f" 模型: {task.get('model_type', 'N/A')}")
|
|||
|
print(f" 产品: {task.get('product_id', 'N/A')}")
|
|||
|
print(f" 进程ID: {task.get('process_id', 'N/A')}")
|
|||
|
else:
|
|||
|
print(f"❌ 获取任务列表失败: {all_tasks}")
|
|||
|
|
|||
|
print()
|
|||
|
print("=" * 60)
|
|||
|
print("🎯 测试总结:")
|
|||
|
print("1. 检查API服务器控制台的日志输出格式")
|
|||
|
print("2. 验证训练任务是否在独立进程中运行")
|
|||
|
print("3. 确认中文和emoji显示是否正常")
|
|||
|
print("4. 观察是否还有重复日志输出的问题")
|
|||
|
print("5. 测试进程间通信和状态同步")
|
|||
|
print("=" * 60)
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
main()
|