ShopTRAINING/test/test_cors_fix.py
2025-07-02 11:05:23 +08:00

175 lines
5.9 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
测试CORS跨域请求修复效果
"""
import os
import sys
import requests
import json
# 设置环境变量
os.environ['PYTHONIOENCODING'] = 'utf-8'
def test_cors_requests():
"""测试CORS跨域请求"""
print("=" * 60)
print("🧪 CORS跨域请求修复测试")
print("=" * 60)
# API服务器地址
api_base = "http://127.0.0.1:5000"
# 模拟不同来源的请求
test_origins = [
"http://localhost:5173", # Vue开发服务器
"http://127.0.0.1:5173", # 数字IP版本
"http://localhost:3000", # 备用端口
None # 无Origin头直接访问
]
# 测试端点列表
test_endpoints = [
("/api/health", "GET"),
("/api/training", "GET"),
("/api/stores", "GET"),
("/api/products", "GET")
]
print(f"🎯 测试目标: {api_base}")
print(f"🔍 测试端点数量: {len(test_endpoints)}")
print(f"🌐 测试来源数量: {len(test_origins)}")
print()
# 测试结果统计
success_count = 0
total_tests = 0
for origin in test_origins:
origin_name = origin if origin else "Direct Access"
print(f"📡 测试来源: {origin_name}")
print("-" * 40)
# 设置请求头
headers = {
'Content-Type': 'application/json',
'Accept': 'application/json'
}
if origin:
headers['Origin'] = origin
for endpoint, method in test_endpoints:
total_tests += 1
url = f"{api_base}{endpoint}"
try:
# 首先发送OPTIONS预检请求
if origin:
options_response = requests.options(
url,
headers=headers,
timeout=5
)
# 检查预检请求响应
cors_headers = {
'Access-Control-Allow-Origin': options_response.headers.get('Access-Control-Allow-Origin'),
'Access-Control-Allow-Methods': options_response.headers.get('Access-Control-Allow-Methods'),
'Access-Control-Allow-Headers': options_response.headers.get('Access-Control-Allow-Headers')
}
if options_response.status_code != 200:
print(f"{method} {endpoint}: 预检失败 ({options_response.status_code})")
continue
# 发送实际请求
if method == "GET":
response = requests.get(url, headers=headers, timeout=5)
elif method == "POST":
response = requests.post(url, headers=headers, json={}, timeout=5)
else:
continue
# 检查响应
if response.status_code == 200:
# 检查CORS头
allow_origin = response.headers.get('Access-Control-Allow-Origin')
if allow_origin and (allow_origin == '*' or allow_origin == origin):
print(f"{method} {endpoint}: 成功 (CORS: {allow_origin})")
success_count += 1
else:
print(f" ⚠️ {method} {endpoint}: 成功但CORS头缺失")
success_count += 1
else:
print(f"{method} {endpoint}: HTTP {response.status_code}")
except requests.exceptions.ConnectionError:
print(f"{method} {endpoint}: 连接失败 (服务器未启动?)")
except requests.exceptions.Timeout:
print(f"{method} {endpoint}: 请求超时")
except Exception as e:
print(f"{method} {endpoint}: 错误 - {e}")
print()
# 专门测试训练端点的POST请求
print("🚀 专项测试: 训练端点POST请求")
print("-" * 40)
training_data = {
"product_id": "P001",
"model_type": "transformer",
"epochs": 1,
"training_mode": "product"
}
headers = {
'Content-Type': 'application/json',
'Origin': 'http://localhost:5173'
}
try:
# 测试start_training端点
response = requests.post(
f"{api_base}/api/start_training",
headers=headers,
json=training_data,
timeout=10
)
if response.status_code == 200:
result = response.json()
print(f" ✅ POST /api/start_training: 成功")
print(f" 任务ID: {result.get('task_id', 'N/A')}")
else:
print(f" ❌ POST /api/start_training: HTTP {response.status_code}")
print(f" 错误: {response.text}")
except Exception as e:
print(f" ❌ POST /api/start_training: 错误 - {e}")
# 测试结果总结
print("\n" + "=" * 60)
print("🎯 CORS测试结果总结:")
print(f"✅ 成功: {success_count}/{total_tests} ({success_count/total_tests*100:.1f}%)")
if success_count == total_tests:
print("🎉 CORS配置完全正常!")
elif success_count > total_tests * 0.8:
print("⚠️ CORS配置基本正常少数端点需要检查")
else:
print("❌ CORS配置存在问题需要进一步调试")
print("\n💡 故障排除建议:")
if success_count < total_tests:
print("1. 确认API服务器已启动: ./启动API服务器-修复版.bat")
print("2. 检查防火墙设置")
print("3. 验证端口5000是否被占用")
print("4. 查看服务器控制台日志")
print("=" * 60)
if __name__ == "__main__":
test_cors_requests()