ShopTRAINING/server/utils/training_progress.py

340 lines
13 KiB
Python
Raw Normal View History

2025-07-02 11:05:23 +08:00
"""
训练进度管理器
提供实时训练进度跟踪速度计算和时间预估
"""
import time
import threading
from typing import Optional, Dict, Any, Callable
from dataclasses import dataclass
from datetime import datetime, timedelta
import json
@dataclass
class TrainingMetrics:
"""训练指标数据类"""
epoch: int
total_epochs: int
batch: int
total_batches: int
current_loss: float
avg_loss: float
learning_rate: float
# 时间相关
epoch_start_time: float
epoch_duration: float
total_duration: float
# 速度指标
batches_per_second: float
samples_per_second: float
# 预估时间
eta_current_epoch: float # 当前轮次剩余时间
eta_total: float # 总剩余时间
# 阶段信息
stage: str # 'data_loading', 'training', 'validation', 'saving'
stage_progress: float # 当前阶段进度 0-100
class TrainingProgressManager:
"""训练进度管理器"""
def __init__(self, websocket_callback: Optional[Callable] = None):
"""
初始化进度管理器
Args:
websocket_callback: WebSocket回调函数用于实时推送进度
"""
self.websocket_callback = websocket_callback
self._lock = threading.Lock()
self.reset()
def reset(self):
"""重置所有进度信息"""
with self._lock:
self.training_id = None
self.product_id = None
self.model_type = None
self.training_mode = None
# 训练配置
self.total_epochs = 0
self.total_batches_per_epoch = 0
self.batch_size = 0
self.total_samples = 0
# 当前状态
self.current_epoch = 0
self.current_batch = 0
self.current_stage = "preparing"
self.stage_progress = 0.0
# 时间跟踪
self.start_time = None
self.epoch_start_time = None
self.batch_times = []
self.epoch_times = []
# 损失跟踪
self.epoch_losses = []
self.current_epoch_losses = []
# 状态标志
self.is_training = False
self.is_cancelled = False
self.is_completed = False
def start_training(self, training_id: str, product_id: str, model_type: str,
training_mode: str, total_epochs: int, total_batches: int,
batch_size: int, total_samples: int):
"""开始训练"""
with self._lock:
self.reset()
self.training_id = training_id
self.product_id = product_id
self.model_type = model_type
self.training_mode = training_mode
self.total_epochs = total_epochs
self.total_batches_per_epoch = total_batches
self.batch_size = batch_size
self.total_samples = total_samples
self.start_time = time.time()
self.is_training = True
self._broadcast_progress("training_started")
def start_epoch(self, epoch: int):
"""开始新的训练轮次"""
with self._lock:
self.current_epoch = epoch
self.current_batch = 0
self.epoch_start_time = time.time()
self.current_epoch_losses = []
self.current_stage = "training"
self.stage_progress = 0.0
self._broadcast_progress("epoch_started")
def update_batch(self, batch: int, loss: float, learning_rate: float = 0.001):
"""更新批次进度"""
with self._lock:
if not self.is_training:
return
self.current_batch = batch
self.current_epoch_losses.append(loss)
# 计算当前阶段进度
self.stage_progress = (batch / self.total_batches_per_epoch) * 100
# 记录批次时间
current_time = time.time()
if self.epoch_start_time:
batch_duration = current_time - self.epoch_start_time
self.batch_times.append(batch_duration / (batch + 1))
# 计算训练指标
metrics = self._calculate_metrics(loss, learning_rate)
# 每10个批次或最后一个批次广播一次
if batch % 10 == 0 or batch == self.total_batches_per_epoch - 1:
self._broadcast_progress("batch_update", metrics)
def finish_epoch(self, epoch_loss: float, validation_loss: Optional[float] = None):
"""完成当前轮次"""
with self._lock:
if not self.is_training:
return
# 记录轮次时间
if self.epoch_start_time:
epoch_duration = time.time() - self.epoch_start_time
self.epoch_times.append(epoch_duration)
# 记录损失
self.epoch_losses.append({
'epoch': self.current_epoch,
'train_loss': epoch_loss,
'validation_loss': validation_loss,
'timestamp': datetime.now().isoformat()
})
metrics = self._calculate_metrics(epoch_loss, 0.001)
self._broadcast_progress("epoch_completed", metrics)
def set_stage(self, stage: str, progress: float = 0.0):
"""设置当前训练阶段"""
with self._lock:
self.current_stage = stage
self.stage_progress = progress
stage_info = {
'stage': stage,
'progress': progress,
'timestamp': datetime.now().isoformat()
}
self._broadcast_progress("stage_update", stage_info)
def finish_training(self, success: bool = True, error_message: str = None):
"""完成训练"""
with self._lock:
self.is_training = False
self.is_completed = success
if success:
self.current_stage = "completed"
self.stage_progress = 100.0
else:
self.current_stage = "failed"
finish_info = {
'success': success,
'error_message': error_message,
'total_duration': time.time() - self.start_time if self.start_time else 0,
'total_epochs_completed': self.current_epoch,
'final_loss': self.epoch_losses[-1]['train_loss'] if self.epoch_losses else None
}
self._broadcast_progress("training_finished", finish_info)
def cancel_training(self):
"""取消训练"""
with self._lock:
self.is_cancelled = True
self.is_training = False
self.current_stage = "cancelled"
self._broadcast_progress("training_cancelled")
def _calculate_metrics(self, current_loss: float, learning_rate: float) -> TrainingMetrics:
"""计算训练指标"""
current_time = time.time()
# 计算平均损失
avg_loss = sum(self.current_epoch_losses) / len(self.current_epoch_losses) if self.current_epoch_losses else current_loss
# 计算时间相关指标
epoch_duration = current_time - self.epoch_start_time if self.epoch_start_time else 0
total_duration = current_time - self.start_time if self.start_time else 0
# 计算速度指标
batches_per_second = self.current_batch / epoch_duration if epoch_duration > 0 else 0
samples_per_second = batches_per_second * self.batch_size
# 计算预估时间
if batches_per_second > 0:
remaining_batches_current_epoch = self.total_batches_per_epoch - self.current_batch
eta_current_epoch = remaining_batches_current_epoch / batches_per_second
else:
eta_current_epoch = 0
# 基于历史轮次时间预估总剩余时间
if self.epoch_times:
avg_epoch_time = sum(self.epoch_times) / len(self.epoch_times)
remaining_epochs = self.total_epochs - self.current_epoch - 1
eta_total = eta_current_epoch + (remaining_epochs * avg_epoch_time)
else:
# 基于当前轮次进度估算
if epoch_duration > 0 and self.current_batch > 0:
estimated_epoch_time = epoch_duration * (self.total_batches_per_epoch / self.current_batch)
remaining_epochs = self.total_epochs - self.current_epoch - 1
eta_total = eta_current_epoch + (remaining_epochs * estimated_epoch_time)
else:
eta_total = 0
return TrainingMetrics(
epoch=self.current_epoch,
total_epochs=self.total_epochs,
batch=self.current_batch,
total_batches=self.total_batches_per_epoch,
current_loss=current_loss,
avg_loss=avg_loss,
learning_rate=learning_rate,
epoch_start_time=self.epoch_start_time or 0,
epoch_duration=epoch_duration,
total_duration=total_duration,
batches_per_second=batches_per_second,
samples_per_second=samples_per_second,
eta_current_epoch=eta_current_epoch,
eta_total=eta_total,
stage=self.current_stage,
stage_progress=self.stage_progress
)
def _broadcast_progress(self, event_type: str, data: Any = None):
"""广播进度更新"""
if not self.websocket_callback:
return
try:
message = {
'event_type': event_type,
'training_id': self.training_id,
'product_id': self.product_id,
'model_type': self.model_type,
'training_mode': self.training_mode,
'timestamp': datetime.now().isoformat(),
'data': data
}
# 如果data是TrainingMetrics对象转换为字典
if isinstance(data, TrainingMetrics):
message['data'] = {
'epoch': data.epoch,
'total_epochs': data.total_epochs,
'batch': data.batch,
'total_batches': data.total_batches,
'current_loss': round(data.current_loss, 6),
'avg_loss': round(data.avg_loss, 6),
'learning_rate': data.learning_rate,
'epoch_duration': round(data.epoch_duration, 2),
'total_duration': round(data.total_duration, 2),
'batches_per_second': round(data.batches_per_second, 2),
'samples_per_second': round(data.samples_per_second, 0),
'eta_current_epoch': round(data.eta_current_epoch, 1),
'eta_total': round(data.eta_total, 1),
'stage': data.stage,
'stage_progress': round(data.stage_progress, 1),
'overall_progress': round((data.epoch / data.total_epochs) * 100, 1)
}
self.websocket_callback(message)
except Exception as e:
print(f"Broadcast failed: {e}")
def get_current_status(self) -> Dict[str, Any]:
"""获取当前训练状态"""
with self._lock:
if not self.is_training and not self.is_completed:
return {'status': 'idle'}
current_loss = self.current_epoch_losses[-1] if self.current_epoch_losses else 0
metrics = self._calculate_metrics(current_loss, 0.001)
return {
'status': 'training' if self.is_training else ('completed' if self.is_completed else 'idle'),
'training_id': self.training_id,
'product_id': self.product_id,
'model_type': self.model_type,
'training_mode': self.training_mode,
'current_epoch': self.current_epoch,
'total_epochs': self.total_epochs,
'current_batch': self.current_batch,
'total_batches': self.total_batches_per_epoch,
'current_stage': self.current_stage,
'stage_progress': self.stage_progress,
'overall_progress': (self.current_epoch / self.total_epochs) * 100 if self.total_epochs > 0 else 0,
'eta_total': metrics.eta_total if hasattr(metrics, 'eta_total') else 0,
'is_cancelled': self.is_cancelled
}
# 全局进度管理器实例
progress_manager = TrainingProgressManager()