102 lines
3.3 KiB
Python
102 lines
3.3 KiB
Python
![]() |
#\!/usr/bin/env python3
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
"""
|
|||
|
测试API训练日志功能
|
|||
|
运行一个简单的训练任务来验证日志输出
|
|||
|
"""
|
|||
|
|
|||
|
import os
|
|||
|
import sys
|
|||
|
import time
|
|||
|
|
|||
|
# 设置编码环境
|
|||
|
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
|||
|
if os.name == 'nt':
|
|||
|
try:
|
|||
|
os.system('chcp 65001 >nul 2>&1')
|
|||
|
if hasattr(sys.stdout, 'reconfigure'):
|
|||
|
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
|
|||
|
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
|
|||
|
except Exception as e:
|
|||
|
print(f"Warning: Failed to set UTF-8 encoding: {e}")
|
|||
|
|
|||
|
def test_direct_training():
|
|||
|
"""直接测试训练功能"""
|
|||
|
print("🧪 直接测试训练功能和日志输出")
|
|||
|
print("=" * 50)
|
|||
|
|
|||
|
try:
|
|||
|
# 添加server目录到路径
|
|||
|
server_dir = os.path.join(os.getcwd(), 'server')
|
|||
|
if server_dir not in sys.path:
|
|||
|
sys.path.insert(0, server_dir)
|
|||
|
|
|||
|
# 导入必要模块
|
|||
|
from core.predictor import PharmacyPredictor
|
|||
|
from utils.training_progress import TrainingProgressManager
|
|||
|
|
|||
|
print("✅ 模块导入成功")
|
|||
|
|
|||
|
# 创建进度管理器实例
|
|||
|
def progress_callback(progress_data):
|
|||
|
"""进度回调函数"""
|
|||
|
event_type = progress_data.get('event_type', 'unknown')
|
|||
|
print(f"📊 [进度回调] {event_type}: {progress_data.get('data', {})}")
|
|||
|
|
|||
|
progress_manager = TrainingProgressManager(websocket_callback=progress_callback)
|
|||
|
print("✅ 进度管理器创建成功")
|
|||
|
|
|||
|
# 检查数据文件是否存在
|
|||
|
data_file = 'pharmacy_sales_multi_store.csv'
|
|||
|
if not os.path.exists(data_file):
|
|||
|
print(f"❌ 数据文件不存在: {data_file}")
|
|||
|
print("请先运行: PYTHONIOENCODING=utf-8 uv run generate_multi_store_data.py")
|
|||
|
return False
|
|||
|
|
|||
|
print(f"✅ 找到数据文件: {data_file}")
|
|||
|
|
|||
|
# 创建预测器
|
|||
|
predictor = PharmacyPredictor()
|
|||
|
print("✅ 预测器创建成功")
|
|||
|
|
|||
|
# 运行短期训练测试
|
|||
|
print("\n🤖 开始训练测试(轮次=3)...")
|
|||
|
print("-" * 30)
|
|||
|
|
|||
|
metrics = predictor.train_model(
|
|||
|
product_id='P001',
|
|||
|
model_type='transformer',
|
|||
|
epochs=3, # 只训练3轮用于测试
|
|||
|
training_mode='product'
|
|||
|
)
|
|||
|
|
|||
|
print("-" * 30)
|
|||
|
if metrics:
|
|||
|
print("✅ 训练完成\!")
|
|||
|
print(f"📊 返回指标: {metrics}")
|
|||
|
else:
|
|||
|
print("❌ 训练失败,未返回指标")
|
|||
|
|
|||
|
return True
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"❌ 测试失败: {e}")
|
|||
|
import traceback
|
|||
|
traceback.print_exc()
|
|||
|
return False
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
print("🔧 测试训练日志修复效果")
|
|||
|
print("这将运行一个短期训练来验证日志输出是否正常\n")
|
|||
|
|
|||
|
success = test_direct_training()
|
|||
|
|
|||
|
print("\n" + "=" * 50)
|
|||
|
if success:
|
|||
|
print("🎉 测试完成\!")
|
|||
|
print("\n📋 如果看到了详细的训练进度输出,说明修复成功\!")
|
|||
|
print("现在可以启动API服务器,训练日志应该正常显示")
|
|||
|
else:
|
|||
|
print("❌ 测试失败,需要进一步调试")
|
|||
|
EOF < /dev/null
|