159 lines
5.0 KiB
Python
159 lines
5.0 KiB
Python
![]() |
#!/usr/bin/env python3
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
"""
|
|||
|
简单的API训练测试 - 使用标准库测试API
|
|||
|
"""
|
|||
|
|
|||
|
import os
|
|||
|
import sys
|
|||
|
import urllib.request
|
|||
|
import urllib.parse
|
|||
|
import json
|
|||
|
import time
|
|||
|
|
|||
|
# 设置完整的UTF-8编码环境
|
|||
|
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
|||
|
os.environ['PYTHONLEGACYWINDOWSSTDIO'] = '0'
|
|||
|
|
|||
|
# Windows控制台编码设置
|
|||
|
if os.name == 'nt':
|
|||
|
try:
|
|||
|
import subprocess
|
|||
|
subprocess.run(['chcp', '65001'], capture_output=True, shell=True, check=False)
|
|||
|
|
|||
|
# 重新配置标准输出
|
|||
|
if hasattr(sys.stdout, 'reconfigure'):
|
|||
|
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
|
|||
|
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
|
|||
|
except Exception as e:
|
|||
|
print(f"警告: UTF-8编码设置失败: {e}")
|
|||
|
|
|||
|
def send_post_request(url, data):
|
|||
|
"""发送POST请求"""
|
|||
|
try:
|
|||
|
json_data = json.dumps(data).encode('utf-8')
|
|||
|
req = urllib.request.Request(
|
|||
|
url,
|
|||
|
data=json_data,
|
|||
|
headers={'Content-Type': 'application/json'}
|
|||
|
)
|
|||
|
|
|||
|
with urllib.request.urlopen(req, timeout=30) as response:
|
|||
|
return response.getcode(), json.loads(response.read().decode('utf-8'))
|
|||
|
except Exception as e:
|
|||
|
return None, str(e)
|
|||
|
|
|||
|
def send_get_request(url):
|
|||
|
"""发送GET请求"""
|
|||
|
try:
|
|||
|
with urllib.request.urlopen(url, timeout=10) as response:
|
|||
|
return response.getcode(), json.loads(response.read().decode('utf-8'))
|
|||
|
except Exception as e:
|
|||
|
return None, str(e)
|
|||
|
|
|||
|
def test_api_training():
|
|||
|
"""测试API训练功能"""
|
|||
|
|
|||
|
print("🧪 简单API训练测试")
|
|||
|
print("=" * 50)
|
|||
|
|
|||
|
base_url = "http://127.0.0.1:5000"
|
|||
|
|
|||
|
# 0. 测试线程输出
|
|||
|
print("0️⃣ 测试API线程输出机制...")
|
|||
|
status, response = send_post_request(f"{base_url}/api/test-thread-output", {})
|
|||
|
if status == 200:
|
|||
|
print("✅ 线程输出测试成功")
|
|||
|
else:
|
|||
|
print(f"❌ 线程输出测试失败: {response}")
|
|||
|
|
|||
|
print("\n1️⃣ 测试简化训练...")
|
|||
|
status, response = send_post_request(f"{base_url}/api/test-training-simple", {})
|
|||
|
if status == 200:
|
|||
|
print("✅ 简化训练测试成功")
|
|||
|
else:
|
|||
|
print(f"❌ 简化训练测试失败: {response}")
|
|||
|
|
|||
|
print("\n" + "="*50)
|
|||
|
|
|||
|
# 1. 测试API连接
|
|||
|
print("🔍 测试API连接...")
|
|||
|
status, response = send_get_request(f"{base_url}/api/products")
|
|||
|
|
|||
|
if status != 200:
|
|||
|
print(f"❌ API连接失败: {response}")
|
|||
|
print("💡 请确保API服务器已启动")
|
|||
|
return
|
|||
|
|
|||
|
print("✅ API连接成功")
|
|||
|
|
|||
|
# 2. 获取产品信息
|
|||
|
products = response.get('data', [])
|
|||
|
if not products:
|
|||
|
print("❌ 没有产品数据")
|
|||
|
return
|
|||
|
|
|||
|
product_id = products[0]['product_id']
|
|||
|
product_name = products[0]['product_name']
|
|||
|
print(f"🎯 选择产品: {product_id} - {product_name}")
|
|||
|
|
|||
|
# 3. 发送训练请求
|
|||
|
print(f"\n🚀 发送训练请求...")
|
|||
|
training_data = {
|
|||
|
"product_id": product_id,
|
|||
|
"model_type": "transformer",
|
|||
|
"epochs": 2, # 快速测试
|
|||
|
"training_mode": "product"
|
|||
|
}
|
|||
|
|
|||
|
print(f"📊 训练配置: {json.dumps(training_data, ensure_ascii=False)}")
|
|||
|
print("\n" + "="*60)
|
|||
|
print("🔍 请密切观察API服务器控制台输出!")
|
|||
|
print("🔍 应该看到以下调试信息:")
|
|||
|
print(" - 🚀🚀🚀 THREAD START: 准备启动训练线程")
|
|||
|
print(" - 🔥🔥🔥 TRAIN_TASK ENTRY: 函数已被调用!")
|
|||
|
print(" - 🚀 训练任务开始")
|
|||
|
print(" - 训练器调用和进度信息")
|
|||
|
print("="*60)
|
|||
|
|
|||
|
# 发送训练请求
|
|||
|
status, response = send_post_request(f"{base_url}/api/training", training_data)
|
|||
|
|
|||
|
if status == 200:
|
|||
|
task_id = response.get('task_id')
|
|||
|
print(f"\n✅ 训练请求成功")
|
|||
|
print(f"🆔 任务ID: {task_id}")
|
|||
|
print(f"💬 消息: {response.get('message')}")
|
|||
|
|
|||
|
# 等待一段时间让训练开始
|
|||
|
print(f"\n⏳ 等待训练开始...")
|
|||
|
time.sleep(3)
|
|||
|
|
|||
|
# 检查任务状态
|
|||
|
print(f"📊 检查训练状态...")
|
|||
|
status, response = send_get_request(f"{base_url}/api/training")
|
|||
|
|
|||
|
if status == 200:
|
|||
|
tasks = response.get('data', [])
|
|||
|
for task in tasks:
|
|||
|
if task.get('task_id') == task_id:
|
|||
|
print(f"📋 任务状态: {task.get('status')}")
|
|||
|
if task.get('metrics'):
|
|||
|
print(f"📈 指标: {task.get('metrics')}")
|
|||
|
if task.get('error'):
|
|||
|
print(f"❌ 错误: {task.get('error')}")
|
|||
|
break
|
|||
|
|
|||
|
else:
|
|||
|
print(f"❌ 训练请求失败")
|
|||
|
print(f"🔴 错误: {response}")
|
|||
|
|
|||
|
print("\n" + "=" * 50)
|
|||
|
print("🎯 测试重点:")
|
|||
|
print("1. 检查API服务器控制台是否有调试输出")
|
|||
|
print("2. 确认线程是否正常启动")
|
|||
|
print("3. 观察训练函数是否被调用")
|
|||
|
print("=" * 50)
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
test_api_training()
|