ShopTRAINING/test/simple_api_training_test.py
2025-07-02 11:05:23 +08:00

159 lines
5.0 KiB
Python
Raw Permalink Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
简单的API训练测试 - 使用标准库测试API
"""
import os
import sys
import urllib.request
import urllib.parse
import json
import time
# 设置完整的UTF-8编码环境
os.environ['PYTHONIOENCODING'] = 'utf-8'
os.environ['PYTHONLEGACYWINDOWSSTDIO'] = '0'
# Windows控制台编码设置
if os.name == 'nt':
try:
import subprocess
subprocess.run(['chcp', '65001'], capture_output=True, shell=True, check=False)
# 重新配置标准输出
if hasattr(sys.stdout, 'reconfigure'):
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
except Exception as e:
print(f"警告: UTF-8编码设置失败: {e}")
def send_post_request(url, data):
"""发送POST请求"""
try:
json_data = json.dumps(data).encode('utf-8')
req = urllib.request.Request(
url,
data=json_data,
headers={'Content-Type': 'application/json'}
)
with urllib.request.urlopen(req, timeout=30) as response:
return response.getcode(), json.loads(response.read().decode('utf-8'))
except Exception as e:
return None, str(e)
def send_get_request(url):
"""发送GET请求"""
try:
with urllib.request.urlopen(url, timeout=10) as response:
return response.getcode(), json.loads(response.read().decode('utf-8'))
except Exception as e:
return None, str(e)
def test_api_training():
"""测试API训练功能"""
print("🧪 简单API训练测试")
print("=" * 50)
base_url = "http://127.0.0.1:5000"
# 0. 测试线程输出
print("0⃣ 测试API线程输出机制...")
status, response = send_post_request(f"{base_url}/api/test-thread-output", {})
if status == 200:
print("✅ 线程输出测试成功")
else:
print(f"❌ 线程输出测试失败: {response}")
print("\n1⃣ 测试简化训练...")
status, response = send_post_request(f"{base_url}/api/test-training-simple", {})
if status == 200:
print("✅ 简化训练测试成功")
else:
print(f"❌ 简化训练测试失败: {response}")
print("\n" + "="*50)
# 1. 测试API连接
print("🔍 测试API连接...")
status, response = send_get_request(f"{base_url}/api/products")
if status != 200:
print(f"❌ API连接失败: {response}")
print("💡 请确保API服务器已启动")
return
print("✅ API连接成功")
# 2. 获取产品信息
products = response.get('data', [])
if not products:
print("❌ 没有产品数据")
return
product_id = products[0]['product_id']
product_name = products[0]['product_name']
print(f"🎯 选择产品: {product_id} - {product_name}")
# 3. 发送训练请求
print(f"\n🚀 发送训练请求...")
training_data = {
"product_id": product_id,
"model_type": "transformer",
"epochs": 2, # 快速测试
"training_mode": "product"
}
print(f"📊 训练配置: {json.dumps(training_data, ensure_ascii=False)}")
print("\n" + "="*60)
print("🔍 请密切观察API服务器控制台输出")
print("🔍 应该看到以下调试信息:")
print(" - 🚀🚀🚀 THREAD START: 准备启动训练线程")
print(" - 🔥🔥🔥 TRAIN_TASK ENTRY: 函数已被调用!")
print(" - 🚀 训练任务开始")
print(" - 训练器调用和进度信息")
print("="*60)
# 发送训练请求
status, response = send_post_request(f"{base_url}/api/training", training_data)
if status == 200:
task_id = response.get('task_id')
print(f"\n✅ 训练请求成功")
print(f"🆔 任务ID: {task_id}")
print(f"💬 消息: {response.get('message')}")
# 等待一段时间让训练开始
print(f"\n⏳ 等待训练开始...")
time.sleep(3)
# 检查任务状态
print(f"📊 检查训练状态...")
status, response = send_get_request(f"{base_url}/api/training")
if status == 200:
tasks = response.get('data', [])
for task in tasks:
if task.get('task_id') == task_id:
print(f"📋 任务状态: {task.get('status')}")
if task.get('metrics'):
print(f"📈 指标: {task.get('metrics')}")
if task.get('error'):
print(f"❌ 错误: {task.get('error')}")
break
else:
print(f"❌ 训练请求失败")
print(f"🔴 错误: {response}")
print("\n" + "=" * 50)
print("🎯 测试重点:")
print("1. 检查API服务器控制台是否有调试输出")
print("2. 确认线程是否正常启动")
print("3. 观察训练函数是否被调用")
print("=" * 50)
if __name__ == "__main__":
test_api_training()