ShopTRAINING/test/test_concurrent_training.py
2025-07-02 11:05:23 +08:00

157 lines
5.5 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
并发训练测试 - 测试多个模型同时训练的日志输出
"""
import os
import sys
import urllib.request
import urllib.parse
import json
import time
import threading
# 设置UTF-8编码
os.environ['PYTHONIOENCODING'] = 'utf-8'
os.environ['PYTHONLEGACYWINDOWSSTDIO'] = '0'
# Windows系统额外配置
if os.name == 'nt':
try:
os.system('chcp 65001 >nul 2>&1')
if hasattr(sys.stdout, 'reconfigure'):
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
except Exception as e:
print(f"Warning: Failed to set UTF-8 encoding: {e}")
def send_training_request(model_type, product_id="P001", epochs=1):
"""发送训练请求"""
training_data = {
"product_id": product_id,
"model_type": model_type,
"epochs": epochs,
"training_mode": "product"
}
try:
json_data = json.dumps(training_data).encode('utf-8')
req = urllib.request.Request(
"http://127.0.0.1:5000/api/training",
data=json_data,
headers={'Content-Type': 'application/json'}
)
with urllib.request.urlopen(req, timeout=30) as response:
return response.getcode(), json.loads(response.read().decode('utf-8'))
except Exception as e:
return None, str(e)
def test_single_training(model_type, test_name):
"""测试单个训练任务"""
print(f"🚀 {test_name}: 启动 {model_type} 训练...")
status, response = send_training_request(model_type)
if status == 200:
task_id = response.get('task_id')
print(f"{test_name}: {model_type} 训练启动成功, 任务ID: {task_id[:8]}")
return task_id
else:
print(f"{test_name}: {model_type} 训练启动失败: {response}")
return None
def main():
"""主测试函数"""
print("🔧 并发训练日志测试")
print("=" * 60)
print("🎯 目标: 测试多个模型同时训练时的日志输出")
print("📋 期望: 不同训练任务的日志能清晰区分")
print("=" * 60)
# 检查API连接
try:
with urllib.request.urlopen("http://127.0.0.1:5000/api/products", timeout=3) as response:
if response.getcode() != 200:
raise Exception("API不可用")
except:
print("❌ API服务器未运行请先启动: PYTHONIOENCODING=utf-8 uv run server/api.py")
return
print("✅ API服务器连接正常\n")
# 测试1: 单个训练任务
print("🧪 测试1: 单个训练任务")
print("-" * 40)
print("👀 观察API服务器控制台应该看到:")
print(" [时间戳][线程ID][任务ID][标签] 消息内容")
print(" 例如: [14:30:45][线程12345][826b7ef4][ENTRY] 🔥 训练任务线程启动")
task1 = test_single_training("transformer", "单任务测试")
if task1:
print("⏳ 等待5秒观察单任务日志...")
time.sleep(5)
# 测试2: 并发训练任务
print(f"\n🧪 测试2: 并发训练任务")
print("-" * 40)
print("👀 观察API服务器控制台应该看到:")
print(" 多个不同线程ID和任务ID的交替输出")
print(" 通过[线程ID]和[任务ID]可以区分不同的训练任务")
# 同时启动3个不同的训练任务
models = ["mlstm", "tcn", "kan"]
tasks = []
print("🚀 同时启动3个训练任务...")
for i, model in enumerate(models):
task_id = test_single_training(model, f"并发任务{i+1}")
if task_id:
tasks.append((model, task_id))
time.sleep(0.5) # 稍微错开启动时间
print(f"\n📊 启动了 {len(tasks)} 个并发训练任务:")
for model, task_id in tasks:
print(f" {model}: {task_id[:8]}")
print(f"\n⏳ 等待20秒观察并发训练日志...")
print("💡 重点观察:")
print(" 1. 不同任务的日志是否能通过线程ID和任务ID区分")
print(" 2. 日志输出是否完整,没有丢失")
print(" 3. 多个训练器是否能同时输出进度信息")
# 等待训练完成
for i in range(20):
time.sleep(1)
print(f" 等待中... {i+1}/20秒", end='\r')
print(f"\n\n📈 训练应该已完成,检查最终状态...")
# 检查任务状态
try:
with urllib.request.urlopen("http://127.0.0.1:5000/api/training", timeout=5) as response:
result = json.loads(response.read().decode('utf-8'))
recent_tasks = [t for t in result.get('data', []) if t.get('task_id') in [task[1] for task in tasks]]
print(f"📊 并发任务完成状态:")
for task in recent_tasks:
status = task.get('status', 'unknown')
model_type = task.get('model_type', 'unknown')
task_id = task.get('task_id', 'unknown')[:8]
print(f" {model_type}: {task_id} - {status}")
except Exception as e:
print(f"❌ 检查状态失败: {e}")
print("\n" + "=" * 60)
print("🎯 并发训练测试总结:")
print("1. 检查API服务器控制台是否有清晰的多线程日志输出")
print("2. 验证不同训练任务的日志是否能够区分")
print("3. 确认并发训练时日志不会混乱或丢失")
print("4. 线程安全的日志输出函数是否正常工作")
print("=" * 60)
if __name__ == "__main__":
main()