ShopTRAINING/test/test_websocket_fix.py
2025-07-02 11:05:23 +08:00

165 lines
4.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
测试WebSocket修复效果
"""
import os
import sys
import time
import requests
import socketio
# 设置环境变量
os.environ['PYTHONIOENCODING'] = 'utf-8'
def test_websocket_connection():
"""测试WebSocket连接"""
print("=" * 60)
print("🧪 WebSocket连接修复测试")
print("=" * 60)
# API服务器地址
api_base = "http://localhost:5000"
# 测试HTTP连接
try:
response = requests.get(f"{api_base}/api/health", timeout=5)
if response.status_code == 200:
print("✅ HTTP API服务器连接正常")
else:
print(f"⚠️ HTTP API服务器响应异常: {response.status_code}")
return
except Exception as e:
print(f"❌ 无法连接到API服务器: {e}")
print("💡 请先启动API服务器: PYTHONIOENCODING=utf-8 uv run server/api.py")
return
# 创建SocketIO客户端
sio = socketio.Client(logger=True, engineio_logger=False)
# 连接状态
connected = False
connection_error = None
@sio.event
def connect():
nonlocal connected
connected = True
print("✅ WebSocket连接成功")
@sio.event
def disconnect():
print("🔌 WebSocket连接断开")
@sio.event
def connection_established(data):
print(f"📡 收到连接确认: {data}")
@sio.event
def training_progress_detailed(data):
task_id = data.get('task_id', 'unknown')[:8]
message = data.get('message', 'No message')
print(f"📊 [训练进度] {task_id}: {message}")
@sio.event
def training_update(data):
task_id = data.get('task_id', 'unknown')[:8]
status = data.get('status', 'unknown')
progress = data.get('progress', 0)
print(f"🔄 [训练更新] {task_id}: {status} ({progress}%)")
@sio.on('*')
def catch_all(event, data):
print(f"📨 收到事件: {event} -> {data}")
# 尝试连接WebSocket
try:
print(f"\n🔗 尝试连接WebSocket: {api_base}")
sio.connect(f"{api_base}", namespaces=['/training'])
# 等待连接建立
timeout = 10
for i in range(timeout):
if connected:
break
time.sleep(1)
print(f"⏳ 等待连接... ({i+1}/{timeout})")
if connected:
print(f"🎉 WebSocket连接测试成功!")
# 测试保持连接几秒钟
print(f"\n⏱️ 保持连接5秒钟监听事件...")
time.sleep(5)
# 断开连接
sio.disconnect()
print(f"👋 主动断开WebSocket连接")
else:
print(f"❌ WebSocket连接超时")
except Exception as e:
connection_error = str(e)
print(f"❌ WebSocket连接失败: {e}")
# 测试HTTP轮询备用方案
try:
print(f"\n🔄 测试HTTP轮询方式...")
sio_polling = socketio.Client(logger=False, engineio_logger=False)
polling_connected = False
@sio_polling.event
def connect():
nonlocal polling_connected
polling_connected = True
print("✅ HTTP轮询连接成功")
# 使用轮询传输方式
sio_polling.connect(f"{api_base}", transports=['polling'], namespaces=['/training'])
# 等待连接
for i in range(5):
if polling_connected:
break
time.sleep(1)
if polling_connected:
print("🎉 HTTP轮询备用方案工作正常")
sio_polling.disconnect()
else:
print("⚠️ HTTP轮询也无法连接")
except Exception as e:
print(f"⚠️ HTTP轮询测试失败: {e}")
print("\n" + "=" * 60)
print("🎯 WebSocket测试结果:")
if connected:
print("✅ WebSocket协议: 正常")
else:
print("❌ WebSocket协议: 失败")
if connection_error:
print(f" 错误详情: {connection_error}")
if polling_connected:
print("✅ HTTP轮询备用: 正常")
else:
print("⚠️ HTTP轮询备用: 需要检查")
print("\n💡 修复建议:")
if not connected and not polling_connected:
print("1. 检查Flask-SocketIO版本兼容性")
print("2. 确认防火墙和端口设置")
print("3. 检查服务器日志输出")
elif connected:
print("✅ WebSocket修复成功可以正常使用实时功能")
print("=" * 60)
if __name__ == "__main__":
test_websocket_connection()