ShopTRAINING/test/start_api_with_training_test.py

100 lines
3.3 KiB
Python
Raw Normal View History

2025-07-02 11:05:23 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
启动API服务器并立即测试训练 - 观察控制台输出
"""
import sys
import os
import threading
import time
import requests
import json
# 设置UTF-8编码
if os.name == 'nt':
os.system('chcp 65001 >nul 2>&1')
import io
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace', line_buffering=True)
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace', line_buffering=True)
# 添加server目录到路径
server_path = os.path.join(os.path.dirname(__file__), 'server')
sys.path.insert(0, server_path)
def start_api_server():
"""启动API服务器"""
try:
print("🚀 正在启动API服务器...", flush=True)
# 导入并启动API
import api
# API会在导入时自动启动
except Exception as e:
print(f"❌ API服务器启动失败: {e}", flush=True)
def test_training_after_delay():
"""延迟后测试训练"""
time.sleep(3) # 等待API服务器启动
print("\n" + "="*60, flush=True)
print("🧪 开始训练测试", flush=True)
print("="*60, flush=True)
try:
# 测试连接
response = requests.get("http://127.0.0.1:5000/api/products", timeout=5)
if response.status_code == 200:
products = response.json().get('data', [])
if products:
product_id = products[0]['product_id']
print(f"✅ 连接成功,选择产品: {product_id}", flush=True)
# 启动训练
training_data = {
"product_id": product_id,
"model_type": "transformer",
"epochs": 3,
"training_mode": "product"
}
print(f"\n🚀 发送训练请求...", flush=True)
print(f"📊 训练配置: {json.dumps(training_data, ensure_ascii=False)}", flush=True)
response = requests.post(
"http://127.0.0.1:5000/api/training",
json=training_data,
timeout=300
)
if response.status_code == 200:
result = response.json()
print(f"✅ 训练请求成功: {result.get('task_id')}", flush=True)
print(f"💬 消息: {result.get('message')}", flush=True)
else:
print(f"❌ 训练请求失败: {response.status_code}", flush=True)
print(f"🔴 错误: {response.text}", flush=True)
else:
print("❌ 没有找到产品", flush=True)
else:
print(f"❌ API连接失败: {response.status_code}", flush=True)
except Exception as e:
print(f"❌ 训练测试失败: {e}", flush=True)
def main():
"""主函数"""
print("\n" + "="*60, flush=True)
print("🔧 API服务器 + 训练测试", flush=True)
print("="*60, flush=True)
# 在后台线程中启动训练测试
test_thread = threading.Thread(target=test_training_after_delay, daemon=True)
test_thread.start()
# 在主线程中启动API服务器这会阻塞
start_api_server()
if __name__ == "__main__":
main()