149 lines
5.2 KiB
Python
149 lines
5.2 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
import urllib.request
|
||
import json
|
||
import time
|
||
|
||
def test_api_training():
|
||
"""直接测试API训练功能"""
|
||
|
||
print("=" * 60)
|
||
print("🔧 直接API训练测试")
|
||
print("=" * 60)
|
||
|
||
# 测试产品列表端点
|
||
try:
|
||
print("📋 测试产品列表API...")
|
||
with urllib.request.urlopen("http://127.0.0.1:5000/api/products", timeout=10) as response:
|
||
if response.getcode() == 200:
|
||
data = json.loads(response.read().decode('utf-8'))
|
||
products = data.get('data', [])
|
||
print(f"✅ 产品列表获取成功,找到 {len(products)} 个产品")
|
||
for p in products[:3]:
|
||
print(f" - {p['id']}: {p['name']}")
|
||
else:
|
||
print(f"❌ 产品列表失败: {response.getcode()}")
|
||
return False
|
||
except Exception as e:
|
||
print(f"❌ 产品列表异常: {e}")
|
||
return False
|
||
|
||
# 提交训练任务
|
||
try:
|
||
print("\n🚀 提交训练任务...")
|
||
training_data = {
|
||
"product_id": "P001",
|
||
"model_type": "transformer",
|
||
"epochs": 2,
|
||
"training_mode": "product"
|
||
}
|
||
|
||
json_data = json.dumps(training_data).encode('utf-8')
|
||
req = urllib.request.Request(
|
||
"http://127.0.0.1:5000/api/training",
|
||
data=json_data,
|
||
headers={'Content-Type': 'application/json'}
|
||
)
|
||
|
||
with urllib.request.urlopen(req, timeout=30) as response:
|
||
if response.getcode() == 200:
|
||
result = json.loads(response.read().decode('utf-8'))
|
||
task_id = result.get('task_id')
|
||
message = result.get('message', '')
|
||
|
||
print(f"✅ 训练任务提交成功:")
|
||
print(f" 任务ID: {task_id}")
|
||
print(f" 消息: {message}")
|
||
|
||
return task_id
|
||
else:
|
||
print(f"❌ 训练提交失败: {response.getcode()}")
|
||
return None
|
||
|
||
except Exception as e:
|
||
print(f"❌ 训练提交异常: {e}")
|
||
return None
|
||
|
||
def monitor_training(task_id, max_wait=60):
|
||
"""监控训练状态"""
|
||
if not task_id:
|
||
return False
|
||
|
||
print(f"\n⏳ 监控训练任务: {task_id}")
|
||
print("-" * 40)
|
||
|
||
start_time = time.time()
|
||
check_count = 0
|
||
|
||
while time.time() - start_time < max_wait:
|
||
check_count += 1
|
||
try:
|
||
url = f"http://127.0.0.1:5000/api/training/{task_id}"
|
||
with urllib.request.urlopen(url, timeout=10) as response:
|
||
if response.getcode() == 200:
|
||
result = json.loads(response.read().decode('utf-8'))
|
||
|
||
if result.get('status') == 'success':
|
||
data = result.get('data', {})
|
||
status = data.get('status', 'unknown')
|
||
progress = data.get('progress', 0)
|
||
message = data.get('message', '')
|
||
|
||
elapsed = time.time() - start_time
|
||
print(f"[{check_count:2d}] {elapsed:4.1f}s | {status:10s} | {progress:5.1f}% | {message}")
|
||
|
||
if status == 'completed':
|
||
metrics = data.get('metrics', {})
|
||
print(f"\n✅ 训练完成!")
|
||
if metrics:
|
||
print("📊 训练指标:")
|
||
for key, value in metrics.items():
|
||
if isinstance(value, (int, float)):
|
||
print(f" {key}: {value:.4f}")
|
||
else:
|
||
print(f" {key}: {value}")
|
||
return True
|
||
elif status == 'failed':
|
||
error = data.get('error', 'Unknown error')
|
||
print(f"❌ 训练失败: {error}")
|
||
return False
|
||
else:
|
||
print(f"❌ 状态查询失败: {response.getcode()}")
|
||
return False
|
||
|
||
except Exception as e:
|
||
print(f"❌ 状态查询异常: {e}")
|
||
return False
|
||
|
||
time.sleep(3)
|
||
|
||
print(f"⏰ 训练监控超时 ({max_wait}秒)")
|
||
return False
|
||
|
||
def main():
|
||
print("🧪 开始直接API训练测试")
|
||
|
||
# 提交训练任务
|
||
task_id = test_api_training()
|
||
|
||
if task_id:
|
||
# 监控训练进度
|
||
success = monitor_training(task_id, max_wait=120)
|
||
|
||
print("\n" + "=" * 60)
|
||
if success:
|
||
print("🎉 API训练测试成功!")
|
||
print("✅ 确认: API训练功能正常工作")
|
||
print("✅ 确认: 训练日志输出正常")
|
||
print("✅ 确认: 训练指标返回完整")
|
||
else:
|
||
print("❌ API训练测试失败")
|
||
print("⚠️ 需要进一步调试训练过程")
|
||
else:
|
||
print("\n❌ 无法提交训练任务,API可能有问题")
|
||
|
||
print("=" * 60)
|
||
|
||
if __name__ == "__main__":
|
||
main() |