ShopTRAINING/test/http_training_test.py
2025-07-02 11:05:23 +08:00

173 lines
6.4 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
HTTP训练测试 - 调用已运行的API服务器进行训练测试
要求: 请先在另一个终端运行 uv run server/api.py 启动API服务器
"""
import os
import sys
# 在导入其他模块前设置编码
if os.name == 'nt':
# Windows系统编码设置
os.environ['PYTHONIOENCODING'] = 'utf-8'
os.system('chcp 65001 >nul 2>&1')
import requests
import json
import time
def test_api_training():
"""测试API训练功能"""
print("🧪 API训练测试开始")
print("="*50)
api_url = "http://127.0.0.1:5000"
try:
# 1. 检查API服务器状态
print("🔍 检查API服务器状态...")
try:
response = requests.get(f"{api_url}/api/products", timeout=5)
if response.status_code != 200:
print("❌ API服务器未运行")
print("💡 请先运行: uv run server/api.py")
return False
print("✅ API服务器运行正常")
except requests.exceptions.ConnectionError:
print("❌ 无法连接到API服务器")
print("💡 请先运行: uv run server/api.py")
return False
# 2. 获取产品列表
print("\n📋 获取产品列表...")
products = response.json().get('data', [])
if not products:
print("❌ 没有找到产品数据")
return False
product_id = products[0]['product_id']
product_name = products[0]['product_name']
print(f"✅ 选择产品: {product_id} - {product_name}")
# 3. 启动训练任务
print(f"\n🚀 启动训练任务...")
training_data = {
"product_id": product_id,
"model_type": "transformer",
"epochs": 3, # 使用较少轮次快速测试
"training_mode": "product"
}
print(f"📊 训练配置: {json.dumps(training_data, ensure_ascii=False)}")
print("🔍 请观察API服务器控制台输出...")
print("-" * 60)
response = requests.post(
f"{api_url}/api/training",
json=training_data,
timeout=180 # 3分钟超时
)
print("-" * 60)
print(f"📡 API响应状态: {response.status_code}")
if response.status_code == 200:
result = response.json()
task_id = result.get('task_id')
print(f"✅ 训练任务已启动")
print(f"🔖 任务ID: {task_id}")
print(f"💬 消息: {result.get('message', 'N/A')}")
# 4. 监控训练状态
print(f"\n⏳ 监控训练状态...")
max_checks = 30 # 最多检查30次
for i in range(max_checks):
time.sleep(2) # 每2秒检查一次
try:
status_response = requests.get(f"{api_url}/api/training", timeout=5)
if status_response.status_code == 200:
tasks = status_response.json().get('data', [])
current_task = None
for task in tasks:
if task.get('task_id') == task_id:
current_task = task
break
if current_task:
status = current_task.get('status', 'unknown')
print(f"📊 训练状态: {status} (检查 {i+1}/{max_checks})")
if status == 'completed':
print(f"\n✅ 训练成功完成!")
metrics = current_task.get('metrics', {})
if metrics:
print(f"📈 训练指标:")
for key, value in metrics.items():
if isinstance(value, (int, float)):
print(f" {key}: {value:.4f}")
else:
print(f" {key}: {value}")
else:
print("⚠️ 训练指标为空")
return True
elif status == 'failed':
print(f"\n❌ 训练失败")
error = current_task.get('error', 'Unknown error')
print(f"🔴 错误信息: {error}")
return False
else:
print(f"⚠️ 找不到任务 {task_id}")
except Exception as e:
print(f"⚠️ 状态检查异常: {e}")
print(f"\n⏰ 训练监控超时,但任务可能仍在运行")
print(f"💡 请查看API服务器控制台确认训练状态")
else:
print(f"❌ 训练启动失败: {response.status_code}")
try:
error_info = response.json()
print(f"🔴 错误详情: {json.dumps(error_info, ensure_ascii=False)}")
except:
print(f"🔴 错误内容: {response.text}")
return False
except Exception as e:
print(f"❌ 测试异常: {e}")
return False
return True
def main():
"""主函数"""
print("\n" + "="*60)
print("🔧 药店销售预测系统 - API训练测试")
print("="*60)
print("📋 测试步骤:")
print(" 1. 检查API服务器连接")
print(" 2. 获取产品列表")
print(" 3. 启动训练任务")
print(" 4. 监控训练进度")
print("📝 注意: 请在另一个终端运行 'uv run server/api.py' 启动API服务器")
print("="*60)
success = test_api_training()
print("\n" + "="*60)
if success:
print("🎉 API训练测试完成!")
else:
print("❌ API训练测试失败!")
print("💡 关键是观察API服务器控制台是否有完整的训练日志输出")
print("="*60)
if __name__ == "__main__":
main()