ShopTRAINING/test/test_concurrent_training.py

157 lines
5.5 KiB
Python
Raw Permalink Normal View History

2025-07-02 11:05:23 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
并发训练测试 - 测试多个模型同时训练的日志输出
"""
import os
import sys
import urllib.request
import urllib.parse
import json
import time
import threading
# 设置UTF-8编码
os.environ['PYTHONIOENCODING'] = 'utf-8'
os.environ['PYTHONLEGACYWINDOWSSTDIO'] = '0'
# Windows系统额外配置
if os.name == 'nt':
try:
os.system('chcp 65001 >nul 2>&1')
if hasattr(sys.stdout, 'reconfigure'):
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
except Exception as e:
print(f"Warning: Failed to set UTF-8 encoding: {e}")
def send_training_request(model_type, product_id="P001", epochs=1):
"""发送训练请求"""
training_data = {
"product_id": product_id,
"model_type": model_type,
"epochs": epochs,
"training_mode": "product"
}
try:
json_data = json.dumps(training_data).encode('utf-8')
req = urllib.request.Request(
"http://127.0.0.1:5000/api/training",
data=json_data,
headers={'Content-Type': 'application/json'}
)
with urllib.request.urlopen(req, timeout=30) as response:
return response.getcode(), json.loads(response.read().decode('utf-8'))
except Exception as e:
return None, str(e)
def test_single_training(model_type, test_name):
"""测试单个训练任务"""
print(f"🚀 {test_name}: 启动 {model_type} 训练...")
status, response = send_training_request(model_type)
if status == 200:
task_id = response.get('task_id')
print(f"{test_name}: {model_type} 训练启动成功, 任务ID: {task_id[:8]}")
return task_id
else:
print(f"{test_name}: {model_type} 训练启动失败: {response}")
return None
def main():
"""主测试函数"""
print("🔧 并发训练日志测试")
print("=" * 60)
print("🎯 目标: 测试多个模型同时训练时的日志输出")
print("📋 期望: 不同训练任务的日志能清晰区分")
print("=" * 60)
# 检查API连接
try:
with urllib.request.urlopen("http://127.0.0.1:5000/api/products", timeout=3) as response:
if response.getcode() != 200:
raise Exception("API不可用")
except:
print("❌ API服务器未运行请先启动: PYTHONIOENCODING=utf-8 uv run server/api.py")
return
print("✅ API服务器连接正常\n")
# 测试1: 单个训练任务
print("🧪 测试1: 单个训练任务")
print("-" * 40)
print("👀 观察API服务器控制台应该看到:")
print(" [时间戳][线程ID][任务ID][标签] 消息内容")
print(" 例如: [14:30:45][线程12345][826b7ef4][ENTRY] 🔥 训练任务线程启动")
task1 = test_single_training("transformer", "单任务测试")
if task1:
print("⏳ 等待5秒观察单任务日志...")
time.sleep(5)
# 测试2: 并发训练任务
print(f"\n🧪 测试2: 并发训练任务")
print("-" * 40)
print("👀 观察API服务器控制台应该看到:")
print(" 多个不同线程ID和任务ID的交替输出")
print(" 通过[线程ID]和[任务ID]可以区分不同的训练任务")
# 同时启动3个不同的训练任务
models = ["mlstm", "tcn", "kan"]
tasks = []
print("🚀 同时启动3个训练任务...")
for i, model in enumerate(models):
task_id = test_single_training(model, f"并发任务{i+1}")
if task_id:
tasks.append((model, task_id))
time.sleep(0.5) # 稍微错开启动时间
print(f"\n📊 启动了 {len(tasks)} 个并发训练任务:")
for model, task_id in tasks:
print(f" {model}: {task_id[:8]}")
print(f"\n⏳ 等待20秒观察并发训练日志...")
print("💡 重点观察:")
print(" 1. 不同任务的日志是否能通过线程ID和任务ID区分")
print(" 2. 日志输出是否完整,没有丢失")
print(" 3. 多个训练器是否能同时输出进度信息")
# 等待训练完成
for i in range(20):
time.sleep(1)
print(f" 等待中... {i+1}/20秒", end='\r')
print(f"\n\n📈 训练应该已完成,检查最终状态...")
# 检查任务状态
try:
with urllib.request.urlopen("http://127.0.0.1:5000/api/training", timeout=5) as response:
result = json.loads(response.read().decode('utf-8'))
recent_tasks = [t for t in result.get('data', []) if t.get('task_id') in [task[1] for task in tasks]]
print(f"📊 并发任务完成状态:")
for task in recent_tasks:
status = task.get('status', 'unknown')
model_type = task.get('model_type', 'unknown')
task_id = task.get('task_id', 'unknown')[:8]
print(f" {model_type}: {task_id} - {status}")
except Exception as e:
print(f"❌ 检查状态失败: {e}")
print("\n" + "=" * 60)
print("🎯 并发训练测试总结:")
print("1. 检查API服务器控制台是否有清晰的多线程日志输出")
print("2. 验证不同训练任务的日志是否能够区分")
print("3. 确认并发训练时日志不会混乱或丢失")
print("4. 线程安全的日志输出函数是否正常工作")
print("=" * 60)
if __name__ == "__main__":
main()