124 lines
4.2 KiB
Python
124 lines
4.2 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""
|
|
测试训练日志输出 - 带正确的编码配置
|
|
"""
|
|
|
|
import os
|
|
import sys
|
|
|
|
# 在任何输出前设置环境变量
|
|
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
|
os.environ['PYTHONLEGACYWINDOWSSTDIO'] = '0'
|
|
|
|
# Windows控制台编码设置
|
|
if os.name == 'nt':
|
|
os.system('chcp 65001 >nul 2>&1')
|
|
|
|
import requests
|
|
import json
|
|
import time
|
|
|
|
def test_training_with_logs():
|
|
"""测试训练并观察日志输出"""
|
|
|
|
print("🧪 训练日志测试开始")
|
|
print("="*50)
|
|
|
|
api_url = "http://127.0.0.1:5000"
|
|
|
|
try:
|
|
# 检查API连接
|
|
print("🔍 检查API服务器...")
|
|
response = requests.get(f"{api_url}/api/products", timeout=5)
|
|
|
|
if response.status_code != 200:
|
|
print("❌ API服务器未启动")
|
|
print("💡 请运行: 启动API服务器.bat")
|
|
return
|
|
|
|
print("✅ API服务器连接成功")
|
|
|
|
# 获取产品
|
|
products = response.json().get('data', [])
|
|
if not products:
|
|
print("❌ 没有产品数据")
|
|
return
|
|
|
|
product_id = products[0]['product_id']
|
|
product_name = products[0]['product_name']
|
|
print(f"🎯 选择产品: {product_id} - {product_name}")
|
|
|
|
# 发送训练请求
|
|
print(f"\n🚀 发送训练请求...")
|
|
training_data = {
|
|
"product_id": product_id,
|
|
"model_type": "transformer",
|
|
"epochs": 2, # 快速测试
|
|
"training_mode": "product"
|
|
}
|
|
|
|
print(f"📊 训练配置: {json.dumps(training_data, ensure_ascii=False)}")
|
|
print("\n" + "="*60)
|
|
print("🔍 请观察API服务器控制台的详细训练日志...")
|
|
print("📝 应该看到表情符号和中文正确显示")
|
|
print("="*60)
|
|
|
|
# 发送请求
|
|
response = requests.post(
|
|
f"{api_url}/api/training",
|
|
json=training_data,
|
|
timeout=120
|
|
)
|
|
|
|
if response.status_code == 200:
|
|
result = response.json()
|
|
task_id = result.get('task_id')
|
|
print(f"\n✅ 训练请求成功发送")
|
|
print(f"🆔 任务ID: {task_id}")
|
|
|
|
# 等待完成
|
|
print(f"⏳ 等待训练完成...")
|
|
for i in range(30):
|
|
time.sleep(2)
|
|
try:
|
|
status_response = requests.get(f"{api_url}/api/training")
|
|
if status_response.status_code == 200:
|
|
tasks = status_response.json().get('data', [])
|
|
for task in tasks:
|
|
if task.get('task_id') == task_id:
|
|
status = task.get('status')
|
|
if status == 'completed':
|
|
print(f"\n🎉 训练完成!")
|
|
metrics = task.get('metrics')
|
|
if metrics:
|
|
print(f"📊 训练指标: {metrics}")
|
|
return
|
|
elif status == 'failed':
|
|
print(f"\n❌ 训练失败")
|
|
print(f"错误: {task.get('error', 'Unknown')}")
|
|
return
|
|
except:
|
|
pass
|
|
|
|
print(f"\n⏰ 等待超时,训练可能仍在进行")
|
|
else:
|
|
print(f"❌ 训练请求失败: {response.status_code}")
|
|
print(f"错误: {response.text}")
|
|
|
|
except requests.exceptions.ConnectionError:
|
|
print("❌ 无法连接API服务器")
|
|
print("💡 请先运行: 启动API服务器.bat")
|
|
except Exception as e:
|
|
print(f"❌ 测试异常: {e}")
|
|
|
|
print("\n" + "="*50)
|
|
print("🎯 测试重点:")
|
|
print("1. API服务器控制台是否显示完整训练日志")
|
|
print("2. 中文字符是否正确显示")
|
|
print("3. 表情符号是否正确显示")
|
|
print("4. 训练进度是否实时输出")
|
|
print("="*50)
|
|
|
|
if __name__ == "__main__":
|
|
test_training_with_logs() |