ShopTRAINING/test/test_websocket_fix.py

165 lines
4.7 KiB
Python
Raw Permalink Normal View History

2025-07-02 11:05:23 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
测试WebSocket修复效果
"""
import os
import sys
import time
import requests
import socketio
# 设置环境变量
os.environ['PYTHONIOENCODING'] = 'utf-8'
def test_websocket_connection():
"""测试WebSocket连接"""
print("=" * 60)
print("🧪 WebSocket连接修复测试")
print("=" * 60)
# API服务器地址
api_base = "http://localhost:5000"
# 测试HTTP连接
try:
response = requests.get(f"{api_base}/api/health", timeout=5)
if response.status_code == 200:
print("✅ HTTP API服务器连接正常")
else:
print(f"⚠️ HTTP API服务器响应异常: {response.status_code}")
return
except Exception as e:
print(f"❌ 无法连接到API服务器: {e}")
print("💡 请先启动API服务器: PYTHONIOENCODING=utf-8 uv run server/api.py")
return
# 创建SocketIO客户端
sio = socketio.Client(logger=True, engineio_logger=False)
# 连接状态
connected = False
connection_error = None
@sio.event
def connect():
nonlocal connected
connected = True
print("✅ WebSocket连接成功")
@sio.event
def disconnect():
print("🔌 WebSocket连接断开")
@sio.event
def connection_established(data):
print(f"📡 收到连接确认: {data}")
@sio.event
def training_progress_detailed(data):
task_id = data.get('task_id', 'unknown')[:8]
message = data.get('message', 'No message')
print(f"📊 [训练进度] {task_id}: {message}")
@sio.event
def training_update(data):
task_id = data.get('task_id', 'unknown')[:8]
status = data.get('status', 'unknown')
progress = data.get('progress', 0)
print(f"🔄 [训练更新] {task_id}: {status} ({progress}%)")
@sio.on('*')
def catch_all(event, data):
print(f"📨 收到事件: {event} -> {data}")
# 尝试连接WebSocket
try:
print(f"\n🔗 尝试连接WebSocket: {api_base}")
sio.connect(f"{api_base}", namespaces=['/training'])
# 等待连接建立
timeout = 10
for i in range(timeout):
if connected:
break
time.sleep(1)
print(f"⏳ 等待连接... ({i+1}/{timeout})")
if connected:
print(f"🎉 WebSocket连接测试成功!")
# 测试保持连接几秒钟
print(f"\n⏱️ 保持连接5秒钟监听事件...")
time.sleep(5)
# 断开连接
sio.disconnect()
print(f"👋 主动断开WebSocket连接")
else:
print(f"❌ WebSocket连接超时")
except Exception as e:
connection_error = str(e)
print(f"❌ WebSocket连接失败: {e}")
# 测试HTTP轮询备用方案
try:
print(f"\n🔄 测试HTTP轮询方式...")
sio_polling = socketio.Client(logger=False, engineio_logger=False)
polling_connected = False
@sio_polling.event
def connect():
nonlocal polling_connected
polling_connected = True
print("✅ HTTP轮询连接成功")
# 使用轮询传输方式
sio_polling.connect(f"{api_base}", transports=['polling'], namespaces=['/training'])
# 等待连接
for i in range(5):
if polling_connected:
break
time.sleep(1)
if polling_connected:
print("🎉 HTTP轮询备用方案工作正常")
sio_polling.disconnect()
else:
print("⚠️ HTTP轮询也无法连接")
except Exception as e:
print(f"⚠️ HTTP轮询测试失败: {e}")
print("\n" + "=" * 60)
print("🎯 WebSocket测试结果:")
if connected:
print("✅ WebSocket协议: 正常")
else:
print("❌ WebSocket协议: 失败")
if connection_error:
print(f" 错误详情: {connection_error}")
if polling_connected:
print("✅ HTTP轮询备用: 正常")
else:
print("⚠️ HTTP轮询备用: 需要检查")
print("\n💡 修复建议:")
if not connected and not polling_connected:
print("1. 检查Flask-SocketIO版本兼容性")
print("2. 确认防火墙和端口设置")
print("3. 检查服务器日志输出")
elif connected:
print("✅ WebSocket修复成功可以正常使用实时功能")
print("=" * 60)
if __name__ == "__main__":
test_websocket_connection()