ShopTRAINING/test/simple_api_training_test.py

159 lines
5.0 KiB
Python
Raw Normal View History

2025-07-02 11:05:23 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
简单的API训练测试 - 使用标准库测试API
"""
import os
import sys
import urllib.request
import urllib.parse
import json
import time
# 设置完整的UTF-8编码环境
os.environ['PYTHONIOENCODING'] = 'utf-8'
os.environ['PYTHONLEGACYWINDOWSSTDIO'] = '0'
# Windows控制台编码设置
if os.name == 'nt':
try:
import subprocess
subprocess.run(['chcp', '65001'], capture_output=True, shell=True, check=False)
# 重新配置标准输出
if hasattr(sys.stdout, 'reconfigure'):
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
except Exception as e:
print(f"警告: UTF-8编码设置失败: {e}")
def send_post_request(url, data):
"""发送POST请求"""
try:
json_data = json.dumps(data).encode('utf-8')
req = urllib.request.Request(
url,
data=json_data,
headers={'Content-Type': 'application/json'}
)
with urllib.request.urlopen(req, timeout=30) as response:
return response.getcode(), json.loads(response.read().decode('utf-8'))
except Exception as e:
return None, str(e)
def send_get_request(url):
"""发送GET请求"""
try:
with urllib.request.urlopen(url, timeout=10) as response:
return response.getcode(), json.loads(response.read().decode('utf-8'))
except Exception as e:
return None, str(e)
def test_api_training():
"""测试API训练功能"""
print("🧪 简单API训练测试")
print("=" * 50)
base_url = "http://127.0.0.1:5000"
# 0. 测试线程输出
print("0⃣ 测试API线程输出机制...")
status, response = send_post_request(f"{base_url}/api/test-thread-output", {})
if status == 200:
print("✅ 线程输出测试成功")
else:
print(f"❌ 线程输出测试失败: {response}")
print("\n1⃣ 测试简化训练...")
status, response = send_post_request(f"{base_url}/api/test-training-simple", {})
if status == 200:
print("✅ 简化训练测试成功")
else:
print(f"❌ 简化训练测试失败: {response}")
print("\n" + "="*50)
# 1. 测试API连接
print("🔍 测试API连接...")
status, response = send_get_request(f"{base_url}/api/products")
if status != 200:
print(f"❌ API连接失败: {response}")
print("💡 请确保API服务器已启动")
return
print("✅ API连接成功")
# 2. 获取产品信息
products = response.get('data', [])
if not products:
print("❌ 没有产品数据")
return
product_id = products[0]['product_id']
product_name = products[0]['product_name']
print(f"🎯 选择产品: {product_id} - {product_name}")
# 3. 发送训练请求
print(f"\n🚀 发送训练请求...")
training_data = {
"product_id": product_id,
"model_type": "transformer",
"epochs": 2, # 快速测试
"training_mode": "product"
}
print(f"📊 训练配置: {json.dumps(training_data, ensure_ascii=False)}")
print("\n" + "="*60)
print("🔍 请密切观察API服务器控制台输出")
print("🔍 应该看到以下调试信息:")
print(" - 🚀🚀🚀 THREAD START: 准备启动训练线程")
print(" - 🔥🔥🔥 TRAIN_TASK ENTRY: 函数已被调用!")
print(" - 🚀 训练任务开始")
print(" - 训练器调用和进度信息")
print("="*60)
# 发送训练请求
status, response = send_post_request(f"{base_url}/api/training", training_data)
if status == 200:
task_id = response.get('task_id')
print(f"\n✅ 训练请求成功")
print(f"🆔 任务ID: {task_id}")
print(f"💬 消息: {response.get('message')}")
# 等待一段时间让训练开始
print(f"\n⏳ 等待训练开始...")
time.sleep(3)
# 检查任务状态
print(f"📊 检查训练状态...")
status, response = send_get_request(f"{base_url}/api/training")
if status == 200:
tasks = response.get('data', [])
for task in tasks:
if task.get('task_id') == task_id:
print(f"📋 任务状态: {task.get('status')}")
if task.get('metrics'):
print(f"📈 指标: {task.get('metrics')}")
if task.get('error'):
print(f"❌ 错误: {task.get('error')}")
break
else:
print(f"❌ 训练请求失败")
print(f"🔴 错误: {response}")
print("\n" + "=" * 50)
print("🎯 测试重点:")
print("1. 检查API服务器控制台是否有调试输出")
print("2. 确认线程是否正常启动")
print("3. 观察训练函数是否被调用")
print("=" * 50)
if __name__ == "__main__":
test_api_training()