63 lines
2.1 KiB
Python
63 lines
2.1 KiB
Python
![]() |
#!/usr/bin/env python3
|
|||
|
"""
|
|||
|
测试训练日志输出
|
|||
|
"""
|
|||
|
import requests
|
|||
|
import json
|
|||
|
import time
|
|||
|
|
|||
|
def test_training_logs():
|
|||
|
"""测试训练API和日志输出"""
|
|||
|
|
|||
|
# 1. 启动训练任务
|
|||
|
print("=== 启动训练任务 ===")
|
|||
|
training_data = {
|
|||
|
"product_id": "P001",
|
|||
|
"model_type": "mlstm",
|
|||
|
"epochs": 3, # 只训练3个epoch进行测试
|
|||
|
"training_mode": "product"
|
|||
|
}
|
|||
|
|
|||
|
try:
|
|||
|
response = requests.post(
|
|||
|
'http://localhost:5000/api/training/start',
|
|||
|
json=training_data,
|
|||
|
headers={'Content-Type': 'application/json'}
|
|||
|
)
|
|||
|
|
|||
|
if response.status_code == 200:
|
|||
|
result = response.json()
|
|||
|
task_id = result.get('task_id')
|
|||
|
print(f"✅ 训练任务已启动,任务ID: {task_id}")
|
|||
|
|
|||
|
# 2. 监控训练状态
|
|||
|
print("\n=== 监控训练状态 ===")
|
|||
|
for i in range(30): # 最多监控30次
|
|||
|
time.sleep(2)
|
|||
|
|
|||
|
status_response = requests.get(f'http://localhost:5000/api/training/status/{task_id}')
|
|||
|
if status_response.status_code == 200:
|
|||
|
status_data = status_response.json()
|
|||
|
status = status_data.get('status', 'unknown')
|
|||
|
print(f"第{i+1}次检查 - 状态: {status}")
|
|||
|
|
|||
|
if status == 'completed':
|
|||
|
print("✅ 训练已完成!")
|
|||
|
print(f"指标: {status_data.get('metrics', 'None')}")
|
|||
|
break
|
|||
|
elif status == 'failed':
|
|||
|
print("❌ 训练失败!")
|
|||
|
print(f"错误: {status_data.get('error', 'Unknown')}")
|
|||
|
break
|
|||
|
else:
|
|||
|
print(f"⚠️ 无法获取状态: {status_response.status_code}")
|
|||
|
|
|||
|
else:
|
|||
|
print(f"❌ 启动训练失败: {response.status_code}")
|
|||
|
print(f"响应: {response.text}")
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"❌ 测试失败: {e}")
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
test_training_logs()
|