#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 测试API训练流程 - 直接调用训练API并观察控制台输出 """ import requests import json import time import sys import os # 设置UTF-8编码 if os.name == 'nt': os.system('chcp 65001 >nul 2>&1') import io sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding='utf-8', errors='replace', line_buffering=True) sys.stderr = io.TextIOWrapper(sys.stderr.buffer, encoding='utf-8', errors='replace', line_buffering=True) API_BASE_URL = "http://127.0.0.1:5000" def test_training_api(): """测试训练API调用""" print("\n" + "="*60, flush=True) print("🧪 API训练流程测试开始", flush=True) print("="*60, flush=True) # 检查API服务器状态 try: response = requests.get(f"{API_BASE_URL}/api/health", timeout=5) print(f"✅ API服务器状态: {response.status_code}", flush=True) if response.status_code == 200: print(f"📊 服务器响应: {response.json()}", flush=True) except Exception as e: print(f"❌ API服务器连接失败: {e}", flush=True) return # 检查是否有训练数据 try: print("\n🔍 检查训练数据...", flush=True) if not os.path.exists("pharmacy_sales_multi_store.csv"): print("⚠️ 未找到训练数据文件,正在生成...", flush=True) import subprocess result = subprocess.run(['python', 'generate_multi_store_data.py'], capture_output=True, text=True, cwd=".") if result.returncode == 0: print("✅ 训练数据生成成功", flush=True) else: print(f"❌ 数据生成失败: {result.stderr}", flush=True) return else: print("✅ 找到训练数据文件", flush=True) except Exception as e: print(f"❌ 数据检查失败: {e}", flush=True) # 获取产品列表 try: print("\n📋 获取产品列表...", flush=True) response = requests.get(f"{API_BASE_URL}/api/products", timeout=10) if response.status_code == 200: products = response.json().get('data', []) print(f"✅ 找到 {len(products)} 个产品", flush=True) if products: product_id = products[0]['product_id'] print(f"🎯 选择产品: {product_id} - {products[0]['product_name']}", flush=True) else: print("❌ 没有找到产品数据", flush=True) return else: print(f"❌ 获取产品列表失败: {response.status_code}", flush=True) return except Exception as e: print(f"❌ 获取产品列表异常: {e}", flush=True) return # 启动训练任务 print(f"\n🚀 启动训练任务...", flush=True) training_data = { "product_id": product_id, "model_type": "transformer", "epochs": 3, # 使用较少的轮次快速测试 "training_mode": "product" } print(f"📊 训练配置: {json.dumps(training_data, ensure_ascii=False, indent=2)}", flush=True) try: response = requests.post( f"{API_BASE_URL}/api/training", json=training_data, timeout=300 # 5分钟超时 ) print(f"\n📡 API响应状态: {response.status_code}", flush=True) if response.status_code == 200: result = response.json() task_id = result.get('task_id') print(f"✅ 训练任务已启动: {task_id}", flush=True) print(f"💬 响应消息: {result.get('message', 'N/A')}", flush=True) # 等待训练完成并检查状态 print(f"\n⏳ 等待训练完成...", flush=True) for i in range(60): # 最多等待60秒 time.sleep(1) try: status_response = requests.get(f"{API_BASE_URL}/api/training", timeout=5) if status_response.status_code == 200: tasks = status_response.json().get('data', []) current_task = None for task in tasks: if task.get('task_id') == task_id: current_task = task break if current_task: status = current_task.get('status', 'unknown') print(f"📊 任务状态: {status} ({i+1}s)", end='\r', flush=True) if status == 'completed': print(f"\n✅ 训练完成!") metrics = current_task.get('metrics', {}) if metrics: print(f"📈 训练指标: {json.dumps(metrics, ensure_ascii=False, indent=2)}", flush=True) else: print("⚠️ 训练指标为空", flush=True) break elif status == 'failed': print(f"\n❌ 训练失败!") error = current_task.get('error', 'Unknown error') print(f"🔴 错误信息: {error}", flush=True) break except Exception as e: print(f"\n⚠️ 状态检查异常: {e}", flush=True) else: print(f"\n⏰ 训练超时,但任务可能仍在后台运行", flush=True) else: print(f"❌ 训练启动失败: {response.status_code}", flush=True) try: error_info = response.json() print(f"🔴 错误详情: {json.dumps(error_info, ensure_ascii=False, indent=2)}", flush=True) except: print(f"🔴 错误内容: {response.text}", flush=True) except requests.exceptions.Timeout: print(f"⏰ 请求超时,训练可能仍在进行中", flush=True) except Exception as e: print(f"❌ 训练请求异常: {e}", flush=True) print("\n" + "="*60, flush=True) print("🎉 API训练流程测试完成", flush=True) print("💡 请查看API服务器控制台输出以确认训练日志是否正常显示", flush=True) print("="*60, flush=True) if __name__ == "__main__": test_training_api()