#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 直接测试训练器的日志输出 """ import os import sys import time # 设置编码环境 os.environ['PYTHONIOENCODING'] = 'utf-8' if os.name == 'nt': try: os.system('chcp 65001 >nul 2>&1') if hasattr(sys.stdout, 'reconfigure'): sys.stdout.reconfigure(encoding='utf-8', errors='replace') sys.stderr.reconfigure(encoding='utf-8', errors='replace') except Exception as e: print(f"Warning: Failed to set UTF-8 encoding: {e}") def test_direct_training(): """直接测试训练器功能和日志输出""" print("🧪 直接测试训练器日志输出") print("=" * 50) try: # 添加server目录到路径 server_dir = os.path.join(os.getcwd(), 'server') if server_dir not in sys.path: sys.path.insert(0, server_dir) print("1️⃣ 导入模块...") from trainers.transformer_trainer import train_product_model_with_transformer from utils.training_progress import TrainingProgressManager print("✅ 模块导入成功") print("\n2️⃣ 创建进度管理器...") def progress_callback(message): print(f"📊 [进度回调] {message}") progress_manager = TrainingProgressManager(websocket_callback=progress_callback) print("✅ 进度管理器创建成功") print("\n3️⃣ 检查数据文件...") if not os.path.exists('pharmacy_sales_multi_store.csv'): print("❌ 数据文件不存在,请先运行生成数据脚本") return False print("✅ 数据文件存在") print("\n4️⃣ 开始训练测试...") print("-" * 30) # 直接调用训练器 result = train_product_model_with_transformer( product_id='P001', epochs=2, # 只训练2轮用于测试 store_id=None, training_mode='product', aggregation_method='sum', model_dir='saved_models', version=None, socketio=None, # 不使用socketio task_id='direct_test_001', continue_training=False ) print("-" * 30) print(f"4️⃣ 训练完成,结果: {type(result)}") if result and len(result) >= 2: model, metrics = result[0], result[1] print(f"✅ 训练成功!") print(f"📊 返回指标: {metrics}") else: print(f"⚠️ 训练结果异常: {result}") return True except Exception as e: print(f"❌ 测试失败: {e}") import traceback traceback.print_exc() return False if __name__ == "__main__": print("🔧 直接测试训练器日志功能") print("这将直接调用训练器,观察日志输出\n") success = test_direct_training() print("\n" + "=" * 50) if success: print("🎉 测试完成!") print("\n📋 观察上面的输出:") print("- 如果看到详细的训练进度信息,说明训练器日志正常") print("- 如果看到进度回调信息,说明进度管理器工作正常") else: print("❌ 测试失败")