ShopTRAINING/test/test_thread_output.py
2025-07-02 11:05:23 +08:00

102 lines
3.2 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
测试线程输出问题 - 验证子线程中的print是否能正确显示
"""
import os
import sys
import threading
import time
# 设置UTF-8编码
os.environ['PYTHONIOENCODING'] = 'utf-8'
def test_main_thread_output():
"""测试主线程输出"""
print("📍 [主线程] 开始测试", flush=True)
print("📍 [主线程] 这是主线程的输出", flush=True)
print("📍 [主线程] 中文测试: 药店销售预测系统", flush=True)
print("📍 [主线程] 表情符号: 🚀 📊 🤖", flush=True)
def test_child_thread_output():
"""测试子线程输出"""
print("🔗 [子线程] 子线程开始执行", flush=True)
# 模拟训练过程
for i in range(3):
print(f"🔗 [子线程] 模拟训练 Epoch {i+1}/3", flush=True)
print(f"🔗 [子线程] 中文输出: 正在训练第 {i+1}", flush=True)
time.sleep(1)
print("🔗 [子线程] 训练完成!", flush=True)
print("🔗 [子线程] 子线程执行结束", flush=True)
def test_thread_with_sys_flush():
"""测试带强制刷新的子线程输出"""
print("⚡ [强制刷新线程] 开始执行", flush=True)
sys.stdout.flush()
sys.stderr.flush()
for i in range(3):
print(f"⚡ [强制刷新线程] 训练进度 {i+1}/3", flush=True)
# 强制刷新
sys.stdout.flush()
sys.stderr.flush()
time.sleep(0.5)
print("⚡ [强制刷新线程] 完成!", flush=True)
sys.stdout.flush()
def main():
"""主测试函数"""
print("=" * 60)
print("🧪 线程输出测试开始")
print("=" * 60)
# 1. 测试主线程输出
print("\n1⃣ 主线程输出测试:")
test_main_thread_output()
# 2. 测试普通子线程输出
print("\n2⃣ 普通子线程输出测试:")
thread1 = threading.Thread(target=test_child_thread_output)
thread1.start()
thread1.join()
# 3. 测试强制刷新的子线程输出
print("\n3⃣ 强制刷新子线程输出测试:")
thread2 = threading.Thread(target=test_thread_with_sys_flush)
thread2.start()
thread2.join()
# 4. 模拟API训练任务的线程模式
print("\n4⃣ 模拟API训练任务:")
def simulate_api_training():
"""模拟API训练任务"""
task_id = "test-task-123"
print(f"🚀 [API模拟] 训练任务开始: {task_id}", flush=True)
# 模拟预测器调用
for epoch in range(2):
print(f"📊 [API模拟] Epoch {epoch+1}/2, 训练损失: 0.12{epoch}", flush=True)
sys.stdout.flush()
time.sleep(1)
print(f"✅ [API模拟] 训练完成: {task_id}", flush=True)
sys.stdout.flush()
api_thread = threading.Thread(target=simulate_api_training)
api_thread.start()
api_thread.join()
print("\n" + "=" * 60)
print("🎯 测试结论:")
print("1. 如果看到所有线程的输出,说明线程输出正常")
print("2. 如果只看到主线程输出,说明子线程输出被阻塞")
print("3. 这将帮助我们确定API训练日志问题的根源")
print("=" * 60)
if __name__ == "__main__":
main()