ShopTRAINING/test/test_training_import.py
2025-07-02 11:05:23 +08:00

165 lines
5.3 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import sys
# 设置完整的UTF-8编码环境
os.environ['PYTHONIOENCODING'] = 'utf-8'
os.environ['PYTHONLEGACYWINDOWSSTDIO'] = '0'
# Windows控制台编码设置
if os.name == 'nt':
try:
import subprocess
subprocess.run(['chcp', '65001'], capture_output=True, shell=True, check=False)
# 重新配置标准输出
if hasattr(sys.stdout, 'reconfigure'):
sys.stdout.reconfigure(encoding='utf-8', errors='replace')
sys.stderr.reconfigure(encoding='utf-8', errors='replace')
except Exception as e:
print(f"警告: UTF-8编码设置失败: {e}")
def test_training_imports():
"""测试训练相关模块的导入"""
print("🔍 训练模块导入测试")
print("=" * 60)
# 添加server路径
import os as os_module
current_dir = os_module.path.dirname(os_module.path.abspath(__file__))
server_dir = os_module.path.join(current_dir, 'server')
sys.path.insert(0, server_dir)
# 测试各种导入
tests = [
("core.predictor", "PharmacyPredictor"),
("trainers.transformer_trainer", "train_product_model_with_transformer"),
("utils.training_process_manager", "TrainingProcessManager"),
("utils.logging_config", "setup_api_logging"),
]
for module_name, class_or_func in tests:
try:
print(f"📦 测试导入 {module_name}.{class_or_func}...")
module = __import__(module_name, fromlist=[class_or_func])
item = getattr(module, class_or_func)
print(f"✅ 成功: {module_name}.{class_or_func}")
except Exception as e:
print(f"❌ 失败: {module_name}.{class_or_func} - {e}")
print("\n" + "=" * 60)
# 测试训练进程管理器的实际导入过程
print("🔍 模拟训练进程管理器的导入过程")
print("-" * 40)
try:
# 模拟training_process_manager中的导入逻辑
import sys
import os as os_mod
mgr_server_dir = os_mod.path.dirname(os_mod.path.dirname(__file__))
mgr_server_dir = os_mod.path.join(mgr_server_dir, 'server')
print(f"📁 添加路径: {mgr_server_dir}")
if mgr_server_dir not in sys.path:
sys.path.append(mgr_server_dir)
print("📦 尝试导入 core.predictor.PharmacyPredictor...")
from core.predictor import PharmacyPredictor
print("🤖 尝试实例化 PharmacyPredictor...")
predictor = PharmacyPredictor()
print("✅ 训练器导入和实例化成功!")
print("💡 这说明导入问题已解决")
return True
except Exception as e:
print(f"❌ 训练器导入失败: {e}")
print(f"🔍 错误详情:")
import traceback
traceback.print_exc()
print(f"\n💡 可能的解决方案:")
print(f"1. 检查server目录结构")
print(f"2. 确认core.predictor模块存在")
print(f"3. 检查依赖模块导入")
return False
def test_direct_training_call():
"""测试直接调用训练函数"""
print("\n🚀 直接训练调用测试")
print("=" * 60)
try:
# 添加server路径
current_dir = os.path.dirname(os.path.abspath(__file__))
server_dir = os.path.join(current_dir, 'server')
sys.path.insert(0, server_dir)
print("📦 导入 PharmacyPredictor...")
from core.predictor import PharmacyPredictor
print("🤖 创建预测器实例...")
predictor = PharmacyPredictor()
print("🎯 调用训练方法...")
print("参数: product_id=P001, model_type=transformer, epochs=1")
# 调用训练方法,使用最少的轮次
metrics = predictor.train_model(
product_id='P001',
model_type='transformer',
epochs=1,
training_mode='product'
)
print("✅ 训练调用成功!")
if metrics:
print("📊 返回的训练指标:")
for key, value in metrics.items():
if isinstance(value, (int, float)):
print(f" {key}: {value:.4f}")
else:
print(f" {key}: {value}")
else:
print("⚠️ 训练指标为空")
return True
except Exception as e:
print(f"❌ 直接训练调用失败: {e}")
import traceback
traceback.print_exc()
return False
def main():
print("🧪 训练导入和调用问题诊断")
print("目标: 找出为什么API训练返回空指标")
print("=" * 80)
# 测试导入
import_success = test_training_imports()
if import_success:
# 测试直接调用
call_success = test_direct_training_call()
if call_success:
print("\n🎉 诊断结果: 训练模块工作正常")
print("💡 问题可能在训练进程管理器的参数传递或路径配置")
else:
print("\n⚠️ 诊断结果: 训练调用存在问题")
else:
print("\n❌ 诊断结果: 模块导入失败")
print("=" * 80)
if __name__ == "__main__":
main()