63 lines
2.1 KiB
Python
63 lines
2.1 KiB
Python
#!/usr/bin/env python3
|
||
"""
|
||
测试训练日志输出
|
||
"""
|
||
import requests
|
||
import json
|
||
import time
|
||
|
||
def test_training_logs():
|
||
"""测试训练API和日志输出"""
|
||
|
||
# 1. 启动训练任务
|
||
print("=== 启动训练任务 ===")
|
||
training_data = {
|
||
"product_id": "P001",
|
||
"model_type": "mlstm",
|
||
"epochs": 3, # 只训练3个epoch进行测试
|
||
"training_mode": "product"
|
||
}
|
||
|
||
try:
|
||
response = requests.post(
|
||
'http://localhost:5000/api/training/start',
|
||
json=training_data,
|
||
headers={'Content-Type': 'application/json'}
|
||
)
|
||
|
||
if response.status_code == 200:
|
||
result = response.json()
|
||
task_id = result.get('task_id')
|
||
print(f"✅ 训练任务已启动,任务ID: {task_id}")
|
||
|
||
# 2. 监控训练状态
|
||
print("\n=== 监控训练状态 ===")
|
||
for i in range(30): # 最多监控30次
|
||
time.sleep(2)
|
||
|
||
status_response = requests.get(f'http://localhost:5000/api/training/status/{task_id}')
|
||
if status_response.status_code == 200:
|
||
status_data = status_response.json()
|
||
status = status_data.get('status', 'unknown')
|
||
print(f"第{i+1}次检查 - 状态: {status}")
|
||
|
||
if status == 'completed':
|
||
print("✅ 训练已完成!")
|
||
print(f"指标: {status_data.get('metrics', 'None')}")
|
||
break
|
||
elif status == 'failed':
|
||
print("❌ 训练失败!")
|
||
print(f"错误: {status_data.get('error', 'Unknown')}")
|
||
break
|
||
else:
|
||
print(f"⚠️ 无法获取状态: {status_response.status_code}")
|
||
|
||
else:
|
||
print(f"❌ 启动训练失败: {response.status_code}")
|
||
print(f"响应: {response.text}")
|
||
|
||
except Exception as e:
|
||
print(f"❌ 测试失败: {e}")
|
||
|
||
if __name__ == "__main__":
|
||
test_training_logs() |