""" 训练进度管理器 提供实时训练进度跟踪、速度计算和时间预估 """ import time import threading from typing import Optional, Dict, Any, Callable from dataclasses import dataclass from datetime import datetime, timedelta import json @dataclass class TrainingMetrics: """训练指标数据类""" epoch: int total_epochs: int batch: int total_batches: int current_loss: float avg_loss: float learning_rate: float # 时间相关 epoch_start_time: float epoch_duration: float total_duration: float # 速度指标 batches_per_second: float samples_per_second: float # 预估时间 eta_current_epoch: float # 当前轮次剩余时间 eta_total: float # 总剩余时间 # 阶段信息 stage: str # 'data_loading', 'training', 'validation', 'saving' stage_progress: float # 当前阶段进度 0-100 class TrainingProgressManager: """训练进度管理器""" def __init__(self, websocket_callback: Optional[Callable] = None): """ 初始化进度管理器 Args: websocket_callback: WebSocket回调函数,用于实时推送进度 """ self.websocket_callback = websocket_callback self._lock = threading.Lock() self.reset() def reset(self): """重置所有进度信息""" with self._lock: self.training_id = None self.product_id = None self.model_type = None self.training_mode = None # 训练配置 self.total_epochs = 0 self.total_batches_per_epoch = 0 self.batch_size = 0 self.total_samples = 0 # 当前状态 self.current_epoch = 0 self.current_batch = 0 self.current_stage = "preparing" self.stage_progress = 0.0 # 时间跟踪 self.start_time = None self.epoch_start_time = None self.batch_times = [] self.epoch_times = [] # 损失跟踪 self.epoch_losses = [] self.current_epoch_losses = [] # 状态标志 self.is_training = False self.is_cancelled = False self.is_completed = False def start_training(self, training_id: str, product_id: str, model_type: str, training_mode: str, total_epochs: int, total_batches: int, batch_size: int, total_samples: int): """开始训练""" with self._lock: self.reset() self.training_id = training_id self.product_id = product_id self.model_type = model_type self.training_mode = training_mode self.total_epochs = total_epochs self.total_batches_per_epoch = total_batches self.batch_size = batch_size self.total_samples = total_samples self.start_time = time.time() self.is_training = True self._broadcast_progress("training_started") def start_epoch(self, epoch: int): """开始新的训练轮次""" with self._lock: self.current_epoch = epoch self.current_batch = 0 self.epoch_start_time = time.time() self.current_epoch_losses = [] self.current_stage = "training" self.stage_progress = 0.0 self._broadcast_progress("epoch_started") def update_batch(self, batch: int, loss: float, learning_rate: float = 0.001): """更新批次进度""" with self._lock: if not self.is_training: return self.current_batch = batch self.current_epoch_losses.append(loss) # 计算当前阶段进度 self.stage_progress = (batch / self.total_batches_per_epoch) * 100 # 记录批次时间 current_time = time.time() if self.epoch_start_time: batch_duration = current_time - self.epoch_start_time self.batch_times.append(batch_duration / (batch + 1)) # 计算训练指标 metrics = self._calculate_metrics(loss, learning_rate) # 每10个批次或最后一个批次广播一次 if batch % 10 == 0 or batch == self.total_batches_per_epoch - 1: self._broadcast_progress("batch_update", metrics) def finish_epoch(self, epoch_loss: float, validation_loss: Optional[float] = None): """完成当前轮次""" with self._lock: if not self.is_training: return # 记录轮次时间 if self.epoch_start_time: epoch_duration = time.time() - self.epoch_start_time self.epoch_times.append(epoch_duration) # 记录损失 self.epoch_losses.append({ 'epoch': self.current_epoch, 'train_loss': epoch_loss, 'validation_loss': validation_loss, 'timestamp': datetime.now().isoformat() }) metrics = self._calculate_metrics(epoch_loss, 0.001) self._broadcast_progress("epoch_completed", metrics) def set_stage(self, stage: str, progress: float = 0.0): """设置当前训练阶段""" with self._lock: self.current_stage = stage self.stage_progress = progress stage_info = { 'stage': stage, 'progress': progress, 'timestamp': datetime.now().isoformat() } self._broadcast_progress("stage_update", stage_info) def finish_training(self, success: bool = True, error_message: str = None): """完成训练""" with self._lock: self.is_training = False self.is_completed = success if success: self.current_stage = "completed" self.stage_progress = 100.0 else: self.current_stage = "failed" finish_info = { 'success': success, 'error_message': error_message, 'total_duration': time.time() - self.start_time if self.start_time else 0, 'total_epochs_completed': self.current_epoch, 'final_loss': self.epoch_losses[-1]['train_loss'] if self.epoch_losses else None } self._broadcast_progress("training_finished", finish_info) def cancel_training(self): """取消训练""" with self._lock: self.is_cancelled = True self.is_training = False self.current_stage = "cancelled" self._broadcast_progress("training_cancelled") def _calculate_metrics(self, current_loss: float, learning_rate: float) -> TrainingMetrics: """计算训练指标""" current_time = time.time() # 计算平均损失 avg_loss = sum(self.current_epoch_losses) / len(self.current_epoch_losses) if self.current_epoch_losses else current_loss # 计算时间相关指标 epoch_duration = current_time - self.epoch_start_time if self.epoch_start_time else 0 total_duration = current_time - self.start_time if self.start_time else 0 # 计算速度指标 batches_per_second = self.current_batch / epoch_duration if epoch_duration > 0 else 0 samples_per_second = batches_per_second * self.batch_size # 计算预估时间 if batches_per_second > 0: remaining_batches_current_epoch = self.total_batches_per_epoch - self.current_batch eta_current_epoch = remaining_batches_current_epoch / batches_per_second else: eta_current_epoch = 0 # 基于历史轮次时间预估总剩余时间 if self.epoch_times: avg_epoch_time = sum(self.epoch_times) / len(self.epoch_times) remaining_epochs = self.total_epochs - self.current_epoch - 1 eta_total = eta_current_epoch + (remaining_epochs * avg_epoch_time) else: # 基于当前轮次进度估算 if epoch_duration > 0 and self.current_batch > 0: estimated_epoch_time = epoch_duration * (self.total_batches_per_epoch / self.current_batch) remaining_epochs = self.total_epochs - self.current_epoch - 1 eta_total = eta_current_epoch + (remaining_epochs * estimated_epoch_time) else: eta_total = 0 return TrainingMetrics( epoch=self.current_epoch, total_epochs=self.total_epochs, batch=self.current_batch, total_batches=self.total_batches_per_epoch, current_loss=current_loss, avg_loss=avg_loss, learning_rate=learning_rate, epoch_start_time=self.epoch_start_time or 0, epoch_duration=epoch_duration, total_duration=total_duration, batches_per_second=batches_per_second, samples_per_second=samples_per_second, eta_current_epoch=eta_current_epoch, eta_total=eta_total, stage=self.current_stage, stage_progress=self.stage_progress ) def _broadcast_progress(self, event_type: str, data: Any = None): """广播进度更新""" if not self.websocket_callback: return try: message = { 'event_type': event_type, 'training_id': self.training_id, 'product_id': self.product_id, 'model_type': self.model_type, 'training_mode': self.training_mode, 'timestamp': datetime.now().isoformat(), 'data': data } # 如果data是TrainingMetrics对象,转换为字典 if isinstance(data, TrainingMetrics): message['data'] = { 'epoch': data.epoch, 'total_epochs': data.total_epochs, 'batch': data.batch, 'total_batches': data.total_batches, 'current_loss': round(data.current_loss, 6), 'avg_loss': round(data.avg_loss, 6), 'learning_rate': data.learning_rate, 'epoch_duration': round(data.epoch_duration, 2), 'total_duration': round(data.total_duration, 2), 'batches_per_second': round(data.batches_per_second, 2), 'samples_per_second': round(data.samples_per_second, 0), 'eta_current_epoch': round(data.eta_current_epoch, 1), 'eta_total': round(data.eta_total, 1), 'stage': data.stage, 'stage_progress': round(data.stage_progress, 1), 'overall_progress': round((data.epoch / data.total_epochs) * 100, 1) } self.websocket_callback(message) except Exception as e: print(f"Broadcast failed: {e}") def get_current_status(self) -> Dict[str, Any]: """获取当前训练状态""" with self._lock: if not self.is_training and not self.is_completed: return {'status': 'idle'} current_loss = self.current_epoch_losses[-1] if self.current_epoch_losses else 0 metrics = self._calculate_metrics(current_loss, 0.001) return { 'status': 'training' if self.is_training else ('completed' if self.is_completed else 'idle'), 'training_id': self.training_id, 'product_id': self.product_id, 'model_type': self.model_type, 'training_mode': self.training_mode, 'current_epoch': self.current_epoch, 'total_epochs': self.total_epochs, 'current_batch': self.current_batch, 'total_batches': self.total_batches_per_epoch, 'current_stage': self.current_stage, 'stage_progress': self.stage_progress, 'overall_progress': (self.current_epoch / self.total_epochs) * 100 if self.total_epochs > 0 else 0, 'eta_total': metrics.eta_total if hasattr(metrics, 'eta_total') else 0, 'is_cancelled': self.is_cancelled } # 全局进度管理器实例 progress_manager = TrainingProgressManager()