100 lines
3.3 KiB
Python
100 lines
3.3 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
启动API服务器并立即测试训练 - 观察控制台输出
|
||
"""
|
||
|
||
import sys
|
||
import os
|
||
import threading
|
||
import time
|
||
import requests
|
||
import json
|
||
|
||
# 设置UTF-8编码
|
||
if os.name == 'nt':
|
||
os.system('chcp 65001 >nul 2>&1')
|
||
import io
|
||
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace', line_buffering=True)
|
||
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace', line_buffering=True)
|
||
|
||
# 添加server目录到路径
|
||
server_path = os.path.join(os.path.dirname(__file__), 'server')
|
||
sys.path.insert(0, server_path)
|
||
|
||
def start_api_server():
|
||
"""启动API服务器"""
|
||
try:
|
||
print("🚀 正在启动API服务器...", flush=True)
|
||
|
||
# 导入并启动API
|
||
import api
|
||
# API会在导入时自动启动
|
||
|
||
except Exception as e:
|
||
print(f"❌ API服务器启动失败: {e}", flush=True)
|
||
|
||
def test_training_after_delay():
|
||
"""延迟后测试训练"""
|
||
time.sleep(3) # 等待API服务器启动
|
||
|
||
print("\n" + "="*60, flush=True)
|
||
print("🧪 开始训练测试", flush=True)
|
||
print("="*60, flush=True)
|
||
|
||
try:
|
||
# 测试连接
|
||
response = requests.get("http://127.0.0.1:5000/api/products", timeout=5)
|
||
if response.status_code == 200:
|
||
products = response.json().get('data', [])
|
||
if products:
|
||
product_id = products[0]['product_id']
|
||
print(f"✅ 连接成功,选择产品: {product_id}", flush=True)
|
||
|
||
# 启动训练
|
||
training_data = {
|
||
"product_id": product_id,
|
||
"model_type": "transformer",
|
||
"epochs": 3,
|
||
"training_mode": "product"
|
||
}
|
||
|
||
print(f"\n🚀 发送训练请求...", flush=True)
|
||
print(f"📊 训练配置: {json.dumps(training_data, ensure_ascii=False)}", flush=True)
|
||
|
||
response = requests.post(
|
||
"http://127.0.0.1:5000/api/training",
|
||
json=training_data,
|
||
timeout=300
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
print(f"✅ 训练请求成功: {result.get('task_id')}", flush=True)
|
||
print(f"💬 消息: {result.get('message')}", flush=True)
|
||
else:
|
||
print(f"❌ 训练请求失败: {response.status_code}", flush=True)
|
||
print(f"🔴 错误: {response.text}", flush=True)
|
||
else:
|
||
print("❌ 没有找到产品", flush=True)
|
||
else:
|
||
print(f"❌ API连接失败: {response.status_code}", flush=True)
|
||
|
||
except Exception as e:
|
||
print(f"❌ 训练测试失败: {e}", flush=True)
|
||
|
||
def main():
|
||
"""主函数"""
|
||
print("\n" + "="*60, flush=True)
|
||
print("🔧 API服务器 + 训练测试", flush=True)
|
||
print("="*60, flush=True)
|
||
|
||
# 在后台线程中启动训练测试
|
||
test_thread = threading.Thread(target=test_training_after_delay, daemon=True)
|
||
test_thread.start()
|
||
|
||
# 在主线程中启动API服务器(这会阻塞)
|
||
start_api_server()
|
||
|
||
if __name__ == "__main__":
|
||
main() |