100 lines
3.3 KiB
Python
100 lines
3.3 KiB
Python
![]() |
#!/usr/bin/env python3
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
"""
|
|||
|
启动API服务器并立即测试训练 - 观察控制台输出
|
|||
|
"""
|
|||
|
|
|||
|
import sys
|
|||
|
import os
|
|||
|
import threading
|
|||
|
import time
|
|||
|
import requests
|
|||
|
import json
|
|||
|
|
|||
|
# 设置UTF-8编码
|
|||
|
if os.name == 'nt':
|
|||
|
os.system('chcp 65001 >nul 2>&1')
|
|||
|
import io
|
|||
|
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace', line_buffering=True)
|
|||
|
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace', line_buffering=True)
|
|||
|
|
|||
|
# 添加server目录到路径
|
|||
|
server_path = os.path.join(os.path.dirname(__file__), 'server')
|
|||
|
sys.path.insert(0, server_path)
|
|||
|
|
|||
|
def start_api_server():
|
|||
|
"""启动API服务器"""
|
|||
|
try:
|
|||
|
print("🚀 正在启动API服务器...", flush=True)
|
|||
|
|
|||
|
# 导入并启动API
|
|||
|
import api
|
|||
|
# API会在导入时自动启动
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"❌ API服务器启动失败: {e}", flush=True)
|
|||
|
|
|||
|
def test_training_after_delay():
|
|||
|
"""延迟后测试训练"""
|
|||
|
time.sleep(3) # 等待API服务器启动
|
|||
|
|
|||
|
print("\n" + "="*60, flush=True)
|
|||
|
print("🧪 开始训练测试", flush=True)
|
|||
|
print("="*60, flush=True)
|
|||
|
|
|||
|
try:
|
|||
|
# 测试连接
|
|||
|
response = requests.get("http://127.0.0.1:5000/api/products", timeout=5)
|
|||
|
if response.status_code == 200:
|
|||
|
products = response.json().get('data', [])
|
|||
|
if products:
|
|||
|
product_id = products[0]['product_id']
|
|||
|
print(f"✅ 连接成功,选择产品: {product_id}", flush=True)
|
|||
|
|
|||
|
# 启动训练
|
|||
|
training_data = {
|
|||
|
"product_id": product_id,
|
|||
|
"model_type": "transformer",
|
|||
|
"epochs": 3,
|
|||
|
"training_mode": "product"
|
|||
|
}
|
|||
|
|
|||
|
print(f"\n🚀 发送训练请求...", flush=True)
|
|||
|
print(f"📊 训练配置: {json.dumps(training_data, ensure_ascii=False)}", flush=True)
|
|||
|
|
|||
|
response = requests.post(
|
|||
|
"http://127.0.0.1:5000/api/training",
|
|||
|
json=training_data,
|
|||
|
timeout=300
|
|||
|
)
|
|||
|
|
|||
|
if response.status_code == 200:
|
|||
|
result = response.json()
|
|||
|
print(f"✅ 训练请求成功: {result.get('task_id')}", flush=True)
|
|||
|
print(f"💬 消息: {result.get('message')}", flush=True)
|
|||
|
else:
|
|||
|
print(f"❌ 训练请求失败: {response.status_code}", flush=True)
|
|||
|
print(f"🔴 错误: {response.text}", flush=True)
|
|||
|
else:
|
|||
|
print("❌ 没有找到产品", flush=True)
|
|||
|
else:
|
|||
|
print(f"❌ API连接失败: {response.status_code}", flush=True)
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"❌ 训练测试失败: {e}", flush=True)
|
|||
|
|
|||
|
def main():
|
|||
|
"""主函数"""
|
|||
|
print("\n" + "="*60, flush=True)
|
|||
|
print("🔧 API服务器 + 训练测试", flush=True)
|
|||
|
print("="*60, flush=True)
|
|||
|
|
|||
|
# 在后台线程中启动训练测试
|
|||
|
test_thread = threading.Thread(target=test_training_after_delay, daemon=True)
|
|||
|
test_thread.start()
|
|||
|
|
|||
|
# 在主线程中启动API服务器(这会阻塞)
|
|||
|
start_api_server()
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
main()
|