112 lines
3.6 KiB
Python
112 lines
3.6 KiB
Python
![]() |
#!/usr/bin/env python3
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
"""
|
|||
|
测试修复后的API训练日志
|
|||
|
"""
|
|||
|
|
|||
|
import os
|
|||
|
import sys
|
|||
|
import time
|
|||
|
import requests
|
|||
|
import json
|
|||
|
|
|||
|
# 设置编码环境
|
|||
|
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
|||
|
|
|||
|
def test_api_training_logs():
|
|||
|
"""测试API训练日志是否正常显示"""
|
|||
|
print("🧪 测试API训练日志修复效果")
|
|||
|
print("=" * 50)
|
|||
|
|
|||
|
# 检查API服务器是否在运行
|
|||
|
try:
|
|||
|
response = requests.get('http://localhost:5000/api/products', timeout=5)
|
|||
|
if response.status_code != 200:
|
|||
|
print("❌ API服务器不可访问")
|
|||
|
return False
|
|||
|
print("✅ API服务器正在运行")
|
|||
|
except requests.exceptions.ConnectionError:
|
|||
|
print("❌ API服务器未启动,请先启动: PYTHONIOENCODING=utf-8 uv run server/api.py")
|
|||
|
return False
|
|||
|
except Exception as e:
|
|||
|
print(f"❌ API连接错误: {e}")
|
|||
|
return False
|
|||
|
|
|||
|
# 启动训练任务
|
|||
|
print("\n🚀 启动训练任务...")
|
|||
|
training_data = {
|
|||
|
"product_id": "P003", # 使用不同的产品避免冲突
|
|||
|
"model_type": "transformer",
|
|||
|
"epochs": 5 # 较短的训练用于测试
|
|||
|
}
|
|||
|
|
|||
|
try:
|
|||
|
response = requests.post(
|
|||
|
'http://localhost:5000/api/training',
|
|||
|
json=training_data,
|
|||
|
timeout=10
|
|||
|
)
|
|||
|
|
|||
|
if response.status_code != 200:
|
|||
|
print(f"❌ 训练启动失败: {response.status_code}")
|
|||
|
print(f"响应: {response.text}")
|
|||
|
return False
|
|||
|
|
|||
|
result = response.json()
|
|||
|
task_id = result.get('task_id')
|
|||
|
|
|||
|
if not task_id:
|
|||
|
print("❌ 未获得任务ID")
|
|||
|
return False
|
|||
|
|
|||
|
print(f"✅ 训练任务已启动: {task_id[:8]}")
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"❌ 启动训练失败: {e}")
|
|||
|
return False
|
|||
|
|
|||
|
# 监控训练状态
|
|||
|
print(f"\n📊 监控训练状态 (任务: {task_id[:8]})...")
|
|||
|
print("检查API服务器控制台是否有训练日志输出...")
|
|||
|
|
|||
|
for i in range(30): # 最多等待30秒
|
|||
|
try:
|
|||
|
response = requests.get(f'http://localhost:5000/api/training/{task_id}', timeout=5)
|
|||
|
if response.status_code == 200:
|
|||
|
status_data = response.json()
|
|||
|
task_status = status_data.get('data', {}).get('status', 'unknown')
|
|||
|
|
|||
|
print(f"⏱️ [{i+1:2d}/30] 状态: {task_status}")
|
|||
|
|
|||
|
if task_status == 'completed':
|
|||
|
print("✅ 训练完成!")
|
|||
|
metrics = status_data.get('data', {}).get('metrics')
|
|||
|
if metrics:
|
|||
|
print(f"📊 训练指标: {metrics}")
|
|||
|
break
|
|||
|
elif task_status == 'failed':
|
|||
|
print("❌ 训练失败!")
|
|||
|
error = status_data.get('data', {}).get('error')
|
|||
|
if error:
|
|||
|
print(f"错误信息: {error}")
|
|||
|
break
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"⚠️ 状态查询错误: {e}")
|
|||
|
|
|||
|
time.sleep(2)
|
|||
|
|
|||
|
print("\n" + "=" * 50)
|
|||
|
print("🎯 测试结果:")
|
|||
|
print("请检查API服务器控制台是否显示了以下类型的日志:")
|
|||
|
print("- [task_id] 🚀 训练进程启动")
|
|||
|
print("- [task_id] 📋 任务参数")
|
|||
|
print("- [task_id] 🤖 开始执行模型训练")
|
|||
|
print("- [task_id] ✅ 模型训练完成")
|
|||
|
print("- [task_id] 📊 训练指标")
|
|||
|
print("\n如果看到这些日志,说明修复成功!")
|
|||
|
|
|||
|
return True
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
test_api_training_logs()
|