175 lines
5.9 KiB
Python
175 lines
5.9 KiB
Python
![]() |
#!/usr/bin/env python3
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
"""
|
|||
|
测试CORS跨域请求修复效果
|
|||
|
"""
|
|||
|
|
|||
|
import os
|
|||
|
import sys
|
|||
|
import requests
|
|||
|
import json
|
|||
|
|
|||
|
# 设置环境变量
|
|||
|
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
|||
|
|
|||
|
def test_cors_requests():
|
|||
|
"""测试CORS跨域请求"""
|
|||
|
|
|||
|
print("=" * 60)
|
|||
|
print("🧪 CORS跨域请求修复测试")
|
|||
|
print("=" * 60)
|
|||
|
|
|||
|
# API服务器地址
|
|||
|
api_base = "http://127.0.0.1:5000"
|
|||
|
|
|||
|
# 模拟不同来源的请求
|
|||
|
test_origins = [
|
|||
|
"http://localhost:5173", # Vue开发服务器
|
|||
|
"http://127.0.0.1:5173", # 数字IP版本
|
|||
|
"http://localhost:3000", # 备用端口
|
|||
|
None # 无Origin头(直接访问)
|
|||
|
]
|
|||
|
|
|||
|
# 测试端点列表
|
|||
|
test_endpoints = [
|
|||
|
("/api/health", "GET"),
|
|||
|
("/api/training", "GET"),
|
|||
|
("/api/stores", "GET"),
|
|||
|
("/api/products", "GET")
|
|||
|
]
|
|||
|
|
|||
|
print(f"🎯 测试目标: {api_base}")
|
|||
|
print(f"🔍 测试端点数量: {len(test_endpoints)}")
|
|||
|
print(f"🌐 测试来源数量: {len(test_origins)}")
|
|||
|
print()
|
|||
|
|
|||
|
# 测试结果统计
|
|||
|
success_count = 0
|
|||
|
total_tests = 0
|
|||
|
|
|||
|
for origin in test_origins:
|
|||
|
origin_name = origin if origin else "Direct Access"
|
|||
|
print(f"📡 测试来源: {origin_name}")
|
|||
|
print("-" * 40)
|
|||
|
|
|||
|
# 设置请求头
|
|||
|
headers = {
|
|||
|
'Content-Type': 'application/json',
|
|||
|
'Accept': 'application/json'
|
|||
|
}
|
|||
|
if origin:
|
|||
|
headers['Origin'] = origin
|
|||
|
|
|||
|
for endpoint, method in test_endpoints:
|
|||
|
total_tests += 1
|
|||
|
url = f"{api_base}{endpoint}"
|
|||
|
|
|||
|
try:
|
|||
|
# 首先发送OPTIONS预检请求
|
|||
|
if origin:
|
|||
|
options_response = requests.options(
|
|||
|
url,
|
|||
|
headers=headers,
|
|||
|
timeout=5
|
|||
|
)
|
|||
|
|
|||
|
# 检查预检请求响应
|
|||
|
cors_headers = {
|
|||
|
'Access-Control-Allow-Origin': options_response.headers.get('Access-Control-Allow-Origin'),
|
|||
|
'Access-Control-Allow-Methods': options_response.headers.get('Access-Control-Allow-Methods'),
|
|||
|
'Access-Control-Allow-Headers': options_response.headers.get('Access-Control-Allow-Headers')
|
|||
|
}
|
|||
|
|
|||
|
if options_response.status_code != 200:
|
|||
|
print(f" ❌ {method} {endpoint}: 预检失败 ({options_response.status_code})")
|
|||
|
continue
|
|||
|
|
|||
|
# 发送实际请求
|
|||
|
if method == "GET":
|
|||
|
response = requests.get(url, headers=headers, timeout=5)
|
|||
|
elif method == "POST":
|
|||
|
response = requests.post(url, headers=headers, json={}, timeout=5)
|
|||
|
else:
|
|||
|
continue
|
|||
|
|
|||
|
# 检查响应
|
|||
|
if response.status_code == 200:
|
|||
|
# 检查CORS头
|
|||
|
allow_origin = response.headers.get('Access-Control-Allow-Origin')
|
|||
|
if allow_origin and (allow_origin == '*' or allow_origin == origin):
|
|||
|
print(f" ✅ {method} {endpoint}: 成功 (CORS: {allow_origin})")
|
|||
|
success_count += 1
|
|||
|
else:
|
|||
|
print(f" ⚠️ {method} {endpoint}: 成功但CORS头缺失")
|
|||
|
success_count += 1
|
|||
|
else:
|
|||
|
print(f" ❌ {method} {endpoint}: HTTP {response.status_code}")
|
|||
|
|
|||
|
except requests.exceptions.ConnectionError:
|
|||
|
print(f" ❌ {method} {endpoint}: 连接失败 (服务器未启动?)")
|
|||
|
except requests.exceptions.Timeout:
|
|||
|
print(f" ❌ {method} {endpoint}: 请求超时")
|
|||
|
except Exception as e:
|
|||
|
print(f" ❌ {method} {endpoint}: 错误 - {e}")
|
|||
|
|
|||
|
print()
|
|||
|
|
|||
|
# 专门测试训练端点的POST请求
|
|||
|
print("🚀 专项测试: 训练端点POST请求")
|
|||
|
print("-" * 40)
|
|||
|
|
|||
|
training_data = {
|
|||
|
"product_id": "P001",
|
|||
|
"model_type": "transformer",
|
|||
|
"epochs": 1,
|
|||
|
"training_mode": "product"
|
|||
|
}
|
|||
|
|
|||
|
headers = {
|
|||
|
'Content-Type': 'application/json',
|
|||
|
'Origin': 'http://localhost:5173'
|
|||
|
}
|
|||
|
|
|||
|
try:
|
|||
|
# 测试start_training端点
|
|||
|
response = requests.post(
|
|||
|
f"{api_base}/api/start_training",
|
|||
|
headers=headers,
|
|||
|
json=training_data,
|
|||
|
timeout=10
|
|||
|
)
|
|||
|
|
|||
|
if response.status_code == 200:
|
|||
|
result = response.json()
|
|||
|
print(f" ✅ POST /api/start_training: 成功")
|
|||
|
print(f" 任务ID: {result.get('task_id', 'N/A')}")
|
|||
|
else:
|
|||
|
print(f" ❌ POST /api/start_training: HTTP {response.status_code}")
|
|||
|
print(f" 错误: {response.text}")
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f" ❌ POST /api/start_training: 错误 - {e}")
|
|||
|
|
|||
|
# 测试结果总结
|
|||
|
print("\n" + "=" * 60)
|
|||
|
print("🎯 CORS测试结果总结:")
|
|||
|
print(f"✅ 成功: {success_count}/{total_tests} ({success_count/total_tests*100:.1f}%)")
|
|||
|
|
|||
|
if success_count == total_tests:
|
|||
|
print("🎉 CORS配置完全正常!")
|
|||
|
elif success_count > total_tests * 0.8:
|
|||
|
print("⚠️ CORS配置基本正常,少数端点需要检查")
|
|||
|
else:
|
|||
|
print("❌ CORS配置存在问题,需要进一步调试")
|
|||
|
|
|||
|
print("\n💡 故障排除建议:")
|
|||
|
if success_count < total_tests:
|
|||
|
print("1. 确认API服务器已启动: ./启动API服务器-修复版.bat")
|
|||
|
print("2. 检查防火墙设置")
|
|||
|
print("3. 验证端口5000是否被占用")
|
|||
|
print("4. 查看服务器控制台日志")
|
|||
|
|
|||
|
print("=" * 60)
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
test_cors_requests()
|