ShopTRAINING/test/test_training_api.py
2025-07-02 11:05:23 +08:00

158 lines
6.5 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
测试API训练流程 - 直接调用训练API并观察控制台输出
"""
import requests
import json
import time
import sys
import os
# 设置UTF-8编码
if os.name == 'nt':
os.system('chcp 65001 >nul 2>&1')
import io
sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace', line_buffering=True)
sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace', line_buffering=True)
API_BASE_URL = "http://127.0.0.1:5000"
def test_training_api():
"""测试训练API调用"""
print("\n" + "="*60, flush=True)
print("🧪 API训练流程测试开始", flush=True)
print("="*60, flush=True)
# 检查API服务器状态
try:
response = requests.get(f"{API_BASE_URL}/api/health", timeout=5)
print(f"✅ API服务器状态: {response.status_code}", flush=True)
if response.status_code == 200:
print(f"📊 服务器响应: {response.json()}", flush=True)
except Exception as e:
print(f"❌ API服务器连接失败: {e}", flush=True)
return
# 检查是否有训练数据
try:
print("\n🔍 检查训练数据...", flush=True)
if not os.path.exists("pharmacy_sales_multi_store.csv"):
print("⚠️ 未找到训练数据文件,正在生成...", flush=True)
import subprocess
result = subprocess.run(['python', 'generate_multi_store_data.py'],
capture_output=True, text=True, cwd=".")
if result.returncode == 0:
print("✅ 训练数据生成成功", flush=True)
else:
print(f"❌ 数据生成失败: {result.stderr}", flush=True)
return
else:
print("✅ 找到训练数据文件", flush=True)
except Exception as e:
print(f"❌ 数据检查失败: {e}", flush=True)
# 获取产品列表
try:
print("\n📋 获取产品列表...", flush=True)
response = requests.get(f"{API_BASE_URL}/api/products", timeout=10)
if response.status_code == 200:
products = response.json().get('data', [])
print(f"✅ 找到 {len(products)} 个产品", flush=True)
if products:
product_id = products[0]['product_id']
print(f"🎯 选择产品: {product_id} - {products[0]['product_name']}", flush=True)
else:
print("❌ 没有找到产品数据", flush=True)
return
else:
print(f"❌ 获取产品列表失败: {response.status_code}", flush=True)
return
except Exception as e:
print(f"❌ 获取产品列表异常: {e}", flush=True)
return
# 启动训练任务
print(f"\n🚀 启动训练任务...", flush=True)
training_data = {
"product_id": product_id,
"model_type": "transformer",
"epochs": 3, # 使用较少的轮次快速测试
"training_mode": "product"
}
print(f"📊 训练配置: {json.dumps(training_data, ensure_ascii=False, indent=2)}", flush=True)
try:
response = requests.post(
f"{API_BASE_URL}/api/training",
json=training_data,
timeout=300 # 5分钟超时
)
print(f"\n📡 API响应状态: {response.status_code}", flush=True)
if response.status_code == 200:
result = response.json()
task_id = result.get('task_id')
print(f"✅ 训练任务已启动: {task_id}", flush=True)
print(f"💬 响应消息: {result.get('message', 'N/A')}", flush=True)
# 等待训练完成并检查状态
print(f"\n⏳ 等待训练完成...", flush=True)
for i in range(60): # 最多等待60秒
time.sleep(1)
try:
status_response = requests.get(f"{API_BASE_URL}/api/training", timeout=5)
if status_response.status_code == 200:
tasks = status_response.json().get('data', [])
current_task = None
for task in tasks:
if task.get('task_id') == task_id:
current_task = task
break
if current_task:
status = current_task.get('status', 'unknown')
print(f"📊 任务状态: {status} ({i+1}s)", end='\r', flush=True)
if status == 'completed':
print(f"\n✅ 训练完成!")
metrics = current_task.get('metrics', {})
if metrics:
print(f"📈 训练指标: {json.dumps(metrics, ensure_ascii=False, indent=2)}", flush=True)
else:
print("⚠️ 训练指标为空", flush=True)
break
elif status == 'failed':
print(f"\n❌ 训练失败!")
error = current_task.get('error', 'Unknown error')
print(f"🔴 错误信息: {error}", flush=True)
break
except Exception as e:
print(f"\n⚠️ 状态检查异常: {e}", flush=True)
else:
print(f"\n⏰ 训练超时,但任务可能仍在后台运行", flush=True)
else:
print(f"❌ 训练启动失败: {response.status_code}", flush=True)
try:
error_info = response.json()
print(f"🔴 错误详情: {json.dumps(error_info, ensure_ascii=False, indent=2)}", flush=True)
except:
print(f"🔴 错误内容: {response.text}", flush=True)
except requests.exceptions.Timeout:
print(f"⏰ 请求超时,训练可能仍在进行中", flush=True)
except Exception as e:
print(f"❌ 训练请求异常: {e}", flush=True)
print("\n" + "="*60, flush=True)
print("🎉 API训练流程测试完成", flush=True)
print("💡 请查看API服务器控制台输出以确认训练日志是否正常显示", flush=True)
print("="*60, flush=True)
if __name__ == "__main__":
test_training_api()