ShopTRAINING/test/direct_training_log_test.py
2025-07-02 11:05:23 +08:00

99 lines
3.3 KiB
Python
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
直接测试训练器的日志输出
"""
import os
import sys
import time
# 设置编码环境
os.environ['PYTHONIOENCODING'] = 'utf-8'
if os.name == 'nt':
try:
os.system('chcp 65001 >nul 2>&1')
if hasattr(sys.stdout, 'reconfigure'):
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
except Exception as e:
print(f"Warning: Failed to set UTF-8 encoding: {e}")
def test_direct_training():
"""直接测试训练器功能和日志输出"""
print("🧪 直接测试训练器日志输出")
print("=" * 50)
try:
# 添加server目录到路径
server_dir = os.path.join(os.getcwd(), 'server')
if server_dir not in sys.path:
sys.path.insert(0, server_dir)
print("1⃣ 导入模块...")
from trainers.transformer_trainer import train_product_model_with_transformer
from utils.training_progress import TrainingProgressManager
print("✅ 模块导入成功")
print("\n2⃣ 创建进度管理器...")
def progress_callback(message):
print(f"📊 [进度回调] {message}")
progress_manager = TrainingProgressManager(websocket_callback=progress_callback)
print("✅ 进度管理器创建成功")
print("\n3⃣ 检查数据文件...")
if not os.path.exists('pharmacy_sales_multi_store.csv'):
print("❌ 数据文件不存在,请先运行生成数据脚本")
return False
print("✅ 数据文件存在")
print("\n4⃣ 开始训练测试...")
print("-" * 30)
# 直接调用训练器
result = train_product_model_with_transformer(
product_id='P001',
epochs=2, # 只训练2轮用于测试
store_id=None,
training_mode='product',
aggregation_method='sum',
model_dir='saved_models',
version=None,
socketio=None, # 不使用socketio
task_id='direct_test_001',
continue_training=False
)
print("-" * 30)
print(f"4⃣ 训练完成,结果: {type(result)}")
if result and len(result) >= 2:
model, metrics = result[0], result[1]
print(f"✅ 训练成功!")
print(f"📊 返回指标: {metrics}")
else:
print(f"⚠️ 训练结果异常: {result}")
return True
except Exception as e:
print(f"❌ 测试失败: {e}")
import traceback
traceback.print_exc()
return False
if __name__ == "__main__":
print("🔧 直接测试训练器日志功能")
print("这将直接调用训练器,观察日志输出\n")
success = test_direct_training()
print("\n" + "=" * 50)
if success:
print("🎉 测试完成!")
print("\n📋 观察上面的输出:")
print("- 如果看到详细的训练进度信息,说明训练器日志正常")
print("- 如果看到进度回调信息,说明进度管理器工作正常")
else:
print("❌ 测试失败")