102 lines
3.3 KiB
Python
102 lines
3.3 KiB
Python
#\!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
测试API训练日志功能
|
||
运行一个简单的训练任务来验证日志输出
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import time
|
||
|
||
# 设置编码环境
|
||
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
||
if os.name == 'nt':
|
||
try:
|
||
os.system('chcp 65001 >nul 2>&1')
|
||
if hasattr(sys.stdout, 'reconfigure'):
|
||
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
|
||
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
|
||
except Exception as e:
|
||
print(f"Warning: Failed to set UTF-8 encoding: {e}")
|
||
|
||
def test_direct_training():
|
||
"""直接测试训练功能"""
|
||
print("🧪 直接测试训练功能和日志输出")
|
||
print("=" * 50)
|
||
|
||
try:
|
||
# 添加server目录到路径
|
||
server_dir = os.path.join(os.getcwd(), 'server')
|
||
if server_dir not in sys.path:
|
||
sys.path.insert(0, server_dir)
|
||
|
||
# 导入必要模块
|
||
from core.predictor import PharmacyPredictor
|
||
from utils.training_progress import TrainingProgressManager
|
||
|
||
print("✅ 模块导入成功")
|
||
|
||
# 创建进度管理器实例
|
||
def progress_callback(progress_data):
|
||
"""进度回调函数"""
|
||
event_type = progress_data.get('event_type', 'unknown')
|
||
print(f"📊 [进度回调] {event_type}: {progress_data.get('data', {})}")
|
||
|
||
progress_manager = TrainingProgressManager(websocket_callback=progress_callback)
|
||
print("✅ 进度管理器创建成功")
|
||
|
||
# 检查数据文件是否存在
|
||
data_file = 'pharmacy_sales_multi_store.csv'
|
||
if not os.path.exists(data_file):
|
||
print(f"❌ 数据文件不存在: {data_file}")
|
||
print("请先运行: PYTHONIOENCODING=utf-8 uv run generate_multi_store_data.py")
|
||
return False
|
||
|
||
print(f"✅ 找到数据文件: {data_file}")
|
||
|
||
# 创建预测器
|
||
predictor = PharmacyPredictor()
|
||
print("✅ 预测器创建成功")
|
||
|
||
# 运行短期训练测试
|
||
print("\n🤖 开始训练测试(轮次=3)...")
|
||
print("-" * 30)
|
||
|
||
metrics = predictor.train_model(
|
||
product_id='P001',
|
||
model_type='transformer',
|
||
epochs=3, # 只训练3轮用于测试
|
||
training_mode='product'
|
||
)
|
||
|
||
print("-" * 30)
|
||
if metrics:
|
||
print("✅ 训练完成\!")
|
||
print(f"📊 返回指标: {metrics}")
|
||
else:
|
||
print("❌ 训练失败,未返回指标")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"❌ 测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
if __name__ == "__main__":
|
||
print("🔧 测试训练日志修复效果")
|
||
print("这将运行一个短期训练来验证日志输出是否正常\n")
|
||
|
||
success = test_direct_training()
|
||
|
||
print("\n" + "=" * 50)
|
||
if success:
|
||
print("🎉 测试完成\!")
|
||
print("\n📋 如果看到了详细的训练进度输出,说明修复成功\!")
|
||
print("现在可以启动API服务器,训练日志应该正常显示")
|
||
else:
|
||
print("❌ 测试失败,需要进一步调试")
|
||
EOF < /dev/null
|