175 lines
5.9 KiB
Python
175 lines
5.9 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
测试CORS跨域请求修复效果
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import requests
|
||
import json
|
||
|
||
# 设置环境变量
|
||
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
||
|
||
def test_cors_requests():
|
||
"""测试CORS跨域请求"""
|
||
|
||
print("=" * 60)
|
||
print("🧪 CORS跨域请求修复测试")
|
||
print("=" * 60)
|
||
|
||
# API服务器地址
|
||
api_base = "http://127.0.0.1:5000"
|
||
|
||
# 模拟不同来源的请求
|
||
test_origins = [
|
||
"http://localhost:5173", # Vue开发服务器
|
||
"http://127.0.0.1:5173", # 数字IP版本
|
||
"http://localhost:3000", # 备用端口
|
||
None # 无Origin头(直接访问)
|
||
]
|
||
|
||
# 测试端点列表
|
||
test_endpoints = [
|
||
("/api/health", "GET"),
|
||
("/api/training", "GET"),
|
||
("/api/stores", "GET"),
|
||
("/api/products", "GET")
|
||
]
|
||
|
||
print(f"🎯 测试目标: {api_base}")
|
||
print(f"🔍 测试端点数量: {len(test_endpoints)}")
|
||
print(f"🌐 测试来源数量: {len(test_origins)}")
|
||
print()
|
||
|
||
# 测试结果统计
|
||
success_count = 0
|
||
total_tests = 0
|
||
|
||
for origin in test_origins:
|
||
origin_name = origin if origin else "Direct Access"
|
||
print(f"📡 测试来源: {origin_name}")
|
||
print("-" * 40)
|
||
|
||
# 设置请求头
|
||
headers = {
|
||
'Content-Type': 'application/json',
|
||
'Accept': 'application/json'
|
||
}
|
||
if origin:
|
||
headers['Origin'] = origin
|
||
|
||
for endpoint, method in test_endpoints:
|
||
total_tests += 1
|
||
url = f"{api_base}{endpoint}"
|
||
|
||
try:
|
||
# 首先发送OPTIONS预检请求
|
||
if origin:
|
||
options_response = requests.options(
|
||
url,
|
||
headers=headers,
|
||
timeout=5
|
||
)
|
||
|
||
# 检查预检请求响应
|
||
cors_headers = {
|
||
'Access-Control-Allow-Origin': options_response.headers.get('Access-Control-Allow-Origin'),
|
||
'Access-Control-Allow-Methods': options_response.headers.get('Access-Control-Allow-Methods'),
|
||
'Access-Control-Allow-Headers': options_response.headers.get('Access-Control-Allow-Headers')
|
||
}
|
||
|
||
if options_response.status_code != 200:
|
||
print(f" ❌ {method} {endpoint}: 预检失败 ({options_response.status_code})")
|
||
continue
|
||
|
||
# 发送实际请求
|
||
if method == "GET":
|
||
response = requests.get(url, headers=headers, timeout=5)
|
||
elif method == "POST":
|
||
response = requests.post(url, headers=headers, json={}, timeout=5)
|
||
else:
|
||
continue
|
||
|
||
# 检查响应
|
||
if response.status_code == 200:
|
||
# 检查CORS头
|
||
allow_origin = response.headers.get('Access-Control-Allow-Origin')
|
||
if allow_origin and (allow_origin == '*' or allow_origin == origin):
|
||
print(f" ✅ {method} {endpoint}: 成功 (CORS: {allow_origin})")
|
||
success_count += 1
|
||
else:
|
||
print(f" ⚠️ {method} {endpoint}: 成功但CORS头缺失")
|
||
success_count += 1
|
||
else:
|
||
print(f" ❌ {method} {endpoint}: HTTP {response.status_code}")
|
||
|
||
except requests.exceptions.ConnectionError:
|
||
print(f" ❌ {method} {endpoint}: 连接失败 (服务器未启动?)")
|
||
except requests.exceptions.Timeout:
|
||
print(f" ❌ {method} {endpoint}: 请求超时")
|
||
except Exception as e:
|
||
print(f" ❌ {method} {endpoint}: 错误 - {e}")
|
||
|
||
print()
|
||
|
||
# 专门测试训练端点的POST请求
|
||
print("🚀 专项测试: 训练端点POST请求")
|
||
print("-" * 40)
|
||
|
||
training_data = {
|
||
"product_id": "P001",
|
||
"model_type": "transformer",
|
||
"epochs": 1,
|
||
"training_mode": "product"
|
||
}
|
||
|
||
headers = {
|
||
'Content-Type': 'application/json',
|
||
'Origin': 'http://localhost:5173'
|
||
}
|
||
|
||
try:
|
||
# 测试start_training端点
|
||
response = requests.post(
|
||
f"{api_base}/api/start_training",
|
||
headers=headers,
|
||
json=training_data,
|
||
timeout=10
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
print(f" ✅ POST /api/start_training: 成功")
|
||
print(f" 任务ID: {result.get('task_id', 'N/A')}")
|
||
else:
|
||
print(f" ❌ POST /api/start_training: HTTP {response.status_code}")
|
||
print(f" 错误: {response.text}")
|
||
|
||
except Exception as e:
|
||
print(f" ❌ POST /api/start_training: 错误 - {e}")
|
||
|
||
# 测试结果总结
|
||
print("\n" + "=" * 60)
|
||
print("🎯 CORS测试结果总结:")
|
||
print(f"✅ 成功: {success_count}/{total_tests} ({success_count/total_tests*100:.1f}%)")
|
||
|
||
if success_count == total_tests:
|
||
print("🎉 CORS配置完全正常!")
|
||
elif success_count > total_tests * 0.8:
|
||
print("⚠️ CORS配置基本正常,少数端点需要检查")
|
||
else:
|
||
print("❌ CORS配置存在问题,需要进一步调试")
|
||
|
||
print("\n💡 故障排除建议:")
|
||
if success_count < total_tests:
|
||
print("1. 确认API服务器已启动: ./启动API服务器-修复版.bat")
|
||
print("2. 检查防火墙设置")
|
||
print("3. 验证端口5000是否被占用")
|
||
print("4. 查看服务器控制台日志")
|
||
|
||
print("=" * 60)
|
||
|
||
if __name__ == "__main__":
|
||
test_cors_requests() |