ShopTRAINING/test/direct_api_training_test.py
2025-07-02 11:05:23 +08:00

149 lines
5.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import urllib.request
import json
import time
def test_api_training():
"""直接测试API训练功能"""
print("=" * 60)
print("🔧 直接API训练测试")
print("=" * 60)
# 测试产品列表端点
try:
print("📋 测试产品列表API...")
with urllib.request.urlopen("http://127.0.0.1:5000/api/products", timeout=10) as response:
if response.getcode() == 200:
data = json.loads(response.read().decode('utf-8'))
products = data.get('data', [])
print(f"✅ 产品列表获取成功,找到 {len(products)} 个产品")
for p in products[:3]:
print(f" - {p['id']}: {p['name']}")
else:
print(f"❌ 产品列表失败: {response.getcode()}")
return False
except Exception as e:
print(f"❌ 产品列表异常: {e}")
return False
# 提交训练任务
try:
print("\n🚀 提交训练任务...")
training_data = {
"product_id": "P001",
"model_type": "transformer",
"epochs": 2,
"training_mode": "product"
}
json_data = json.dumps(training_data).encode('utf-8')
req = urllib.request.Request(
"http://127.0.0.1:5000/api/training",
data=json_data,
headers={'Content-Type': 'application/json'}
)
with urllib.request.urlopen(req, timeout=30) as response:
if response.getcode() == 200:
result = json.loads(response.read().decode('utf-8'))
task_id = result.get('task_id')
message = result.get('message', '')
print(f"✅ 训练任务提交成功:")
print(f" 任务ID: {task_id}")
print(f" 消息: {message}")
return task_id
else:
print(f"❌ 训练提交失败: {response.getcode()}")
return None
except Exception as e:
print(f"❌ 训练提交异常: {e}")
return None
def monitor_training(task_id, max_wait=60):
"""监控训练状态"""
if not task_id:
return False
print(f"\n⏳ 监控训练任务: {task_id}")
print("-" * 40)
start_time = time.time()
check_count = 0
while time.time() - start_time < max_wait:
check_count += 1
try:
url = f"http://127.0.0.1:5000/api/training/{task_id}"
with urllib.request.urlopen(url, timeout=10) as response:
if response.getcode() == 200:
result = json.loads(response.read().decode('utf-8'))
if result.get('status') == 'success':
data = result.get('data', {})
status = data.get('status', 'unknown')
progress = data.get('progress', 0)
message = data.get('message', '')
elapsed = time.time() - start_time
print(f"[{check_count:2d}] {elapsed:4.1f}s | {status:10s} | {progress:5.1f}% | {message}")
if status == 'completed':
metrics = data.get('metrics', {})
print(f"\n✅ 训练完成!")
if metrics:
print("📊 训练指标:")
for key, value in metrics.items():
if isinstance(value, (int, float)):
print(f" {key}: {value:.4f}")
else:
print(f" {key}: {value}")
return True
elif status == 'failed':
error = data.get('error', 'Unknown error')
print(f"❌ 训练失败: {error}")
return False
else:
print(f"❌ 状态查询失败: {response.getcode()}")
return False
except Exception as e:
print(f"❌ 状态查询异常: {e}")
return False
time.sleep(3)
print(f"⏰ 训练监控超时 ({max_wait}秒)")
return False
def main():
print("🧪 开始直接API训练测试")
# 提交训练任务
task_id = test_api_training()
if task_id:
# 监控训练进度
success = monitor_training(task_id, max_wait=120)
print("\n" + "=" * 60)
if success:
print("🎉 API训练测试成功!")
print("✅ 确认: API训练功能正常工作")
print("✅ 确认: 训练日志输出正常")
print("✅ 确认: 训练指标返回完整")
else:
print("❌ API训练测试失败")
print("⚠️ 需要进一步调试训练过程")
else:
print("\n❌ 无法提交训练任务API可能有问题")
print("=" * 60)
if __name__ == "__main__":
main()