ShopTRAINING/test/test_api_training_logs.py
2025-07-02 11:05:23 +08:00

102 lines
3.3 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#\!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
测试API训练日志功能
运行一个简单的训练任务来验证日志输出
"""
import os
import sys
import time
# 设置编码环境
os.environ['PYTHONIOENCODING'] = 'utf-8'
if os.name == 'nt':
try:
os.system('chcp 65001 >nul 2>&1')
if hasattr(sys.stdout, 'reconfigure'):
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
except Exception as e:
print(f"Warning: Failed to set UTF-8 encoding: {e}")
def test_direct_training():
"""直接测试训练功能"""
print("🧪 直接测试训练功能和日志输出")
print("=" * 50)
try:
# 添加server目录到路径
server_dir = os.path.join(os.getcwd(), 'server')
if server_dir not in sys.path:
sys.path.insert(0, server_dir)
# 导入必要模块
from core.predictor import PharmacyPredictor
from utils.training_progress import TrainingProgressManager
print("✅ 模块导入成功")
# 创建进度管理器实例
def progress_callback(progress_data):
"""进度回调函数"""
event_type = progress_data.get('event_type', 'unknown')
print(f"📊 [进度回调] {event_type}: {progress_data.get('data', {})}")
progress_manager = TrainingProgressManager(websocket_callback=progress_callback)
print("✅ 进度管理器创建成功")
# 检查数据文件是否存在
data_file = 'pharmacy_sales_multi_store.csv'
if not os.path.exists(data_file):
print(f"❌ 数据文件不存在: {data_file}")
print("请先运行: PYTHONIOENCODING=utf-8 uv run generate_multi_store_data.py")
return False
print(f"✅ 找到数据文件: {data_file}")
# 创建预测器
predictor = PharmacyPredictor()
print("✅ 预测器创建成功")
# 运行短期训练测试
print("\n🤖 开始训练测试(轮次=3...")
print("-" * 30)
metrics = predictor.train_model(
product_id='P001',
model_type='transformer',
epochs=3, # 只训练3轮用于测试
training_mode='product'
)
print("-" * 30)
if metrics:
print("✅ 训练完成\!")
print(f"📊 返回指标: {metrics}")
else:
print("❌ 训练失败,未返回指标")
return True
except Exception as e:
print(f"❌ 测试失败: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
print("🔧 测试训练日志修复效果")
print("这将运行一个短期训练来验证日志输出是否正常\n")
success = test_direct_training()
print("\n" + "=" * 50)
if success:
print("🎉 测试完成\!")
print("\n📋 如果看到了详细的训练进度输出,说明修复成功\!")
print("现在可以启动API服务器训练日志应该正常显示")
else:
print("❌ 测试失败,需要进一步调试")
EOF < /dev/null