340 lines
13 KiB
Python
340 lines
13 KiB
Python
"""
|
||
训练进度管理器
|
||
提供实时训练进度跟踪、速度计算和时间预估
|
||
"""
|
||
|
||
import time
|
||
import threading
|
||
from typing import Optional, Dict, Any, Callable
|
||
from dataclasses import dataclass
|
||
from datetime import datetime, timedelta
|
||
import json
|
||
|
||
|
||
@dataclass
|
||
class TrainingMetrics:
|
||
"""训练指标数据类"""
|
||
epoch: int
|
||
total_epochs: int
|
||
batch: int
|
||
total_batches: int
|
||
current_loss: float
|
||
avg_loss: float
|
||
learning_rate: float
|
||
|
||
# 时间相关
|
||
epoch_start_time: float
|
||
epoch_duration: float
|
||
total_duration: float
|
||
|
||
# 速度指标
|
||
batches_per_second: float
|
||
samples_per_second: float
|
||
|
||
# 预估时间
|
||
eta_current_epoch: float # 当前轮次剩余时间
|
||
eta_total: float # 总剩余时间
|
||
|
||
# 阶段信息
|
||
stage: str # 'data_loading', 'training', 'validation', 'saving'
|
||
stage_progress: float # 当前阶段进度 0-100
|
||
|
||
|
||
class TrainingProgressManager:
|
||
"""训练进度管理器"""
|
||
|
||
def __init__(self, websocket_callback: Optional[Callable] = None):
|
||
"""
|
||
初始化进度管理器
|
||
|
||
Args:
|
||
websocket_callback: WebSocket回调函数,用于实时推送进度
|
||
"""
|
||
self.websocket_callback = websocket_callback
|
||
self._lock = threading.Lock()
|
||
self.reset()
|
||
|
||
def reset(self):
|
||
"""重置所有进度信息"""
|
||
with self._lock:
|
||
self.training_id = None
|
||
self.product_id = None
|
||
self.model_type = None
|
||
self.training_mode = None
|
||
|
||
# 训练配置
|
||
self.total_epochs = 0
|
||
self.total_batches_per_epoch = 0
|
||
self.batch_size = 0
|
||
self.total_samples = 0
|
||
|
||
# 当前状态
|
||
self.current_epoch = 0
|
||
self.current_batch = 0
|
||
self.current_stage = "preparing"
|
||
self.stage_progress = 0.0
|
||
|
||
# 时间跟踪
|
||
self.start_time = None
|
||
self.epoch_start_time = None
|
||
self.batch_times = []
|
||
self.epoch_times = []
|
||
|
||
# 损失跟踪
|
||
self.epoch_losses = []
|
||
self.current_epoch_losses = []
|
||
|
||
# 状态标志
|
||
self.is_training = False
|
||
self.is_cancelled = False
|
||
self.is_completed = False
|
||
|
||
def start_training(self, training_id: str, product_id: str, model_type: str,
|
||
training_mode: str, total_epochs: int, total_batches: int,
|
||
batch_size: int, total_samples: int):
|
||
"""开始训练"""
|
||
with self._lock:
|
||
self.reset()
|
||
self.training_id = training_id
|
||
self.product_id = product_id
|
||
self.model_type = model_type
|
||
self.training_mode = training_mode
|
||
self.total_epochs = total_epochs
|
||
self.total_batches_per_epoch = total_batches
|
||
self.batch_size = batch_size
|
||
self.total_samples = total_samples
|
||
self.start_time = time.time()
|
||
self.is_training = True
|
||
|
||
self._broadcast_progress("training_started")
|
||
|
||
def start_epoch(self, epoch: int):
|
||
"""开始新的训练轮次"""
|
||
with self._lock:
|
||
self.current_epoch = epoch
|
||
self.current_batch = 0
|
||
self.epoch_start_time = time.time()
|
||
self.current_epoch_losses = []
|
||
self.current_stage = "training"
|
||
self.stage_progress = 0.0
|
||
|
||
self._broadcast_progress("epoch_started")
|
||
|
||
def update_batch(self, batch: int, loss: float, learning_rate: float = 0.001):
|
||
"""更新批次进度"""
|
||
with self._lock:
|
||
if not self.is_training:
|
||
return
|
||
|
||
self.current_batch = batch
|
||
self.current_epoch_losses.append(loss)
|
||
|
||
# 计算当前阶段进度
|
||
self.stage_progress = (batch / self.total_batches_per_epoch) * 100
|
||
|
||
# 记录批次时间
|
||
current_time = time.time()
|
||
if self.epoch_start_time:
|
||
batch_duration = current_time - self.epoch_start_time
|
||
self.batch_times.append(batch_duration / (batch + 1))
|
||
|
||
# 计算训练指标
|
||
metrics = self._calculate_metrics(loss, learning_rate)
|
||
|
||
# 每10个批次或最后一个批次广播一次
|
||
if batch % 10 == 0 or batch == self.total_batches_per_epoch - 1:
|
||
self._broadcast_progress("batch_update", metrics)
|
||
|
||
def finish_epoch(self, epoch_loss: float, validation_loss: Optional[float] = None):
|
||
"""完成当前轮次"""
|
||
with self._lock:
|
||
if not self.is_training:
|
||
return
|
||
|
||
# 记录轮次时间
|
||
if self.epoch_start_time:
|
||
epoch_duration = time.time() - self.epoch_start_time
|
||
self.epoch_times.append(epoch_duration)
|
||
|
||
# 记录损失
|
||
self.epoch_losses.append({
|
||
'epoch': self.current_epoch,
|
||
'train_loss': epoch_loss,
|
||
'validation_loss': validation_loss,
|
||
'timestamp': datetime.now().isoformat()
|
||
})
|
||
|
||
metrics = self._calculate_metrics(epoch_loss, 0.001)
|
||
self._broadcast_progress("epoch_completed", metrics)
|
||
|
||
def set_stage(self, stage: str, progress: float = 0.0):
|
||
"""设置当前训练阶段"""
|
||
with self._lock:
|
||
self.current_stage = stage
|
||
self.stage_progress = progress
|
||
|
||
stage_info = {
|
||
'stage': stage,
|
||
'progress': progress,
|
||
'timestamp': datetime.now().isoformat()
|
||
}
|
||
|
||
self._broadcast_progress("stage_update", stage_info)
|
||
|
||
def finish_training(self, success: bool = True, error_message: str = None):
|
||
"""完成训练"""
|
||
with self._lock:
|
||
self.is_training = False
|
||
self.is_completed = success
|
||
|
||
if success:
|
||
self.current_stage = "completed"
|
||
self.stage_progress = 100.0
|
||
else:
|
||
self.current_stage = "failed"
|
||
|
||
finish_info = {
|
||
'success': success,
|
||
'error_message': error_message,
|
||
'total_duration': time.time() - self.start_time if self.start_time else 0,
|
||
'total_epochs_completed': self.current_epoch,
|
||
'final_loss': self.epoch_losses[-1]['train_loss'] if self.epoch_losses else None
|
||
}
|
||
|
||
self._broadcast_progress("training_finished", finish_info)
|
||
|
||
def cancel_training(self):
|
||
"""取消训练"""
|
||
with self._lock:
|
||
self.is_cancelled = True
|
||
self.is_training = False
|
||
self.current_stage = "cancelled"
|
||
|
||
self._broadcast_progress("training_cancelled")
|
||
|
||
def _calculate_metrics(self, current_loss: float, learning_rate: float) -> TrainingMetrics:
|
||
"""计算训练指标"""
|
||
current_time = time.time()
|
||
|
||
# 计算平均损失
|
||
avg_loss = sum(self.current_epoch_losses) / len(self.current_epoch_losses) if self.current_epoch_losses else current_loss
|
||
|
||
# 计算时间相关指标
|
||
epoch_duration = current_time - self.epoch_start_time if self.epoch_start_time else 0
|
||
total_duration = current_time - self.start_time if self.start_time else 0
|
||
|
||
# 计算速度指标
|
||
batches_per_second = self.current_batch / epoch_duration if epoch_duration > 0 else 0
|
||
samples_per_second = batches_per_second * self.batch_size
|
||
|
||
# 计算预估时间
|
||
if batches_per_second > 0:
|
||
remaining_batches_current_epoch = self.total_batches_per_epoch - self.current_batch
|
||
eta_current_epoch = remaining_batches_current_epoch / batches_per_second
|
||
else:
|
||
eta_current_epoch = 0
|
||
|
||
# 基于历史轮次时间预估总剩余时间
|
||
if self.epoch_times:
|
||
avg_epoch_time = sum(self.epoch_times) / len(self.epoch_times)
|
||
remaining_epochs = self.total_epochs - self.current_epoch - 1
|
||
eta_total = eta_current_epoch + (remaining_epochs * avg_epoch_time)
|
||
else:
|
||
# 基于当前轮次进度估算
|
||
if epoch_duration > 0 and self.current_batch > 0:
|
||
estimated_epoch_time = epoch_duration * (self.total_batches_per_epoch / self.current_batch)
|
||
remaining_epochs = self.total_epochs - self.current_epoch - 1
|
||
eta_total = eta_current_epoch + (remaining_epochs * estimated_epoch_time)
|
||
else:
|
||
eta_total = 0
|
||
|
||
return TrainingMetrics(
|
||
epoch=self.current_epoch,
|
||
total_epochs=self.total_epochs,
|
||
batch=self.current_batch,
|
||
total_batches=self.total_batches_per_epoch,
|
||
current_loss=current_loss,
|
||
avg_loss=avg_loss,
|
||
learning_rate=learning_rate,
|
||
epoch_start_time=self.epoch_start_time or 0,
|
||
epoch_duration=epoch_duration,
|
||
total_duration=total_duration,
|
||
batches_per_second=batches_per_second,
|
||
samples_per_second=samples_per_second,
|
||
eta_current_epoch=eta_current_epoch,
|
||
eta_total=eta_total,
|
||
stage=self.current_stage,
|
||
stage_progress=self.stage_progress
|
||
)
|
||
|
||
def _broadcast_progress(self, event_type: str, data: Any = None):
|
||
"""广播进度更新"""
|
||
if not self.websocket_callback:
|
||
return
|
||
|
||
try:
|
||
message = {
|
||
'event_type': event_type,
|
||
'training_id': self.training_id,
|
||
'product_id': self.product_id,
|
||
'model_type': self.model_type,
|
||
'training_mode': self.training_mode,
|
||
'timestamp': datetime.now().isoformat(),
|
||
'data': data
|
||
}
|
||
|
||
# 如果data是TrainingMetrics对象,转换为字典
|
||
if isinstance(data, TrainingMetrics):
|
||
message['data'] = {
|
||
'epoch': data.epoch,
|
||
'total_epochs': data.total_epochs,
|
||
'batch': data.batch,
|
||
'total_batches': data.total_batches,
|
||
'current_loss': round(data.current_loss, 6),
|
||
'avg_loss': round(data.avg_loss, 6),
|
||
'learning_rate': data.learning_rate,
|
||
'epoch_duration': round(data.epoch_duration, 2),
|
||
'total_duration': round(data.total_duration, 2),
|
||
'batches_per_second': round(data.batches_per_second, 2),
|
||
'samples_per_second': round(data.samples_per_second, 0),
|
||
'eta_current_epoch': round(data.eta_current_epoch, 1),
|
||
'eta_total': round(data.eta_total, 1),
|
||
'stage': data.stage,
|
||
'stage_progress': round(data.stage_progress, 1),
|
||
'overall_progress': round((data.epoch / data.total_epochs) * 100, 1)
|
||
}
|
||
|
||
self.websocket_callback(message)
|
||
|
||
except Exception as e:
|
||
print(f"Broadcast failed: {e}")
|
||
|
||
def get_current_status(self) -> Dict[str, Any]:
|
||
"""获取当前训练状态"""
|
||
with self._lock:
|
||
if not self.is_training and not self.is_completed:
|
||
return {'status': 'idle'}
|
||
|
||
current_loss = self.current_epoch_losses[-1] if self.current_epoch_losses else 0
|
||
metrics = self._calculate_metrics(current_loss, 0.001)
|
||
|
||
return {
|
||
'status': 'training' if self.is_training else ('completed' if self.is_completed else 'idle'),
|
||
'training_id': self.training_id,
|
||
'product_id': self.product_id,
|
||
'model_type': self.model_type,
|
||
'training_mode': self.training_mode,
|
||
'current_epoch': self.current_epoch,
|
||
'total_epochs': self.total_epochs,
|
||
'current_batch': self.current_batch,
|
||
'total_batches': self.total_batches_per_epoch,
|
||
'current_stage': self.current_stage,
|
||
'stage_progress': self.stage_progress,
|
||
'overall_progress': (self.current_epoch / self.total_epochs) * 100 if self.total_epochs > 0 else 0,
|
||
'eta_total': metrics.eta_total if hasattr(metrics, 'eta_total') else 0,
|
||
'is_cancelled': self.is_cancelled
|
||
}
|
||
|
||
|
||
# 全局进度管理器实例
|
||
progress_manager = TrainingProgressManager() |