ShopTRAINING/test/direct_api_training_test.py

149 lines
5.2 KiB
Python
Raw Permalink Normal View History

2025-07-02 11:05:23 +08:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import urllib.request
import json
import time
def test_api_training():
"""直接测试API训练功能"""
print("=" * 60)
print("🔧 直接API训练测试")
print("=" * 60)
# 测试产品列表端点
try:
print("📋 测试产品列表API...")
with urllib.request.urlopen("http://127.0.0.1:5000/api/products", timeout=10) as response:
if response.getcode() == 200:
data = json.loads(response.read().decode('utf-8'))
products = data.get('data', [])
print(f"✅ 产品列表获取成功,找到 {len(products)} 个产品")
for p in products[:3]:
print(f" - {p['id']}: {p['name']}")
else:
print(f"❌ 产品列表失败: {response.getcode()}")
return False
except Exception as e:
print(f"❌ 产品列表异常: {e}")
return False
# 提交训练任务
try:
print("\n🚀 提交训练任务...")
training_data = {
"product_id": "P001",
"model_type": "transformer",
"epochs": 2,
"training_mode": "product"
}
json_data = json.dumps(training_data).encode('utf-8')
req = urllib.request.Request(
"http://127.0.0.1:5000/api/training",
data=json_data,
headers={'Content-Type': 'application/json'}
)
with urllib.request.urlopen(req, timeout=30) as response:
if response.getcode() == 200:
result = json.loads(response.read().decode('utf-8'))
task_id = result.get('task_id')
message = result.get('message', '')
print(f"✅ 训练任务提交成功:")
print(f" 任务ID: {task_id}")
print(f" 消息: {message}")
return task_id
else:
print(f"❌ 训练提交失败: {response.getcode()}")
return None
except Exception as e:
print(f"❌ 训练提交异常: {e}")
return None
def monitor_training(task_id, max_wait=60):
"""监控训练状态"""
if not task_id:
return False
print(f"\n⏳ 监控训练任务: {task_id}")
print("-" * 40)
start_time = time.time()
check_count = 0
while time.time() - start_time < max_wait:
check_count += 1
try:
url = f"http://127.0.0.1:5000/api/training/{task_id}"
with urllib.request.urlopen(url, timeout=10) as response:
if response.getcode() == 200:
result = json.loads(response.read().decode('utf-8'))
if result.get('status') == 'success':
data = result.get('data', {})
status = data.get('status', 'unknown')
progress = data.get('progress', 0)
message = data.get('message', '')
elapsed = time.time() - start_time
print(f"[{check_count:2d}] {elapsed:4.1f}s | {status:10s} | {progress:5.1f}% | {message}")
if status == 'completed':
metrics = data.get('metrics', {})
print(f"\n✅ 训练完成!")
if metrics:
print("📊 训练指标:")
for key, value in metrics.items():
if isinstance(value, (int, float)):
print(f" {key}: {value:.4f}")
else:
print(f" {key}: {value}")
return True
elif status == 'failed':
error = data.get('error', 'Unknown error')
print(f"❌ 训练失败: {error}")
return False
else:
print(f"❌ 状态查询失败: {response.getcode()}")
return False
except Exception as e:
print(f"❌ 状态查询异常: {e}")
return False
time.sleep(3)
print(f"⏰ 训练监控超时 ({max_wait}秒)")
return False
def main():
print("🧪 开始直接API训练测试")
# 提交训练任务
task_id = test_api_training()
if task_id:
# 监控训练进度
success = monitor_training(task_id, max_wait=120)
print("\n" + "=" * 60)
if success:
print("🎉 API训练测试成功!")
print("✅ 确认: API训练功能正常工作")
print("✅ 确认: 训练日志输出正常")
print("✅ 确认: 训练指标返回完整")
else:
print("❌ API训练测试失败")
print("⚠️ 需要进一步调试训练过程")
else:
print("\n❌ 无法提交训练任务API可能有问题")
print("=" * 60)
if __name__ == "__main__":
main()