158 lines
6.5 KiB
Python
158 lines
6.5 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
测试API训练流程 - 直接调用训练API并观察控制台输出
|
|
"""
|
|
|
|
import requests
|
|
import json
|
|
import time
|
|
import sys
|
|
import os
|
|
|
|
# 设置UTF-8编码
|
|
if os.name == 'nt':
|
|
os.system('chcp 65001 >nul 2>&1')
|
|
import io
|
|
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace', line_buffering=True)
|
|
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace', line_buffering=True)
|
|
|
|
API_BASE_URL = "http://127.0.0.1:5000"
|
|
|
|
def test_training_api():
|
|
"""测试训练API调用"""
|
|
|
|
print("\n" + "="*60, flush=True)
|
|
print("🧪 API训练流程测试开始", flush=True)
|
|
print("="*60, flush=True)
|
|
|
|
# 检查API服务器状态
|
|
try:
|
|
response = requests.get(f"{API_BASE_URL}/api/health", timeout=5)
|
|
print(f"✅ API服务器状态: {response.status_code}", flush=True)
|
|
if response.status_code == 200:
|
|
print(f"📊 服务器响应: {response.json()}", flush=True)
|
|
except Exception as e:
|
|
print(f"❌ API服务器连接失败: {e}", flush=True)
|
|
return
|
|
|
|
# 检查是否有训练数据
|
|
try:
|
|
print("\n🔍 检查训练数据...", flush=True)
|
|
if not os.path.exists("pharmacy_sales_multi_store.csv"):
|
|
print("⚠️ 未找到训练数据文件,正在生成...", flush=True)
|
|
import subprocess
|
|
result = subprocess.run(['python', 'generate_multi_store_data.py'],
|
|
capture_output=True, text=True, cwd=".")
|
|
if result.returncode == 0:
|
|
print("✅ 训练数据生成成功", flush=True)
|
|
else:
|
|
print(f"❌ 数据生成失败: {result.stderr}", flush=True)
|
|
return
|
|
else:
|
|
print("✅ 找到训练数据文件", flush=True)
|
|
except Exception as e:
|
|
print(f"❌ 数据检查失败: {e}", flush=True)
|
|
|
|
# 获取产品列表
|
|
try:
|
|
print("\n📋 获取产品列表...", flush=True)
|
|
response = requests.get(f"{API_BASE_URL}/api/products", timeout=10)
|
|
if response.status_code == 200:
|
|
products = response.json().get('data', [])
|
|
print(f"✅ 找到 {len(products)} 个产品", flush=True)
|
|
if products:
|
|
product_id = products[0]['product_id']
|
|
print(f"🎯 选择产品: {product_id} - {products[0]['product_name']}", flush=True)
|
|
else:
|
|
print("❌ 没有找到产品数据", flush=True)
|
|
return
|
|
else:
|
|
print(f"❌ 获取产品列表失败: {response.status_code}", flush=True)
|
|
return
|
|
except Exception as e:
|
|
print(f"❌ 获取产品列表异常: {e}", flush=True)
|
|
return
|
|
|
|
# 启动训练任务
|
|
print(f"\n🚀 启动训练任务...", flush=True)
|
|
training_data = {
|
|
"product_id": product_id,
|
|
"model_type": "transformer",
|
|
"epochs": 3, # 使用较少的轮次快速测试
|
|
"training_mode": "product"
|
|
}
|
|
|
|
print(f"📊 训练配置: {json.dumps(training_data, ensure_ascii=False, indent=2)}", flush=True)
|
|
|
|
try:
|
|
response = requests.post(
|
|
f"{API_BASE_URL}/api/training",
|
|
json=training_data,
|
|
timeout=300 # 5分钟超时
|
|
)
|
|
|
|
print(f"\n📡 API响应状态: {response.status_code}", flush=True)
|
|
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
task_id = result.get('task_id')
|
|
print(f"✅ 训练任务已启动: {task_id}", flush=True)
|
|
print(f"💬 响应消息: {result.get('message', 'N/A')}", flush=True)
|
|
|
|
# 等待训练完成并检查状态
|
|
print(f"\n⏳ 等待训练完成...", flush=True)
|
|
for i in range(60): # 最多等待60秒
|
|
time.sleep(1)
|
|
try:
|
|
status_response = requests.get(f"{API_BASE_URL}/api/training", timeout=5)
|
|
if status_response.status_code == 200:
|
|
tasks = status_response.json().get('data', [])
|
|
current_task = None
|
|
for task in tasks:
|
|
if task.get('task_id') == task_id:
|
|
current_task = task
|
|
break
|
|
|
|
if current_task:
|
|
status = current_task.get('status', 'unknown')
|
|
print(f"📊 任务状态: {status} ({i+1}s)", end='\r', flush=True)
|
|
|
|
if status == 'completed':
|
|
print(f"\n✅ 训练完成!")
|
|
metrics = current_task.get('metrics', {})
|
|
if metrics:
|
|
print(f"📈 训练指标: {json.dumps(metrics, ensure_ascii=False, indent=2)}", flush=True)
|
|
else:
|
|
print("⚠️ 训练指标为空", flush=True)
|
|
break
|
|
elif status == 'failed':
|
|
print(f"\n❌ 训练失败!")
|
|
error = current_task.get('error', 'Unknown error')
|
|
print(f"🔴 错误信息: {error}", flush=True)
|
|
break
|
|
except Exception as e:
|
|
print(f"\n⚠️ 状态检查异常: {e}", flush=True)
|
|
else:
|
|
print(f"\n⏰ 训练超时,但任务可能仍在后台运行", flush=True)
|
|
|
|
else:
|
|
print(f"❌ 训练启动失败: {response.status_code}", flush=True)
|
|
try:
|
|
error_info = response.json()
|
|
print(f"🔴 错误详情: {json.dumps(error_info, ensure_ascii=False, indent=2)}", flush=True)
|
|
except:
|
|
print(f"🔴 错误内容: {response.text}", flush=True)
|
|
|
|
except requests.exceptions.Timeout:
|
|
print(f"⏰ 请求超时,训练可能仍在进行中", flush=True)
|
|
except Exception as e:
|
|
print(f"❌ 训练请求异常: {e}", flush=True)
|
|
|
|
print("\n" + "="*60, flush=True)
|
|
print("🎉 API训练流程测试完成", flush=True)
|
|
print("💡 请查看API服务器控制台输出以确认训练日志是否正常显示", flush=True)
|
|
print("="*60, flush=True)
|
|
|
|
if __name__ == "__main__":
|
|
test_training_api() |