ShopTRAINING/test/test_training_logs.py
2025-07-02 11:05:23 +08:00

63 lines
2.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
"""
测试训练日志输出
"""
import requests
import json
import time
def test_training_logs():
"""测试训练API和日志输出"""
# 1. 启动训练任务
print("=== 启动训练任务 ===")
training_data = {
"product_id": "P001",
"model_type": "mlstm",
"epochs": 3, # 只训练3个epoch进行测试
"training_mode": "product"
}
try:
response = requests.post(
'http://localhost:5000/api/training/start',
json=training_data,
headers={'Content-Type': 'application/json'}
)
if response.status_code == 200:
result = response.json()
task_id = result.get('task_id')
print(f"✅ 训练任务已启动任务ID: {task_id}")
# 2. 监控训练状态
print("\n=== 监控训练状态 ===")
for i in range(30): # 最多监控30次
time.sleep(2)
status_response = requests.get(f'http://localhost:5000/api/training/status/{task_id}')
if status_response.status_code == 200:
status_data = status_response.json()
status = status_data.get('status', 'unknown')
print(f"{i+1}次检查 - 状态: {status}")
if status == 'completed':
print("✅ 训练已完成!")
print(f"指标: {status_data.get('metrics', 'None')}")
break
elif status == 'failed':
print("❌ 训练失败!")
print(f"错误: {status_data.get('error', 'Unknown')}")
break
else:
print(f"⚠️ 无法获取状态: {status_response.status_code}")
else:
print(f"❌ 启动训练失败: {response.status_code}")
print(f"响应: {response.text}")
except Exception as e:
print(f"❌ 测试失败: {e}")
if __name__ == "__main__":
test_training_logs()