77 lines
2.3 KiB
Python
77 lines
2.3 KiB
Python
#!/usr/bin/env python
|
|
"""
|
|
测试进度管理器
|
|
"""
|
|
|
|
import sys
|
|
import os
|
|
import time
|
|
sys.path.append('server')
|
|
|
|
def test_progress_manager_only():
|
|
"""只测试进度管理器功能"""
|
|
print("=== 测试进度管理器功能 ===")
|
|
|
|
try:
|
|
from utils.training_progress import progress_manager
|
|
|
|
print("导入进度管理器成功")
|
|
|
|
# 测试进度管理器基本功能
|
|
progress_manager.start_training(
|
|
training_id="test_001",
|
|
product_id="P001",
|
|
model_type="tcn",
|
|
training_mode="product",
|
|
total_epochs=3,
|
|
total_batches=5,
|
|
batch_size=32,
|
|
total_samples=160
|
|
)
|
|
print("开始训练跟踪")
|
|
|
|
# 模拟训练过程
|
|
for epoch in range(3):
|
|
progress_manager.start_epoch(epoch)
|
|
print(f"开始轮次 {epoch}")
|
|
|
|
for batch in range(5):
|
|
loss = 1.0 - (epoch * 0.1) - (batch * 0.02)
|
|
progress_manager.update_batch(batch, loss, 0.001)
|
|
print(f" 批次 {batch}, 损失: {loss:.4f}")
|
|
|
|
epoch_loss = 0.8 - (epoch * 0.1)
|
|
progress_manager.finish_epoch(epoch_loss)
|
|
print(f"完成轮次 {epoch}, 平均损失: {epoch_loss:.4f}")
|
|
|
|
# 获取当前状态
|
|
status = progress_manager.get_current_status()
|
|
print(f" 状态: {status['status']}, 总进度: {status['overall_progress']:.1f}%")
|
|
|
|
# 完成训练
|
|
progress_manager.finish_training(success=True)
|
|
print("完成训练跟踪")
|
|
|
|
# 测试最终状态
|
|
final_status = progress_manager.get_current_status()
|
|
print(f"最终状态: {final_status['status']}")
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
print(f"进度管理器测试失败: {e}")
|
|
import traceback
|
|
traceback.print_exc()
|
|
return False
|
|
|
|
if __name__ == "__main__":
|
|
if test_progress_manager_only():
|
|
print("\n成功!进度管理器工作正常")
|
|
print("\n功能验证:")
|
|
print("- 训练开始跟踪")
|
|
print("- 轮次进度管理")
|
|
print("- 批次进度更新")
|
|
print("- 实时状态获取")
|
|
print("- 训练完成处理")
|
|
else:
|
|
print("\n测试失败") |