165 lines
4.7 KiB
Python
165 lines
4.7 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
测试WebSocket修复效果
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import time
|
||
import requests
|
||
import socketio
|
||
|
||
# 设置环境变量
|
||
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
||
|
||
def test_websocket_connection():
|
||
"""测试WebSocket连接"""
|
||
|
||
print("=" * 60)
|
||
print("🧪 WebSocket连接修复测试")
|
||
print("=" * 60)
|
||
|
||
# API服务器地址
|
||
api_base = "http://localhost:5000"
|
||
|
||
# 测试HTTP连接
|
||
try:
|
||
response = requests.get(f"{api_base}/api/health", timeout=5)
|
||
if response.status_code == 200:
|
||
print("✅ HTTP API服务器连接正常")
|
||
else:
|
||
print(f"⚠️ HTTP API服务器响应异常: {response.status_code}")
|
||
return
|
||
except Exception as e:
|
||
print(f"❌ 无法连接到API服务器: {e}")
|
||
print("💡 请先启动API服务器: PYTHONIOENCODING=utf-8 uv run server/api.py")
|
||
return
|
||
|
||
# 创建SocketIO客户端
|
||
sio = socketio.Client(logger=True, engineio_logger=False)
|
||
|
||
# 连接状态
|
||
connected = False
|
||
connection_error = None
|
||
|
||
@sio.event
|
||
def connect():
|
||
nonlocal connected
|
||
connected = True
|
||
print("✅ WebSocket连接成功")
|
||
|
||
@sio.event
|
||
def disconnect():
|
||
print("🔌 WebSocket连接断开")
|
||
|
||
@sio.event
|
||
def connection_established(data):
|
||
print(f"📡 收到连接确认: {data}")
|
||
|
||
@sio.event
|
||
def training_progress_detailed(data):
|
||
task_id = data.get('task_id', 'unknown')[:8]
|
||
message = data.get('message', 'No message')
|
||
print(f"📊 [训练进度] {task_id}: {message}")
|
||
|
||
@sio.event
|
||
def training_update(data):
|
||
task_id = data.get('task_id', 'unknown')[:8]
|
||
status = data.get('status', 'unknown')
|
||
progress = data.get('progress', 0)
|
||
print(f"🔄 [训练更新] {task_id}: {status} ({progress}%)")
|
||
|
||
@sio.on('*')
|
||
def catch_all(event, data):
|
||
print(f"📨 收到事件: {event} -> {data}")
|
||
|
||
# 尝试连接WebSocket
|
||
try:
|
||
print(f"\n🔗 尝试连接WebSocket: {api_base}")
|
||
sio.connect(f"{api_base}", namespaces=['/training'])
|
||
|
||
# 等待连接建立
|
||
timeout = 10
|
||
for i in range(timeout):
|
||
if connected:
|
||
break
|
||
time.sleep(1)
|
||
print(f"⏳ 等待连接... ({i+1}/{timeout})")
|
||
|
||
if connected:
|
||
print(f"🎉 WebSocket连接测试成功!")
|
||
|
||
# 测试保持连接几秒钟
|
||
print(f"\n⏱️ 保持连接5秒钟,监听事件...")
|
||
time.sleep(5)
|
||
|
||
# 断开连接
|
||
sio.disconnect()
|
||
print(f"👋 主动断开WebSocket连接")
|
||
|
||
else:
|
||
print(f"❌ WebSocket连接超时")
|
||
|
||
except Exception as e:
|
||
connection_error = str(e)
|
||
print(f"❌ WebSocket连接失败: {e}")
|
||
|
||
# 测试HTTP轮询备用方案
|
||
try:
|
||
print(f"\n🔄 测试HTTP轮询方式...")
|
||
sio_polling = socketio.Client(logger=False, engineio_logger=False)
|
||
|
||
polling_connected = False
|
||
|
||
@sio_polling.event
|
||
def connect():
|
||
nonlocal polling_connected
|
||
polling_connected = True
|
||
print("✅ HTTP轮询连接成功")
|
||
|
||
# 使用轮询传输方式
|
||
sio_polling.connect(f"{api_base}", transports=['polling'], namespaces=['/training'])
|
||
|
||
# 等待连接
|
||
for i in range(5):
|
||
if polling_connected:
|
||
break
|
||
time.sleep(1)
|
||
|
||
if polling_connected:
|
||
print("🎉 HTTP轮询备用方案工作正常")
|
||
sio_polling.disconnect()
|
||
else:
|
||
print("⚠️ HTTP轮询也无法连接")
|
||
|
||
except Exception as e:
|
||
print(f"⚠️ HTTP轮询测试失败: {e}")
|
||
|
||
print("\n" + "=" * 60)
|
||
print("🎯 WebSocket测试结果:")
|
||
|
||
if connected:
|
||
print("✅ WebSocket协议: 正常")
|
||
else:
|
||
print("❌ WebSocket协议: 失败")
|
||
if connection_error:
|
||
print(f" 错误详情: {connection_error}")
|
||
|
||
if polling_connected:
|
||
print("✅ HTTP轮询备用: 正常")
|
||
else:
|
||
print("⚠️ HTTP轮询备用: 需要检查")
|
||
|
||
print("\n💡 修复建议:")
|
||
if not connected and not polling_connected:
|
||
print("1. 检查Flask-SocketIO版本兼容性")
|
||
print("2. 确认防火墙和端口设置")
|
||
print("3. 检查服务器日志输出")
|
||
elif connected:
|
||
print("✅ WebSocket修复成功,可以正常使用实时功能")
|
||
|
||
print("=" * 60)
|
||
|
||
if __name__ == "__main__":
|
||
test_websocket_connection() |