ShopTRAINING/test/test_cors_fix.py

175 lines
5.9 KiB
Python
Raw Normal View History

2025-07-02 11:05:23 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
测试CORS跨域请求修复效果
"""
import os
import sys
import requests
import json
# 设置环境变量
os.environ['PYTHONIOENCODING'] = 'utf-8'
def test_cors_requests():
"""测试CORS跨域请求"""
print("=" * 60)
print("🧪 CORS跨域请求修复测试")
print("=" * 60)
# API服务器地址
api_base = "http://127.0.0.1:5000"
# 模拟不同来源的请求
test_origins = [
"http://localhost:5173", # Vue开发服务器
"http://127.0.0.1:5173", # 数字IP版本
"http://localhost:3000", # 备用端口
None # 无Origin头直接访问
]
# 测试端点列表
test_endpoints = [
("/api/health", "GET"),
("/api/training", "GET"),
("/api/stores", "GET"),
("/api/products", "GET")
]
print(f"🎯 测试目标: {api_base}")
print(f"🔍 测试端点数量: {len(test_endpoints)}")
print(f"🌐 测试来源数量: {len(test_origins)}")
print()
# 测试结果统计
success_count = 0
total_tests = 0
for origin in test_origins:
origin_name = origin if origin else "Direct Access"
print(f"📡 测试来源: {origin_name}")
print("-" * 40)
# 设置请求头
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json'
}
if origin:
headers['Origin'] = origin
for endpoint, method in test_endpoints:
total_tests += 1
url = f"{api_base}{endpoint}"
try:
# 首先发送OPTIONS预检请求
if origin:
options_response = requests.options(
url,
headers=headers,
timeout=5
)
# 检查预检请求响应
cors_headers = {
'Access-Control-Allow-Origin': options_response.headers.get('Access-Control-Allow-Origin'),
'Access-Control-Allow-Methods': options_response.headers.get('Access-Control-Allow-Methods'),
'Access-Control-Allow-Headers': options_response.headers.get('Access-Control-Allow-Headers')
}
if options_response.status_code != 200:
print(f"{method} {endpoint}: 预检失败 ({options_response.status_code})")
continue
# 发送实际请求
if method == "GET":
response = requests.get(url, headers=headers, timeout=5)
elif method == "POST":
response = requests.post(url, headers=headers, json={}, timeout=5)
else:
continue
# 检查响应
if response.status_code == 200:
# 检查CORS头
allow_origin = response.headers.get('Access-Control-Allow-Origin')
if allow_origin and (allow_origin == '*' or allow_origin == origin):
print(f"{method} {endpoint}: 成功 (CORS: {allow_origin})")
success_count += 1
else:
print(f" ⚠️ {method} {endpoint}: 成功但CORS头缺失")
success_count += 1
else:
print(f"{method} {endpoint}: HTTP {response.status_code}")
except requests.exceptions.ConnectionError:
print(f"{method} {endpoint}: 连接失败 (服务器未启动?)")
except requests.exceptions.Timeout:
print(f"{method} {endpoint}: 请求超时")
except Exception as e:
print(f"{method} {endpoint}: 错误 - {e}")
print()
# 专门测试训练端点的POST请求
print("🚀 专项测试: 训练端点POST请求")
print("-" * 40)
training_data = {
"product_id": "P001",
"model_type": "transformer",
"epochs": 1,
"training_mode": "product"
}
headers = {
'Content-Type': 'application/json',
'Origin': 'http://localhost:5173'
}
try:
# 测试start_training端点
response = requests.post(
f"{api_base}/api/start_training",
headers=headers,
json=training_data,
timeout=10
)
if response.status_code == 200:
result = response.json()
print(f" ✅ POST /api/start_training: 成功")
print(f" 任务ID: {result.get('task_id', 'N/A')}")
else:
print(f" ❌ POST /api/start_training: HTTP {response.status_code}")
print(f" 错误: {response.text}")
except Exception as e:
print(f" ❌ POST /api/start_training: 错误 - {e}")
# 测试结果总结
print("\n" + "=" * 60)
print("🎯 CORS测试结果总结:")
print(f"✅ 成功: {success_count}/{total_tests} ({success_count/total_tests*100:.1f}%)")
if success_count == total_tests:
print("🎉 CORS配置完全正常!")
elif success_count > total_tests * 0.8:
print("⚠️ CORS配置基本正常少数端点需要检查")
else:
print("❌ CORS配置存在问题需要进一步调试")
print("\n💡 故障排除建议:")
if success_count < total_tests:
print("1. 确认API服务器已启动: ./启动API服务器-修复版.bat")
print("2. 检查防火墙设置")
print("3. 验证端口5000是否被占用")
print("4. 查看服务器控制台日志")
print("=" * 60)
if __name__ == "__main__":
test_cors_requests()