120 lines
4.2 KiB
Python
120 lines
4.2 KiB
Python
![]() |
#!/usr/bin/env python3
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
"""
|
|||
|
测试API训练日志输出修复效果
|
|||
|
"""
|
|||
|
|
|||
|
import os
|
|||
|
import sys
|
|||
|
import time
|
|||
|
import requests
|
|||
|
import json
|
|||
|
|
|||
|
# 设置环境变量
|
|||
|
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
|||
|
|
|||
|
def test_api_training():
|
|||
|
"""测试API训练功能"""
|
|||
|
|
|||
|
print("=" * 60)
|
|||
|
print("🧪 API训练日志输出修复测试")
|
|||
|
print("=" * 60)
|
|||
|
|
|||
|
# API服务器地址
|
|||
|
api_base = "http://localhost:5000"
|
|||
|
|
|||
|
# 测试连接
|
|||
|
try:
|
|||
|
response = requests.get(f"{api_base}/api/health", timeout=5)
|
|||
|
if response.status_code == 200:
|
|||
|
print("✅ API服务器连接正常")
|
|||
|
else:
|
|||
|
print(f"⚠️ API服务器响应异常: {response.status_code}")
|
|||
|
return
|
|||
|
except Exception as e:
|
|||
|
print(f"❌ 无法连接到API服务器: {e}")
|
|||
|
print("💡 请先启动API服务器: PYTHONIOENCODING=utf-8 uv run server/api.py")
|
|||
|
return
|
|||
|
|
|||
|
# 准备训练请求
|
|||
|
training_data = {
|
|||
|
"product_id": "P001",
|
|||
|
"model_type": "transformer",
|
|||
|
"epochs": 3, # 使用较少轮次快速测试
|
|||
|
"training_mode": "product"
|
|||
|
}
|
|||
|
|
|||
|
print(f"\n📋 发送训练请求:")
|
|||
|
print(f" 产品ID: {training_data['product_id']}")
|
|||
|
print(f" 模型类型: {training_data['model_type']}")
|
|||
|
print(f" 训练轮次: {training_data['epochs']}")
|
|||
|
print(f" 训练模式: {training_data['training_mode']}")
|
|||
|
|
|||
|
# 发送训练请求
|
|||
|
try:
|
|||
|
response = requests.post(
|
|||
|
f"{api_base}/api/start_training",
|
|||
|
json=training_data,
|
|||
|
timeout=10
|
|||
|
)
|
|||
|
|
|||
|
if response.status_code == 200:
|
|||
|
result = response.json()
|
|||
|
task_id = result.get('task_id')
|
|||
|
print(f"\n✅ 训练任务已提交")
|
|||
|
print(f"📋 任务ID: {task_id}")
|
|||
|
print(f"📄 响应: {result}")
|
|||
|
|
|||
|
# 监控任务状态
|
|||
|
print(f"\n🔄 监控训练状态...")
|
|||
|
for i in range(30): # 最多监控30秒
|
|||
|
try:
|
|||
|
status_response = requests.get(
|
|||
|
f"{api_base}/api/training_status/{task_id}",
|
|||
|
timeout=5
|
|||
|
)
|
|||
|
|
|||
|
if status_response.status_code == 200:
|
|||
|
status_data = status_response.json()
|
|||
|
task_status = status_data.get('status', 'unknown')
|
|||
|
progress = status_data.get('progress', 0)
|
|||
|
message = status_data.get('message', '')
|
|||
|
|
|||
|
print(f"[{i+1:2d}] 状态: {task_status} | 进度: {progress:6.1f}% | {message}")
|
|||
|
|
|||
|
if task_status in ['completed', 'failed']:
|
|||
|
if task_status == 'completed':
|
|||
|
metrics = status_data.get('metrics')
|
|||
|
print(f"\n🎉 训练完成!")
|
|||
|
if metrics:
|
|||
|
print(f"📊 训练指标: {json.dumps(metrics, indent=2, ensure_ascii=False)}")
|
|||
|
else:
|
|||
|
print(f"⚠️ 未返回训练指标")
|
|||
|
else:
|
|||
|
error = status_data.get('error', '未知错误')
|
|||
|
print(f"\n❌ 训练失败: {error}")
|
|||
|
break
|
|||
|
else:
|
|||
|
print(f"[{i+1:2d}] 获取状态失败: {status_response.status_code}")
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"[{i+1:2d}] 状态查询错误: {e}")
|
|||
|
|
|||
|
time.sleep(1)
|
|||
|
else:
|
|||
|
print(f"\n⏰ 监控超时,任务可能仍在运行")
|
|||
|
|
|||
|
else:
|
|||
|
print(f"❌ 训练请求失败: {response.status_code}")
|
|||
|
print(f"📄 响应: {response.text}")
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"❌ 发送训练请求失败: {e}")
|
|||
|
|
|||
|
print("\n" + "=" * 60)
|
|||
|
print("🎯 测试完成")
|
|||
|
print("💡 如果在API服务器控制台看到训练日志输出,说明修复成功!")
|
|||
|
print("=" * 60)
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
test_api_training()
|