159 lines
5.0 KiB
Python
159 lines
5.0 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
简单的API训练测试 - 使用标准库测试API
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import urllib.request
|
||
import urllib.parse
|
||
import json
|
||
import time
|
||
|
||
# 设置完整的UTF-8编码环境
|
||
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
||
os.environ['PYTHONLEGACYWINDOWSSTDIO'] = '0'
|
||
|
||
# Windows控制台编码设置
|
||
if os.name == 'nt':
|
||
try:
|
||
import subprocess
|
||
subprocess.run(['chcp', '65001'], capture_output=True, shell=True, check=False)
|
||
|
||
# 重新配置标准输出
|
||
if hasattr(sys.stdout, 'reconfigure'):
|
||
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
|
||
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
|
||
except Exception as e:
|
||
print(f"警告: UTF-8编码设置失败: {e}")
|
||
|
||
def send_post_request(url, data):
|
||
"""发送POST请求"""
|
||
try:
|
||
json_data = json.dumps(data).encode('utf-8')
|
||
req = urllib.request.Request(
|
||
url,
|
||
data=json_data,
|
||
headers={'Content-Type': 'application/json'}
|
||
)
|
||
|
||
with urllib.request.urlopen(req, timeout=30) as response:
|
||
return response.getcode(), json.loads(response.read().decode('utf-8'))
|
||
except Exception as e:
|
||
return None, str(e)
|
||
|
||
def send_get_request(url):
|
||
"""发送GET请求"""
|
||
try:
|
||
with urllib.request.urlopen(url, timeout=10) as response:
|
||
return response.getcode(), json.loads(response.read().decode('utf-8'))
|
||
except Exception as e:
|
||
return None, str(e)
|
||
|
||
def test_api_training():
|
||
"""测试API训练功能"""
|
||
|
||
print("🧪 简单API训练测试")
|
||
print("=" * 50)
|
||
|
||
base_url = "http://127.0.0.1:5000"
|
||
|
||
# 0. 测试线程输出
|
||
print("0️⃣ 测试API线程输出机制...")
|
||
status, response = send_post_request(f"{base_url}/api/test-thread-output", {})
|
||
if status == 200:
|
||
print("✅ 线程输出测试成功")
|
||
else:
|
||
print(f"❌ 线程输出测试失败: {response}")
|
||
|
||
print("\n1️⃣ 测试简化训练...")
|
||
status, response = send_post_request(f"{base_url}/api/test-training-simple", {})
|
||
if status == 200:
|
||
print("✅ 简化训练测试成功")
|
||
else:
|
||
print(f"❌ 简化训练测试失败: {response}")
|
||
|
||
print("\n" + "="*50)
|
||
|
||
# 1. 测试API连接
|
||
print("🔍 测试API连接...")
|
||
status, response = send_get_request(f"{base_url}/api/products")
|
||
|
||
if status != 200:
|
||
print(f"❌ API连接失败: {response}")
|
||
print("💡 请确保API服务器已启动")
|
||
return
|
||
|
||
print("✅ API连接成功")
|
||
|
||
# 2. 获取产品信息
|
||
products = response.get('data', [])
|
||
if not products:
|
||
print("❌ 没有产品数据")
|
||
return
|
||
|
||
product_id = products[0]['product_id']
|
||
product_name = products[0]['product_name']
|
||
print(f"🎯 选择产品: {product_id} - {product_name}")
|
||
|
||
# 3. 发送训练请求
|
||
print(f"\n🚀 发送训练请求...")
|
||
training_data = {
|
||
"product_id": product_id,
|
||
"model_type": "transformer",
|
||
"epochs": 2, # 快速测试
|
||
"training_mode": "product"
|
||
}
|
||
|
||
print(f"📊 训练配置: {json.dumps(training_data, ensure_ascii=False)}")
|
||
print("\n" + "="*60)
|
||
print("🔍 请密切观察API服务器控制台输出!")
|
||
print("🔍 应该看到以下调试信息:")
|
||
print(" - 🚀🚀🚀 THREAD START: 准备启动训练线程")
|
||
print(" - 🔥🔥🔥 TRAIN_TASK ENTRY: 函数已被调用!")
|
||
print(" - 🚀 训练任务开始")
|
||
print(" - 训练器调用和进度信息")
|
||
print("="*60)
|
||
|
||
# 发送训练请求
|
||
status, response = send_post_request(f"{base_url}/api/training", training_data)
|
||
|
||
if status == 200:
|
||
task_id = response.get('task_id')
|
||
print(f"\n✅ 训练请求成功")
|
||
print(f"🆔 任务ID: {task_id}")
|
||
print(f"💬 消息: {response.get('message')}")
|
||
|
||
# 等待一段时间让训练开始
|
||
print(f"\n⏳ 等待训练开始...")
|
||
time.sleep(3)
|
||
|
||
# 检查任务状态
|
||
print(f"📊 检查训练状态...")
|
||
status, response = send_get_request(f"{base_url}/api/training")
|
||
|
||
if status == 200:
|
||
tasks = response.get('data', [])
|
||
for task in tasks:
|
||
if task.get('task_id') == task_id:
|
||
print(f"📋 任务状态: {task.get('status')}")
|
||
if task.get('metrics'):
|
||
print(f"📈 指标: {task.get('metrics')}")
|
||
if task.get('error'):
|
||
print(f"❌ 错误: {task.get('error')}")
|
||
break
|
||
|
||
else:
|
||
print(f"❌ 训练请求失败")
|
||
print(f"🔴 错误: {response}")
|
||
|
||
print("\n" + "=" * 50)
|
||
print("🎯 测试重点:")
|
||
print("1. 检查API服务器控制台是否有调试输出")
|
||
print("2. 确认线程是否正常启动")
|
||
print("3. 观察训练函数是否被调用")
|
||
print("=" * 50)
|
||
|
||
if __name__ == "__main__":
|
||
test_api_training() |