99 lines
3.3 KiB
Python
99 lines
3.3 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
直接测试训练器的日志输出
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import time
|
||
|
||
# 设置编码环境
|
||
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
||
if os.name == 'nt':
|
||
try:
|
||
os.system('chcp 65001 >nul 2>&1')
|
||
if hasattr(sys.stdout, 'reconfigure'):
|
||
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
|
||
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
|
||
except Exception as e:
|
||
print(f"Warning: Failed to set UTF-8 encoding: {e}")
|
||
|
||
def test_direct_training():
|
||
"""直接测试训练器功能和日志输出"""
|
||
print("🧪 直接测试训练器日志输出")
|
||
print("=" * 50)
|
||
|
||
try:
|
||
# 添加server目录到路径
|
||
server_dir = os.path.join(os.getcwd(), 'server')
|
||
if server_dir not in sys.path:
|
||
sys.path.insert(0, server_dir)
|
||
|
||
print("1️⃣ 导入模块...")
|
||
from trainers.transformer_trainer import train_product_model_with_transformer
|
||
from utils.training_progress import TrainingProgressManager
|
||
print("✅ 模块导入成功")
|
||
|
||
print("\n2️⃣ 创建进度管理器...")
|
||
def progress_callback(message):
|
||
print(f"📊 [进度回调] {message}")
|
||
|
||
progress_manager = TrainingProgressManager(websocket_callback=progress_callback)
|
||
print("✅ 进度管理器创建成功")
|
||
|
||
print("\n3️⃣ 检查数据文件...")
|
||
if not os.path.exists('pharmacy_sales_multi_store.csv'):
|
||
print("❌ 数据文件不存在,请先运行生成数据脚本")
|
||
return False
|
||
print("✅ 数据文件存在")
|
||
|
||
print("\n4️⃣ 开始训练测试...")
|
||
print("-" * 30)
|
||
|
||
# 直接调用训练器
|
||
result = train_product_model_with_transformer(
|
||
product_id='P001',
|
||
epochs=2, # 只训练2轮用于测试
|
||
store_id=None,
|
||
training_mode='product',
|
||
aggregation_method='sum',
|
||
model_dir='saved_models',
|
||
version=None,
|
||
socketio=None, # 不使用socketio
|
||
task_id='direct_test_001',
|
||
continue_training=False
|
||
)
|
||
|
||
print("-" * 30)
|
||
print(f"4️⃣ 训练完成,结果: {type(result)}")
|
||
|
||
if result and len(result) >= 2:
|
||
model, metrics = result[0], result[1]
|
||
print(f"✅ 训练成功!")
|
||
print(f"📊 返回指标: {metrics}")
|
||
else:
|
||
print(f"⚠️ 训练结果异常: {result}")
|
||
|
||
return True
|
||
|
||
except Exception as e:
|
||
print(f"❌ 测试失败: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
return False
|
||
|
||
if __name__ == "__main__":
|
||
print("🔧 直接测试训练器日志功能")
|
||
print("这将直接调用训练器,观察日志输出\n")
|
||
|
||
success = test_direct_training()
|
||
|
||
print("\n" + "=" * 50)
|
||
if success:
|
||
print("🎉 测试完成!")
|
||
print("\n📋 观察上面的输出:")
|
||
print("- 如果看到详细的训练进度信息,说明训练器日志正常")
|
||
print("- 如果看到进度回调信息,说明进度管理器工作正常")
|
||
else:
|
||
print("❌ 测试失败") |