#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 测试CORS跨域请求修复效果 """ import os import sys import requests import json # 设置环境变量 os.environ['PYTHONIOENCODING'] = 'utf-8' def test_cors_requests(): """测试CORS跨域请求""" print("=" * 60) print("🧪 CORS跨域请求修复测试") print("=" * 60) # API服务器地址 api_base = "http://127.0.0.1:5000" # 模拟不同来源的请求 test_origins = [ "http://localhost:5173", # Vue开发服务器 "http://127.0.0.1:5173", # 数字IP版本 "http://localhost:3000", # 备用端口 None # 无Origin头(直接访问) ] # 测试端点列表 test_endpoints = [ ("/api/health", "GET"), ("/api/training", "GET"), ("/api/stores", "GET"), ("/api/products", "GET") ] print(f"🎯 测试目标: {api_base}") print(f"🔍 测试端点数量: {len(test_endpoints)}") print(f"🌐 测试来源数量: {len(test_origins)}") print() # 测试结果统计 success_count = 0 total_tests = 0 for origin in test_origins: origin_name = origin if origin else "Direct Access" print(f"📡 测试来源: {origin_name}") print("-" * 40) # 设置请求头 headers = { 'Content-Type': 'application/json', 'Accept': 'application/json' } if origin: headers['Origin'] = origin for endpoint, method in test_endpoints: total_tests += 1 url = f"{api_base}{endpoint}" try: # 首先发送OPTIONS预检请求 if origin: options_response = requests.options( url, headers=headers, timeout=5 ) # 检查预检请求响应 cors_headers = { 'Access-Control-Allow-Origin': options_response.headers.get('Access-Control-Allow-Origin'), 'Access-Control-Allow-Methods': options_response.headers.get('Access-Control-Allow-Methods'), 'Access-Control-Allow-Headers': options_response.headers.get('Access-Control-Allow-Headers') } if options_response.status_code != 200: print(f" ❌ {method} {endpoint}: 预检失败 ({options_response.status_code})") continue # 发送实际请求 if method == "GET": response = requests.get(url, headers=headers, timeout=5) elif method == "POST": response = requests.post(url, headers=headers, json={}, timeout=5) else: continue # 检查响应 if response.status_code == 200: # 检查CORS头 allow_origin = response.headers.get('Access-Control-Allow-Origin') if allow_origin and (allow_origin == '*' or allow_origin == origin): print(f" ✅ {method} {endpoint}: 成功 (CORS: {allow_origin})") success_count += 1 else: print(f" ⚠️ {method} {endpoint}: 成功但CORS头缺失") success_count += 1 else: print(f" ❌ {method} {endpoint}: HTTP {response.status_code}") except requests.exceptions.ConnectionError: print(f" ❌ {method} {endpoint}: 连接失败 (服务器未启动?)") except requests.exceptions.Timeout: print(f" ❌ {method} {endpoint}: 请求超时") except Exception as e: print(f" ❌ {method} {endpoint}: 错误 - {e}") print() # 专门测试训练端点的POST请求 print("🚀 专项测试: 训练端点POST请求") print("-" * 40) training_data = { "product_id": "P001", "model_type": "transformer", "epochs": 1, "training_mode": "product" } headers = { 'Content-Type': 'application/json', 'Origin': 'http://localhost:5173' } try: # 测试start_training端点 response = requests.post( f"{api_base}/api/start_training", headers=headers, json=training_data, timeout=10 ) if response.status_code == 200: result = response.json() print(f" ✅ POST /api/start_training: 成功") print(f" 任务ID: {result.get('task_id', 'N/A')}") else: print(f" ❌ POST /api/start_training: HTTP {response.status_code}") print(f" 错误: {response.text}") except Exception as e: print(f" ❌ POST /api/start_training: 错误 - {e}") # 测试结果总结 print("\n" + "=" * 60) print("🎯 CORS测试结果总结:") print(f"✅ 成功: {success_count}/{total_tests} ({success_count/total_tests*100:.1f}%)") if success_count == total_tests: print("🎉 CORS配置完全正常!") elif success_count > total_tests * 0.8: print("⚠️ CORS配置基本正常,少数端点需要检查") else: print("❌ CORS配置存在问题,需要进一步调试") print("\n💡 故障排除建议:") if success_count < total_tests: print("1. 确认API服务器已启动: ./启动API服务器-修复版.bat") print("2. 检查防火墙设置") print("3. 验证端口5000是否被占用") print("4. 查看服务器控制台日志") print("=" * 60) if __name__ == "__main__": test_cors_requests()