165 lines
4.7 KiB
Python
165 lines
4.7 KiB
Python
![]() |
#!/usr/bin/env python3
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
"""
|
|||
|
测试WebSocket修复效果
|
|||
|
"""
|
|||
|
|
|||
|
import os
|
|||
|
import sys
|
|||
|
import time
|
|||
|
import requests
|
|||
|
import socketio
|
|||
|
|
|||
|
# 设置环境变量
|
|||
|
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
|||
|
|
|||
|
def test_websocket_connection():
|
|||
|
"""测试WebSocket连接"""
|
|||
|
|
|||
|
print("=" * 60)
|
|||
|
print("🧪 WebSocket连接修复测试")
|
|||
|
print("=" * 60)
|
|||
|
|
|||
|
# API服务器地址
|
|||
|
api_base = "http://localhost:5000"
|
|||
|
|
|||
|
# 测试HTTP连接
|
|||
|
try:
|
|||
|
response = requests.get(f"{api_base}/api/health", timeout=5)
|
|||
|
if response.status_code == 200:
|
|||
|
print("✅ HTTP API服务器连接正常")
|
|||
|
else:
|
|||
|
print(f"⚠️ HTTP API服务器响应异常: {response.status_code}")
|
|||
|
return
|
|||
|
except Exception as e:
|
|||
|
print(f"❌ 无法连接到API服务器: {e}")
|
|||
|
print("💡 请先启动API服务器: PYTHONIOENCODING=utf-8 uv run server/api.py")
|
|||
|
return
|
|||
|
|
|||
|
# 创建SocketIO客户端
|
|||
|
sio = socketio.Client(logger=True, engineio_logger=False)
|
|||
|
|
|||
|
# 连接状态
|
|||
|
connected = False
|
|||
|
connection_error = None
|
|||
|
|
|||
|
@sio.event
|
|||
|
def connect():
|
|||
|
nonlocal connected
|
|||
|
connected = True
|
|||
|
print("✅ WebSocket连接成功")
|
|||
|
|
|||
|
@sio.event
|
|||
|
def disconnect():
|
|||
|
print("🔌 WebSocket连接断开")
|
|||
|
|
|||
|
@sio.event
|
|||
|
def connection_established(data):
|
|||
|
print(f"📡 收到连接确认: {data}")
|
|||
|
|
|||
|
@sio.event
|
|||
|
def training_progress_detailed(data):
|
|||
|
task_id = data.get('task_id', 'unknown')[:8]
|
|||
|
message = data.get('message', 'No message')
|
|||
|
print(f"📊 [训练进度] {task_id}: {message}")
|
|||
|
|
|||
|
@sio.event
|
|||
|
def training_update(data):
|
|||
|
task_id = data.get('task_id', 'unknown')[:8]
|
|||
|
status = data.get('status', 'unknown')
|
|||
|
progress = data.get('progress', 0)
|
|||
|
print(f"🔄 [训练更新] {task_id}: {status} ({progress}%)")
|
|||
|
|
|||
|
@sio.on('*')
|
|||
|
def catch_all(event, data):
|
|||
|
print(f"📨 收到事件: {event} -> {data}")
|
|||
|
|
|||
|
# 尝试连接WebSocket
|
|||
|
try:
|
|||
|
print(f"\n🔗 尝试连接WebSocket: {api_base}")
|
|||
|
sio.connect(f"{api_base}", namespaces=['/training'])
|
|||
|
|
|||
|
# 等待连接建立
|
|||
|
timeout = 10
|
|||
|
for i in range(timeout):
|
|||
|
if connected:
|
|||
|
break
|
|||
|
time.sleep(1)
|
|||
|
print(f"⏳ 等待连接... ({i+1}/{timeout})")
|
|||
|
|
|||
|
if connected:
|
|||
|
print(f"🎉 WebSocket连接测试成功!")
|
|||
|
|
|||
|
# 测试保持连接几秒钟
|
|||
|
print(f"\n⏱️ 保持连接5秒钟,监听事件...")
|
|||
|
time.sleep(5)
|
|||
|
|
|||
|
# 断开连接
|
|||
|
sio.disconnect()
|
|||
|
print(f"👋 主动断开WebSocket连接")
|
|||
|
|
|||
|
else:
|
|||
|
print(f"❌ WebSocket连接超时")
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
connection_error = str(e)
|
|||
|
print(f"❌ WebSocket连接失败: {e}")
|
|||
|
|
|||
|
# 测试HTTP轮询备用方案
|
|||
|
try:
|
|||
|
print(f"\n🔄 测试HTTP轮询方式...")
|
|||
|
sio_polling = socketio.Client(logger=False, engineio_logger=False)
|
|||
|
|
|||
|
polling_connected = False
|
|||
|
|
|||
|
@sio_polling.event
|
|||
|
def connect():
|
|||
|
nonlocal polling_connected
|
|||
|
polling_connected = True
|
|||
|
print("✅ HTTP轮询连接成功")
|
|||
|
|
|||
|
# 使用轮询传输方式
|
|||
|
sio_polling.connect(f"{api_base}", transports=['polling'], namespaces=['/training'])
|
|||
|
|
|||
|
# 等待连接
|
|||
|
for i in range(5):
|
|||
|
if polling_connected:
|
|||
|
break
|
|||
|
time.sleep(1)
|
|||
|
|
|||
|
if polling_connected:
|
|||
|
print("🎉 HTTP轮询备用方案工作正常")
|
|||
|
sio_polling.disconnect()
|
|||
|
else:
|
|||
|
print("⚠️ HTTP轮询也无法连接")
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"⚠️ HTTP轮询测试失败: {e}")
|
|||
|
|
|||
|
print("\n" + "=" * 60)
|
|||
|
print("🎯 WebSocket测试结果:")
|
|||
|
|
|||
|
if connected:
|
|||
|
print("✅ WebSocket协议: 正常")
|
|||
|
else:
|
|||
|
print("❌ WebSocket协议: 失败")
|
|||
|
if connection_error:
|
|||
|
print(f" 错误详情: {connection_error}")
|
|||
|
|
|||
|
if polling_connected:
|
|||
|
print("✅ HTTP轮询备用: 正常")
|
|||
|
else:
|
|||
|
print("⚠️ HTTP轮询备用: 需要检查")
|
|||
|
|
|||
|
print("\n💡 修复建议:")
|
|||
|
if not connected and not polling_connected:
|
|||
|
print("1. 检查Flask-SocketIO版本兼容性")
|
|||
|
print("2. 确认防火墙和端口设置")
|
|||
|
print("3. 检查服务器日志输出")
|
|||
|
elif connected:
|
|||
|
print("✅ WebSocket修复成功,可以正常使用实时功能")
|
|||
|
|
|||
|
print("=" * 60)
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
test_websocket_connection()
|