ShopTRAINING/test/test_api_training_fix.py
2025-07-02 11:05:23 +08:00

120 lines
4.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
测试API训练日志输出修复效果
"""
import os
import sys
import time
import requests
import json
# 设置环境变量
os.environ['PYTHONIOENCODING'] = 'utf-8'
def test_api_training():
"""测试API训练功能"""
print("=" * 60)
print("🧪 API训练日志输出修复测试")
print("=" * 60)
# API服务器地址
api_base = "http://localhost:5000"
# 测试连接
try:
response = requests.get(f"{api_base}/api/health", timeout=5)
if response.status_code == 200:
print("✅ API服务器连接正常")
else:
print(f"⚠️ API服务器响应异常: {response.status_code}")
return
except Exception as e:
print(f"❌ 无法连接到API服务器: {e}")
print("💡 请先启动API服务器: PYTHONIOENCODING=utf-8 uv run server/api.py")
return
# 准备训练请求
training_data = {
"product_id": "P001",
"model_type": "transformer",
"epochs": 3, # 使用较少轮次快速测试
"training_mode": "product"
}
print(f"\n📋 发送训练请求:")
print(f" 产品ID: {training_data['product_id']}")
print(f" 模型类型: {training_data['model_type']}")
print(f" 训练轮次: {training_data['epochs']}")
print(f" 训练模式: {training_data['training_mode']}")
# 发送训练请求
try:
response = requests.post(
f"{api_base}/api/start_training",
json=training_data,
timeout=10
)
if response.status_code == 200:
result = response.json()
task_id = result.get('task_id')
print(f"\n✅ 训练任务已提交")
print(f"📋 任务ID: {task_id}")
print(f"📄 响应: {result}")
# 监控任务状态
print(f"\n🔄 监控训练状态...")
for i in range(30): # 最多监控30秒
try:
status_response = requests.get(
f"{api_base}/api/training_status/{task_id}",
timeout=5
)
if status_response.status_code == 200:
status_data = status_response.json()
task_status = status_data.get('status', 'unknown')
progress = status_data.get('progress', 0)
message = status_data.get('message', '')
print(f"[{i+1:2d}] 状态: {task_status} | 进度: {progress:6.1f}% | {message}")
if task_status in ['completed', 'failed']:
if task_status == 'completed':
metrics = status_data.get('metrics')
print(f"\n🎉 训练完成!")
if metrics:
print(f"📊 训练指标: {json.dumps(metrics, indent=2, ensure_ascii=False)}")
else:
print(f"⚠️ 未返回训练指标")
else:
error = status_data.get('error', '未知错误')
print(f"\n❌ 训练失败: {error}")
break
else:
print(f"[{i+1:2d}] 获取状态失败: {status_response.status_code}")
except Exception as e:
print(f"[{i+1:2d}] 状态查询错误: {e}")
time.sleep(1)
else:
print(f"\n⏰ 监控超时,任务可能仍在运行")
else:
print(f"❌ 训练请求失败: {response.status_code}")
print(f"📄 响应: {response.text}")
except Exception as e:
print(f"❌ 发送训练请求失败: {e}")
print("\n" + "=" * 60)
print("🎯 测试完成")
print("💡 如果在API服务器控制台看到训练日志输出说明修复成功!")
print("=" * 60)
if __name__ == "__main__":
test_api_training()