#!/usr/bin/env python3 # -*- coding: utf-8 -*- import urllib.request import json import time def test_api_training(): """直接测试API训练功能""" print("=" * 60) print("🔧 直接API训练测试") print("=" * 60) # 测试产品列表端点 try: print("📋 测试产品列表API...") with urllib.request.urlopen("http://127.0.0.1:5000/api/products", timeout=10) as response: if response.getcode() == 200: data = json.loads(response.read().decode('utf-8')) products = data.get('data', []) print(f"✅ 产品列表获取成功,找到 {len(products)} 个产品") for p in products[:3]: print(f" - {p['id']}: {p['name']}") else: print(f"❌ 产品列表失败: {response.getcode()}") return False except Exception as e: print(f"❌ 产品列表异常: {e}") return False # 提交训练任务 try: print("\n🚀 提交训练任务...") training_data = { "product_id": "P001", "model_type": "transformer", "epochs": 2, "training_mode": "product" } json_data = json.dumps(training_data).encode('utf-8') req = urllib.request.Request( "http://127.0.0.1:5000/api/training", data=json_data, headers={'Content-Type': 'application/json'} ) with urllib.request.urlopen(req, timeout=30) as response: if response.getcode() == 200: result = json.loads(response.read().decode('utf-8')) task_id = result.get('task_id') message = result.get('message', '') print(f"✅ 训练任务提交成功:") print(f" 任务ID: {task_id}") print(f" 消息: {message}") return task_id else: print(f"❌ 训练提交失败: {response.getcode()}") return None except Exception as e: print(f"❌ 训练提交异常: {e}") return None def monitor_training(task_id, max_wait=60): """监控训练状态""" if not task_id: return False print(f"\n⏳ 监控训练任务: {task_id}") print("-" * 40) start_time = time.time() check_count = 0 while time.time() - start_time < max_wait: check_count += 1 try: url = f"http://127.0.0.1:5000/api/training/{task_id}" with urllib.request.urlopen(url, timeout=10) as response: if response.getcode() == 200: result = json.loads(response.read().decode('utf-8')) if result.get('status') == 'success': data = result.get('data', {}) status = data.get('status', 'unknown') progress = data.get('progress', 0) message = data.get('message', '') elapsed = time.time() - start_time print(f"[{check_count:2d}] {elapsed:4.1f}s | {status:10s} | {progress:5.1f}% | {message}") if status == 'completed': metrics = data.get('metrics', {}) print(f"\n✅ 训练完成!") if metrics: print("📊 训练指标:") for key, value in metrics.items(): if isinstance(value, (int, float)): print(f" {key}: {value:.4f}") else: print(f" {key}: {value}") return True elif status == 'failed': error = data.get('error', 'Unknown error') print(f"❌ 训练失败: {error}") return False else: print(f"❌ 状态查询失败: {response.getcode()}") return False except Exception as e: print(f"❌ 状态查询异常: {e}") return False time.sleep(3) print(f"⏰ 训练监控超时 ({max_wait}秒)") return False def main(): print("🧪 开始直接API训练测试") # 提交训练任务 task_id = test_api_training() if task_id: # 监控训练进度 success = monitor_training(task_id, max_wait=120) print("\n" + "=" * 60) if success: print("🎉 API训练测试成功!") print("✅ 确认: API训练功能正常工作") print("✅ 确认: 训练日志输出正常") print("✅ 确认: 训练指标返回完整") else: print("❌ API训练测试失败") print("⚠️ 需要进一步调试训练过程") else: print("\n❌ 无法提交训练任务,API可能有问题") print("=" * 60) if __name__ == "__main__": main()