ShopTRAINING/test/start_api_with_training_test.py
2025-07-02 11:05:23 +08:00

100 lines
3.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
启动API服务器并立即测试训练 - 观察控制台输出
"""
import sys
import os
import threading
import time
import requests
import json
# 设置UTF-8编码
if os.name == 'nt':
os.system('chcp 65001 >nul 2>&1')
import io
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace', line_buffering=True)
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace', line_buffering=True)
# 添加server目录到路径
server_path = os.path.join(os.path.dirname(__file__), 'server')
sys.path.insert(0, server_path)
def start_api_server():
"""启动API服务器"""
try:
print("🚀 正在启动API服务器...", flush=True)
# 导入并启动API
import api
# API会在导入时自动启动
except Exception as e:
print(f"❌ API服务器启动失败: {e}", flush=True)
def test_training_after_delay():
"""延迟后测试训练"""
time.sleep(3) # 等待API服务器启动
print("\n" + "="*60, flush=True)
print("🧪 开始训练测试", flush=True)
print("="*60, flush=True)
try:
# 测试连接
response = requests.get("http://127.0.0.1:5000/api/products", timeout=5)
if response.status_code == 200:
products = response.json().get('data', [])
if products:
product_id = products[0]['product_id']
print(f"✅ 连接成功,选择产品: {product_id}", flush=True)
# 启动训练
training_data = {
"product_id": product_id,
"model_type": "transformer",
"epochs": 3,
"training_mode": "product"
}
print(f"\n🚀 发送训练请求...", flush=True)
print(f"📊 训练配置: {json.dumps(training_data, ensure_ascii=False)}", flush=True)
response = requests.post(
"http://127.0.0.1:5000/api/training",
json=training_data,
timeout=300
)
if response.status_code == 200:
result = response.json()
print(f"✅ 训练请求成功: {result.get('task_id')}", flush=True)
print(f"💬 消息: {result.get('message')}", flush=True)
else:
print(f"❌ 训练请求失败: {response.status_code}", flush=True)
print(f"🔴 错误: {response.text}", flush=True)
else:
print("❌ 没有找到产品", flush=True)
else:
print(f"❌ API连接失败: {response.status_code}", flush=True)
except Exception as e:
print(f"❌ 训练测试失败: {e}", flush=True)
def main():
"""主函数"""
print("\n" + "="*60, flush=True)
print("🔧 API服务器 + 训练测试", flush=True)
print("="*60, flush=True)
# 在后台线程中启动训练测试
test_thread = threading.Thread(target=test_training_after_delay, daemon=True)
test_thread.start()
# 在主线程中启动API服务器这会阻塞
start_api_server()
if __name__ == "__main__":
main()