157 lines
5.5 KiB
Python
157 lines
5.5 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
并发训练测试 - 测试多个模型同时训练的日志输出
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import urllib.request
|
||
import urllib.parse
|
||
import json
|
||
import time
|
||
import threading
|
||
|
||
# 设置UTF-8编码
|
||
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
||
os.environ['PYTHONLEGACYWINDOWSSTDIO'] = '0'
|
||
|
||
# Windows系统额外配置
|
||
if os.name == 'nt':
|
||
try:
|
||
os.system('chcp 65001 >nul 2>&1')
|
||
if hasattr(sys.stdout, 'reconfigure'):
|
||
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
|
||
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
|
||
except Exception as e:
|
||
print(f"Warning: Failed to set UTF-8 encoding: {e}")
|
||
|
||
def send_training_request(model_type, product_id="P001", epochs=1):
|
||
"""发送训练请求"""
|
||
training_data = {
|
||
"product_id": product_id,
|
||
"model_type": model_type,
|
||
"epochs": epochs,
|
||
"training_mode": "product"
|
||
}
|
||
|
||
try:
|
||
json_data = json.dumps(training_data).encode('utf-8')
|
||
req = urllib.request.Request(
|
||
"http://127.0.0.1:5000/api/training",
|
||
data=json_data,
|
||
headers={'Content-Type': 'application/json'}
|
||
)
|
||
|
||
with urllib.request.urlopen(req, timeout=30) as response:
|
||
return response.getcode(), json.loads(response.read().decode('utf-8'))
|
||
except Exception as e:
|
||
return None, str(e)
|
||
|
||
def test_single_training(model_type, test_name):
|
||
"""测试单个训练任务"""
|
||
print(f"🚀 {test_name}: 启动 {model_type} 训练...")
|
||
|
||
status, response = send_training_request(model_type)
|
||
|
||
if status == 200:
|
||
task_id = response.get('task_id')
|
||
print(f"✅ {test_name}: {model_type} 训练启动成功, 任务ID: {task_id[:8]}")
|
||
return task_id
|
||
else:
|
||
print(f"❌ {test_name}: {model_type} 训练启动失败: {response}")
|
||
return None
|
||
|
||
def main():
|
||
"""主测试函数"""
|
||
print("🔧 并发训练日志测试")
|
||
print("=" * 60)
|
||
print("🎯 目标: 测试多个模型同时训练时的日志输出")
|
||
print("📋 期望: 不同训练任务的日志能清晰区分")
|
||
print("=" * 60)
|
||
|
||
# 检查API连接
|
||
try:
|
||
with urllib.request.urlopen("http://127.0.0.1:5000/api/products", timeout=3) as response:
|
||
if response.getcode() != 200:
|
||
raise Exception("API不可用")
|
||
except:
|
||
print("❌ API服务器未运行,请先启动: PYTHONIOENCODING=utf-8 uv run server/api.py")
|
||
return
|
||
|
||
print("✅ API服务器连接正常\n")
|
||
|
||
# 测试1: 单个训练任务
|
||
print("🧪 测试1: 单个训练任务")
|
||
print("-" * 40)
|
||
print("👀 观察API服务器控制台,应该看到:")
|
||
print(" [时间戳][线程ID][任务ID][标签] 消息内容")
|
||
print(" 例如: [14:30:45][线程12345][826b7ef4][ENTRY] 🔥 训练任务线程启动")
|
||
|
||
task1 = test_single_training("transformer", "单任务测试")
|
||
|
||
if task1:
|
||
print("⏳ 等待5秒观察单任务日志...")
|
||
time.sleep(5)
|
||
|
||
# 测试2: 并发训练任务
|
||
print(f"\n🧪 测试2: 并发训练任务")
|
||
print("-" * 40)
|
||
print("👀 观察API服务器控制台,应该看到:")
|
||
print(" 多个不同线程ID和任务ID的交替输出")
|
||
print(" 通过[线程ID]和[任务ID]可以区分不同的训练任务")
|
||
|
||
# 同时启动3个不同的训练任务
|
||
models = ["mlstm", "tcn", "kan"]
|
||
tasks = []
|
||
|
||
print("🚀 同时启动3个训练任务...")
|
||
for i, model in enumerate(models):
|
||
task_id = test_single_training(model, f"并发任务{i+1}")
|
||
if task_id:
|
||
tasks.append((model, task_id))
|
||
time.sleep(0.5) # 稍微错开启动时间
|
||
|
||
print(f"\n📊 启动了 {len(tasks)} 个并发训练任务:")
|
||
for model, task_id in tasks:
|
||
print(f" {model}: {task_id[:8]}")
|
||
|
||
print(f"\n⏳ 等待20秒观察并发训练日志...")
|
||
print("💡 重点观察:")
|
||
print(" 1. 不同任务的日志是否能通过线程ID和任务ID区分")
|
||
print(" 2. 日志输出是否完整,没有丢失")
|
||
print(" 3. 多个训练器是否能同时输出进度信息")
|
||
|
||
# 等待训练完成
|
||
for i in range(20):
|
||
time.sleep(1)
|
||
print(f" 等待中... {i+1}/20秒", end='\r')
|
||
|
||
print(f"\n\n📈 训练应该已完成,检查最终状态...")
|
||
|
||
# 检查任务状态
|
||
try:
|
||
with urllib.request.urlopen("http://127.0.0.1:5000/api/training", timeout=5) as response:
|
||
result = json.loads(response.read().decode('utf-8'))
|
||
recent_tasks = [t for t in result.get('data', []) if t.get('task_id') in [task[1] for task in tasks]]
|
||
|
||
print(f"📊 并发任务完成状态:")
|
||
for task in recent_tasks:
|
||
status = task.get('status', 'unknown')
|
||
model_type = task.get('model_type', 'unknown')
|
||
task_id = task.get('task_id', 'unknown')[:8]
|
||
print(f" {model_type}: {task_id} - {status}")
|
||
|
||
except Exception as e:
|
||
print(f"❌ 检查状态失败: {e}")
|
||
|
||
print("\n" + "=" * 60)
|
||
print("🎯 并发训练测试总结:")
|
||
print("1. 检查API服务器控制台是否有清晰的多线程日志输出")
|
||
print("2. 验证不同训练任务的日志是否能够区分")
|
||
print("3. 确认并发训练时日志不会混乱或丢失")
|
||
print("4. 线程安全的日志输出函数是否正常工作")
|
||
print("=" * 60)
|
||
|
||
if __name__ == "__main__":
|
||
main() |