99 lines
3.3 KiB
Python
99 lines
3.3 KiB
Python
![]() |
#!/usr/bin/env python3
|
|||
|
# -*- coding: utf-8 -*-
|
|||
|
"""
|
|||
|
直接测试训练器的日志输出
|
|||
|
"""
|
|||
|
|
|||
|
import os
|
|||
|
import sys
|
|||
|
import time
|
|||
|
|
|||
|
# 设置编码环境
|
|||
|
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
|||
|
if os.name == 'nt':
|
|||
|
try:
|
|||
|
os.system('chcp 65001 >nul 2>&1')
|
|||
|
if hasattr(sys.stdout, 'reconfigure'):
|
|||
|
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
|
|||
|
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
|
|||
|
except Exception as e:
|
|||
|
print(f"Warning: Failed to set UTF-8 encoding: {e}")
|
|||
|
|
|||
|
def test_direct_training():
|
|||
|
"""直接测试训练器功能和日志输出"""
|
|||
|
print("🧪 直接测试训练器日志输出")
|
|||
|
print("=" * 50)
|
|||
|
|
|||
|
try:
|
|||
|
# 添加server目录到路径
|
|||
|
server_dir = os.path.join(os.getcwd(), 'server')
|
|||
|
if server_dir not in sys.path:
|
|||
|
sys.path.insert(0, server_dir)
|
|||
|
|
|||
|
print("1️⃣ 导入模块...")
|
|||
|
from trainers.transformer_trainer import train_product_model_with_transformer
|
|||
|
from utils.training_progress import TrainingProgressManager
|
|||
|
print("✅ 模块导入成功")
|
|||
|
|
|||
|
print("\n2️⃣ 创建进度管理器...")
|
|||
|
def progress_callback(message):
|
|||
|
print(f"📊 [进度回调] {message}")
|
|||
|
|
|||
|
progress_manager = TrainingProgressManager(websocket_callback=progress_callback)
|
|||
|
print("✅ 进度管理器创建成功")
|
|||
|
|
|||
|
print("\n3️⃣ 检查数据文件...")
|
|||
|
if not os.path.exists('pharmacy_sales_multi_store.csv'):
|
|||
|
print("❌ 数据文件不存在,请先运行生成数据脚本")
|
|||
|
return False
|
|||
|
print("✅ 数据文件存在")
|
|||
|
|
|||
|
print("\n4️⃣ 开始训练测试...")
|
|||
|
print("-" * 30)
|
|||
|
|
|||
|
# 直接调用训练器
|
|||
|
result = train_product_model_with_transformer(
|
|||
|
product_id='P001',
|
|||
|
epochs=2, # 只训练2轮用于测试
|
|||
|
store_id=None,
|
|||
|
training_mode='product',
|
|||
|
aggregation_method='sum',
|
|||
|
model_dir='saved_models',
|
|||
|
version=None,
|
|||
|
socketio=None, # 不使用socketio
|
|||
|
task_id='direct_test_001',
|
|||
|
continue_training=False
|
|||
|
)
|
|||
|
|
|||
|
print("-" * 30)
|
|||
|
print(f"4️⃣ 训练完成,结果: {type(result)}")
|
|||
|
|
|||
|
if result and len(result) >= 2:
|
|||
|
model, metrics = result[0], result[1]
|
|||
|
print(f"✅ 训练成功!")
|
|||
|
print(f"📊 返回指标: {metrics}")
|
|||
|
else:
|
|||
|
print(f"⚠️ 训练结果异常: {result}")
|
|||
|
|
|||
|
return True
|
|||
|
|
|||
|
except Exception as e:
|
|||
|
print(f"❌ 测试失败: {e}")
|
|||
|
import traceback
|
|||
|
traceback.print_exc()
|
|||
|
return False
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
print("🔧 直接测试训练器日志功能")
|
|||
|
print("这将直接调用训练器,观察日志输出\n")
|
|||
|
|
|||
|
success = test_direct_training()
|
|||
|
|
|||
|
print("\n" + "=" * 50)
|
|||
|
if success:
|
|||
|
print("🎉 测试完成!")
|
|||
|
print("\n📋 观察上面的输出:")
|
|||
|
print("- 如果看到详细的训练进度信息,说明训练器日志正常")
|
|||
|
print("- 如果看到进度回调信息,说明进度管理器工作正常")
|
|||
|
else:
|
|||
|
print("❌ 测试失败")
|