#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 简单的API训练测试 - 使用标准库测试API """ import os import sys import urllib.request import urllib.parse import json import time # 设置完整的UTF-8编码环境 os.environ['PYTHONIOENCODING'] = 'utf-8' os.environ['PYTHONLEGACYWINDOWSSTDIO'] = '0' # Windows控制台编码设置 if os.name == 'nt': try: import subprocess subprocess.run(['chcp', '65001'], capture_output=True, shell=True, check=False) # 重新配置标准输出 if hasattr(sys.stdout, 'reconfigure'): sys.stdout.reconfigure(encoding='utf-8', errors='replace') sys.stderr.reconfigure(encoding='utf-8', errors='replace') except Exception as e: print(f"警告: UTF-8编码设置失败: {e}") def send_post_request(url, data): """发送POST请求""" try: json_data = json.dumps(data).encode('utf-8') req = urllib.request.Request( url, data=json_data, headers={'Content-Type': 'application/json'} ) with urllib.request.urlopen(req, timeout=30) as response: return response.getcode(), json.loads(response.read().decode('utf-8')) except Exception as e: return None, str(e) def send_get_request(url): """发送GET请求""" try: with urllib.request.urlopen(url, timeout=10) as response: return response.getcode(), json.loads(response.read().decode('utf-8')) except Exception as e: return None, str(e) def test_api_training(): """测试API训练功能""" print("🧪 简单API训练测试") print("=" * 50) base_url = "http://127.0.0.1:5000" # 0. 测试线程输出 print("0️⃣ 测试API线程输出机制...") status, response = send_post_request(f"{base_url}/api/test-thread-output", {}) if status == 200: print("✅ 线程输出测试成功") else: print(f"❌ 线程输出测试失败: {response}") print("\n1️⃣ 测试简化训练...") status, response = send_post_request(f"{base_url}/api/test-training-simple", {}) if status == 200: print("✅ 简化训练测试成功") else: print(f"❌ 简化训练测试失败: {response}") print("\n" + "="*50) # 1. 测试API连接 print("🔍 测试API连接...") status, response = send_get_request(f"{base_url}/api/products") if status != 200: print(f"❌ API连接失败: {response}") print("💡 请确保API服务器已启动") return print("✅ API连接成功") # 2. 获取产品信息 products = response.get('data', []) if not products: print("❌ 没有产品数据") return product_id = products[0]['product_id'] product_name = products[0]['product_name'] print(f"🎯 选择产品: {product_id} - {product_name}") # 3. 发送训练请求 print(f"\n🚀 发送训练请求...") training_data = { "product_id": product_id, "model_type": "transformer", "epochs": 2, # 快速测试 "training_mode": "product" } print(f"📊 训练配置: {json.dumps(training_data, ensure_ascii=False)}") print("\n" + "="*60) print("🔍 请密切观察API服务器控制台输出!") print("🔍 应该看到以下调试信息:") print(" - 🚀🚀🚀 THREAD START: 准备启动训练线程") print(" - 🔥🔥🔥 TRAIN_TASK ENTRY: 函数已被调用!") print(" - 🚀 训练任务开始") print(" - 训练器调用和进度信息") print("="*60) # 发送训练请求 status, response = send_post_request(f"{base_url}/api/training", training_data) if status == 200: task_id = response.get('task_id') print(f"\n✅ 训练请求成功") print(f"🆔 任务ID: {task_id}") print(f"💬 消息: {response.get('message')}") # 等待一段时间让训练开始 print(f"\n⏳ 等待训练开始...") time.sleep(3) # 检查任务状态 print(f"📊 检查训练状态...") status, response = send_get_request(f"{base_url}/api/training") if status == 200: tasks = response.get('data', []) for task in tasks: if task.get('task_id') == task_id: print(f"📋 任务状态: {task.get('status')}") if task.get('metrics'): print(f"📈 指标: {task.get('metrics')}") if task.get('error'): print(f"❌ 错误: {task.get('error')}") break else: print(f"❌ 训练请求失败") print(f"🔴 错误: {response}") print("\n" + "=" * 50) print("🎯 测试重点:") print("1. 检查API服务器控制台是否有调试输出") print("2. 确认线程是否正常启动") print("3. 观察训练函数是否被调用") print("=" * 50) if __name__ == "__main__": test_api_training()