#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ HTTP训练测试 - 调用已运行的API服务器进行训练测试 要求: 请先在另一个终端运行 uv run server/api.py 启动API服务器 """ import os import sys # 在导入其他模块前设置编码 if os.name == 'nt': # Windows系统编码设置 os.environ['PYTHONIOENCODING'] = 'utf-8' os.system('chcp 65001 >nul 2>&1') import requests import json import time def test_api_training(): """测试API训练功能""" print("🧪 API训练测试开始") print("="*50) api_url = "http://127.0.0.1:5000" try: # 1. 检查API服务器状态 print("🔍 检查API服务器状态...") try: response = requests.get(f"{api_url}/api/products", timeout=5) if response.status_code != 200: print("❌ API服务器未运行") print("💡 请先运行: uv run server/api.py") return False print("✅ API服务器运行正常") except requests.exceptions.ConnectionError: print("❌ 无法连接到API服务器") print("💡 请先运行: uv run server/api.py") return False # 2. 获取产品列表 print("\n📋 获取产品列表...") products = response.json().get('data', []) if not products: print("❌ 没有找到产品数据") return False product_id = products[0]['product_id'] product_name = products[0]['product_name'] print(f"✅ 选择产品: {product_id} - {product_name}") # 3. 启动训练任务 print(f"\n🚀 启动训练任务...") training_data = { "product_id": product_id, "model_type": "transformer", "epochs": 3, # 使用较少轮次快速测试 "training_mode": "product" } print(f"📊 训练配置: {json.dumps(training_data, ensure_ascii=False)}") print("🔍 请观察API服务器控制台输出...") print("-" * 60) response = requests.post( f"{api_url}/api/training", json=training_data, timeout=180 # 3分钟超时 ) print("-" * 60) print(f"📡 API响应状态: {response.status_code}") if response.status_code == 200: result = response.json() task_id = result.get('task_id') print(f"✅ 训练任务已启动") print(f"🔖 任务ID: {task_id}") print(f"💬 消息: {result.get('message', 'N/A')}") # 4. 监控训练状态 print(f"\n⏳ 监控训练状态...") max_checks = 30 # 最多检查30次 for i in range(max_checks): time.sleep(2) # 每2秒检查一次 try: status_response = requests.get(f"{api_url}/api/training", timeout=5) if status_response.status_code == 200: tasks = status_response.json().get('data', []) current_task = None for task in tasks: if task.get('task_id') == task_id: current_task = task break if current_task: status = current_task.get('status', 'unknown') print(f"📊 训练状态: {status} (检查 {i+1}/{max_checks})") if status == 'completed': print(f"\n✅ 训练成功完成!") metrics = current_task.get('metrics', {}) if metrics: print(f"📈 训练指标:") for key, value in metrics.items(): if isinstance(value, (int, float)): print(f" {key}: {value:.4f}") else: print(f" {key}: {value}") else: print("⚠️ 训练指标为空") return True elif status == 'failed': print(f"\n❌ 训练失败") error = current_task.get('error', 'Unknown error') print(f"🔴 错误信息: {error}") return False else: print(f"⚠️ 找不到任务 {task_id}") except Exception as e: print(f"⚠️ 状态检查异常: {e}") print(f"\n⏰ 训练监控超时,但任务可能仍在运行") print(f"💡 请查看API服务器控制台确认训练状态") else: print(f"❌ 训练启动失败: {response.status_code}") try: error_info = response.json() print(f"🔴 错误详情: {json.dumps(error_info, ensure_ascii=False)}") except: print(f"🔴 错误内容: {response.text}") return False except Exception as e: print(f"❌ 测试异常: {e}") return False return True def main(): """主函数""" print("\n" + "="*60) print("🔧 药店销售预测系统 - API训练测试") print("="*60) print("📋 测试步骤:") print(" 1. 检查API服务器连接") print(" 2. 获取产品列表") print(" 3. 启动训练任务") print(" 4. 监控训练进度") print("📝 注意: 请在另一个终端运行 'uv run server/api.py' 启动API服务器") print("="*60) success = test_api_training() print("\n" + "="*60) if success: print("🎉 API训练测试完成!") else: print("❌ API训练测试失败!") print("💡 关键是观察API服务器控制台是否有完整的训练日志输出") print("="*60) if __name__ == "__main__": main()