173 lines
6.4 KiB
Python
173 lines
6.4 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
HTTP训练测试 - 调用已运行的API服务器进行训练测试
|
|
要求: 请先在另一个终端运行 uv run server/api.py 启动API服务器
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
|
|
# 在导入其他模块前设置编码
|
|
if os.name == 'nt':
|
|
# Windows系统编码设置
|
|
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
|
os.system('chcp 65001 >nul 2>&1')
|
|
|
|
import requests
|
|
import json
|
|
import time
|
|
|
|
def test_api_training():
|
|
"""测试API训练功能"""
|
|
|
|
print("🧪 API训练测试开始")
|
|
print("="*50)
|
|
|
|
api_url = "http://127.0.0.1:5000"
|
|
|
|
try:
|
|
# 1. 检查API服务器状态
|
|
print("🔍 检查API服务器状态...")
|
|
try:
|
|
response = requests.get(f"{api_url}/api/products", timeout=5)
|
|
if response.status_code != 200:
|
|
print("❌ API服务器未运行")
|
|
print("💡 请先运行: uv run server/api.py")
|
|
return False
|
|
print("✅ API服务器运行正常")
|
|
except requests.exceptions.ConnectionError:
|
|
print("❌ 无法连接到API服务器")
|
|
print("💡 请先运行: uv run server/api.py")
|
|
return False
|
|
|
|
# 2. 获取产品列表
|
|
print("\n📋 获取产品列表...")
|
|
products = response.json().get('data', [])
|
|
if not products:
|
|
print("❌ 没有找到产品数据")
|
|
return False
|
|
|
|
product_id = products[0]['product_id']
|
|
product_name = products[0]['product_name']
|
|
print(f"✅ 选择产品: {product_id} - {product_name}")
|
|
|
|
# 3. 启动训练任务
|
|
print(f"\n🚀 启动训练任务...")
|
|
training_data = {
|
|
"product_id": product_id,
|
|
"model_type": "transformer",
|
|
"epochs": 3, # 使用较少轮次快速测试
|
|
"training_mode": "product"
|
|
}
|
|
|
|
print(f"📊 训练配置: {json.dumps(training_data, ensure_ascii=False)}")
|
|
print("🔍 请观察API服务器控制台输出...")
|
|
print("-" * 60)
|
|
|
|
response = requests.post(
|
|
f"{api_url}/api/training",
|
|
json=training_data,
|
|
timeout=180 # 3分钟超时
|
|
)
|
|
|
|
print("-" * 60)
|
|
print(f"📡 API响应状态: {response.status_code}")
|
|
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
task_id = result.get('task_id')
|
|
print(f"✅ 训练任务已启动")
|
|
print(f"🔖 任务ID: {task_id}")
|
|
print(f"💬 消息: {result.get('message', 'N/A')}")
|
|
|
|
# 4. 监控训练状态
|
|
print(f"\n⏳ 监控训练状态...")
|
|
max_checks = 30 # 最多检查30次
|
|
for i in range(max_checks):
|
|
time.sleep(2) # 每2秒检查一次
|
|
|
|
try:
|
|
status_response = requests.get(f"{api_url}/api/training", timeout=5)
|
|
if status_response.status_code == 200:
|
|
tasks = status_response.json().get('data', [])
|
|
current_task = None
|
|
|
|
for task in tasks:
|
|
if task.get('task_id') == task_id:
|
|
current_task = task
|
|
break
|
|
|
|
if current_task:
|
|
status = current_task.get('status', 'unknown')
|
|
print(f"📊 训练状态: {status} (检查 {i+1}/{max_checks})")
|
|
|
|
if status == 'completed':
|
|
print(f"\n✅ 训练成功完成!")
|
|
metrics = current_task.get('metrics', {})
|
|
if metrics:
|
|
print(f"📈 训练指标:")
|
|
for key, value in metrics.items():
|
|
if isinstance(value, (int, float)):
|
|
print(f" {key}: {value:.4f}")
|
|
else:
|
|
print(f" {key}: {value}")
|
|
else:
|
|
print("⚠️ 训练指标为空")
|
|
return True
|
|
|
|
elif status == 'failed':
|
|
print(f"\n❌ 训练失败")
|
|
error = current_task.get('error', 'Unknown error')
|
|
print(f"🔴 错误信息: {error}")
|
|
return False
|
|
|
|
else:
|
|
print(f"⚠️ 找不到任务 {task_id}")
|
|
|
|
except Exception as e:
|
|
print(f"⚠️ 状态检查异常: {e}")
|
|
|
|
print(f"\n⏰ 训练监控超时,但任务可能仍在运行")
|
|
print(f"💡 请查看API服务器控制台确认训练状态")
|
|
|
|
else:
|
|
print(f"❌ 训练启动失败: {response.status_code}")
|
|
try:
|
|
error_info = response.json()
|
|
print(f"🔴 错误详情: {json.dumps(error_info, ensure_ascii=False)}")
|
|
except:
|
|
print(f"🔴 错误内容: {response.text}")
|
|
return False
|
|
|
|
except Exception as e:
|
|
print(f"❌ 测试异常: {e}")
|
|
return False
|
|
|
|
return True
|
|
|
|
def main():
|
|
"""主函数"""
|
|
print("\n" + "="*60)
|
|
print("🔧 药店销售预测系统 - API训练测试")
|
|
print("="*60)
|
|
print("📋 测试步骤:")
|
|
print(" 1. 检查API服务器连接")
|
|
print(" 2. 获取产品列表")
|
|
print(" 3. 启动训练任务")
|
|
print(" 4. 监控训练进度")
|
|
print("📝 注意: 请在另一个终端运行 'uv run server/api.py' 启动API服务器")
|
|
print("="*60)
|
|
|
|
success = test_api_training()
|
|
|
|
print("\n" + "="*60)
|
|
if success:
|
|
print("🎉 API训练测试完成!")
|
|
else:
|
|
print("❌ API训练测试失败!")
|
|
print("💡 关键是观察API服务器控制台是否有完整的训练日志输出")
|
|
print("="*60)
|
|
|
|
if __name__ == "__main__":
|
|
main() |