122 lines
4.2 KiB
Python
122 lines
4.2 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
"""
|
||
训练测试 - 正确的编码配置
|
||
测试训练器是否正常输出日志,保留所有中文和emoji字符
|
||
"""
|
||
|
||
import os
|
||
import sys
|
||
import subprocess
|
||
|
||
def setup_encoding():
|
||
"""设置正确的UTF-8编码环境"""
|
||
os.environ['PYTHONIOENCODING'] = 'utf-8'
|
||
os.environ['PYTHONLEGACYWINDOWSSTDIO'] = '0'
|
||
|
||
# Windows控制台编码设置
|
||
if os.name == 'nt':
|
||
try:
|
||
# 设置控制台代码页为UTF-8
|
||
subprocess.run(['chcp', '65001'],
|
||
capture_output=True,
|
||
shell=True,
|
||
check=False)
|
||
|
||
# 重新配置标准输出
|
||
if hasattr(sys.stdout, 'reconfigure'):
|
||
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
|
||
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
|
||
except Exception as e:
|
||
print(f"警告:编码设置失败: {e}")
|
||
|
||
def test_direct_training():
|
||
"""测试直接训练功能"""
|
||
|
||
print("=" * 60)
|
||
print("🧪 训练日志输出测试开始")
|
||
print("💡 目标: 验证训练器控制台输出是否正常")
|
||
print("📋 策略: 保留所有中文和emoji,通过编码配置解决乱码")
|
||
print("=" * 60)
|
||
|
||
# 添加server路径
|
||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||
server_dir = os.path.join(current_dir, 'server')
|
||
sys.path.insert(0, server_dir)
|
||
|
||
try:
|
||
print("📦 正在导入训练器模块...")
|
||
from trainers.transformer_trainer import train_product_model_with_transformer
|
||
print("✅ 训练器模块导入成功")
|
||
|
||
print("\n🚀 开始训练测试")
|
||
print("📊 产品ID: P001")
|
||
print("🤖 模型类型: Transformer")
|
||
print("⚙️ 训练轮次: 2 (快速测试)")
|
||
print("🎯 期望: 看到详细的训练进度输出")
|
||
print("-" * 60)
|
||
|
||
# 调用训练器 - 应该能正常显示所有中文和emoji
|
||
result = train_product_model_with_transformer(
|
||
product_id='P001',
|
||
epochs=2,
|
||
training_mode='product'
|
||
)
|
||
|
||
print("-" * 60)
|
||
print("🎉 训练测试完成!")
|
||
|
||
if result:
|
||
model, metrics, version = result
|
||
print(f"📊 返回结果:")
|
||
print(f" 模型对象: {type(model).__name__}")
|
||
print(f" 模型版本: {version}")
|
||
|
||
if metrics:
|
||
print(f" 训练指标:")
|
||
for key, value in metrics.items():
|
||
if isinstance(value, float):
|
||
print(f" {key}: {value:.4f}")
|
||
else:
|
||
print(f" {key}: {value}")
|
||
else:
|
||
print(" ⚠️ 训练指标为空")
|
||
|
||
print("✅ 训练成功完成,返回了完整结果")
|
||
else:
|
||
print("❌ 训练返回None,可能存在问题")
|
||
|
||
except Exception as e:
|
||
print(f"❌ 训练过程发生错误: {e}")
|
||
print("\n🔍 错误详情:")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
# 提供解决建议
|
||
print("\n💡 可能的解决方案:")
|
||
print("1. 确认数据文件存在: pharmacy_sales_multi_store.csv")
|
||
print("2. 检查编码环境变量: PYTHONIOENCODING=utf-8")
|
||
print("3. 使用批处理文件启动: .\\启动API服务器.bat")
|
||
|
||
print("\n" + "=" * 60)
|
||
print("🏁 训练日志输出测试结束")
|
||
print("💬 如果看到了包含emoji和中文的训练进度,说明编码配置成功")
|
||
print("📝 重要: 本次测试保留了所有原始中文和emoji字符")
|
||
print("=" * 60)
|
||
|
||
def main():
|
||
"""主函数"""
|
||
# 设置编码环境
|
||
setup_encoding()
|
||
|
||
print("🔧 编码环境配置:")
|
||
print(f" PYTHONIOENCODING: {os.environ.get('PYTHONIOENCODING', '未设置')}")
|
||
print(f" PYTHONLEGACYWINDOWSSTDIO: {os.environ.get('PYTHONLEGACYWINDOWSSTDIO', '未设置')}")
|
||
print(f" 系统平台: {os.name}")
|
||
print()
|
||
|
||
# 运行训练测试
|
||
test_direct_training()
|
||
|
||
if __name__ == "__main__":
|
||
main() |