ShopTRAINING/server/utils/training_progress.py
2025-07-02 11:05:23 +08:00

340 lines
13 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
训练进度管理器
提供实时训练进度跟踪、速度计算和时间预估
"""
import time
import threading
from typing import Optional, Dict, Any, Callable
from dataclasses import dataclass
from datetime import datetime, timedelta
import json
@dataclass
class TrainingMetrics:
"""训练指标数据类"""
epoch: int
total_epochs: int
batch: int
total_batches: int
current_loss: float
avg_loss: float
learning_rate: float
# 时间相关
epoch_start_time: float
epoch_duration: float
total_duration: float
# 速度指标
batches_per_second: float
samples_per_second: float
# 预估时间
eta_current_epoch: float # 当前轮次剩余时间
eta_total: float # 总剩余时间
# 阶段信息
stage: str # 'data_loading', 'training', 'validation', 'saving'
stage_progress: float # 当前阶段进度 0-100
class TrainingProgressManager:
"""训练进度管理器"""
def __init__(self, websocket_callback: Optional[Callable] = None):
"""
初始化进度管理器
Args:
websocket_callback: WebSocket回调函数用于实时推送进度
"""
self.websocket_callback = websocket_callback
self._lock = threading.Lock()
self.reset()
def reset(self):
"""重置所有进度信息"""
with self._lock:
self.training_id = None
self.product_id = None
self.model_type = None
self.training_mode = None
# 训练配置
self.total_epochs = 0
self.total_batches_per_epoch = 0
self.batch_size = 0
self.total_samples = 0
# 当前状态
self.current_epoch = 0
self.current_batch = 0
self.current_stage = "preparing"
self.stage_progress = 0.0
# 时间跟踪
self.start_time = None
self.epoch_start_time = None
self.batch_times = []
self.epoch_times = []
# 损失跟踪
self.epoch_losses = []
self.current_epoch_losses = []
# 状态标志
self.is_training = False
self.is_cancelled = False
self.is_completed = False
def start_training(self, training_id: str, product_id: str, model_type: str,
training_mode: str, total_epochs: int, total_batches: int,
batch_size: int, total_samples: int):
"""开始训练"""
with self._lock:
self.reset()
self.training_id = training_id
self.product_id = product_id
self.model_type = model_type
self.training_mode = training_mode
self.total_epochs = total_epochs
self.total_batches_per_epoch = total_batches
self.batch_size = batch_size
self.total_samples = total_samples
self.start_time = time.time()
self.is_training = True
self._broadcast_progress("training_started")
def start_epoch(self, epoch: int):
"""开始新的训练轮次"""
with self._lock:
self.current_epoch = epoch
self.current_batch = 0
self.epoch_start_time = time.time()
self.current_epoch_losses = []
self.current_stage = "training"
self.stage_progress = 0.0
self._broadcast_progress("epoch_started")
def update_batch(self, batch: int, loss: float, learning_rate: float = 0.001):
"""更新批次进度"""
with self._lock:
if not self.is_training:
return
self.current_batch = batch
self.current_epoch_losses.append(loss)
# 计算当前阶段进度
self.stage_progress = (batch / self.total_batches_per_epoch) * 100
# 记录批次时间
current_time = time.time()
if self.epoch_start_time:
batch_duration = current_time - self.epoch_start_time
self.batch_times.append(batch_duration / (batch + 1))
# 计算训练指标
metrics = self._calculate_metrics(loss, learning_rate)
# 每10个批次或最后一个批次广播一次
if batch % 10 == 0 or batch == self.total_batches_per_epoch - 1:
self._broadcast_progress("batch_update", metrics)
def finish_epoch(self, epoch_loss: float, validation_loss: Optional[float] = None):
"""完成当前轮次"""
with self._lock:
if not self.is_training:
return
# 记录轮次时间
if self.epoch_start_time:
epoch_duration = time.time() - self.epoch_start_time
self.epoch_times.append(epoch_duration)
# 记录损失
self.epoch_losses.append({
'epoch': self.current_epoch,
'train_loss': epoch_loss,
'validation_loss': validation_loss,
'timestamp': datetime.now().isoformat()
})
metrics = self._calculate_metrics(epoch_loss, 0.001)
self._broadcast_progress("epoch_completed", metrics)
def set_stage(self, stage: str, progress: float = 0.0):
"""设置当前训练阶段"""
with self._lock:
self.current_stage = stage
self.stage_progress = progress
stage_info = {
'stage': stage,
'progress': progress,
'timestamp': datetime.now().isoformat()
}
self._broadcast_progress("stage_update", stage_info)
def finish_training(self, success: bool = True, error_message: str = None):
"""完成训练"""
with self._lock:
self.is_training = False
self.is_completed = success
if success:
self.current_stage = "completed"
self.stage_progress = 100.0
else:
self.current_stage = "failed"
finish_info = {
'success': success,
'error_message': error_message,
'total_duration': time.time() - self.start_time if self.start_time else 0,
'total_epochs_completed': self.current_epoch,
'final_loss': self.epoch_losses[-1]['train_loss'] if self.epoch_losses else None
}
self._broadcast_progress("training_finished", finish_info)
def cancel_training(self):
"""取消训练"""
with self._lock:
self.is_cancelled = True
self.is_training = False
self.current_stage = "cancelled"
self._broadcast_progress("training_cancelled")
def _calculate_metrics(self, current_loss: float, learning_rate: float) -> TrainingMetrics:
"""计算训练指标"""
current_time = time.time()
# 计算平均损失
avg_loss = sum(self.current_epoch_losses) / len(self.current_epoch_losses) if self.current_epoch_losses else current_loss
# 计算时间相关指标
epoch_duration = current_time - self.epoch_start_time if self.epoch_start_time else 0
total_duration = current_time - self.start_time if self.start_time else 0
# 计算速度指标
batches_per_second = self.current_batch / epoch_duration if epoch_duration > 0 else 0
samples_per_second = batches_per_second * self.batch_size
# 计算预估时间
if batches_per_second > 0:
remaining_batches_current_epoch = self.total_batches_per_epoch - self.current_batch
eta_current_epoch = remaining_batches_current_epoch / batches_per_second
else:
eta_current_epoch = 0
# 基于历史轮次时间预估总剩余时间
if self.epoch_times:
avg_epoch_time = sum(self.epoch_times) / len(self.epoch_times)
remaining_epochs = self.total_epochs - self.current_epoch - 1
eta_total = eta_current_epoch + (remaining_epochs * avg_epoch_time)
else:
# 基于当前轮次进度估算
if epoch_duration > 0 and self.current_batch > 0:
estimated_epoch_time = epoch_duration * (self.total_batches_per_epoch / self.current_batch)
remaining_epochs = self.total_epochs - self.current_epoch - 1
eta_total = eta_current_epoch + (remaining_epochs * estimated_epoch_time)
else:
eta_total = 0
return TrainingMetrics(
epoch=self.current_epoch,
total_epochs=self.total_epochs,
batch=self.current_batch,
total_batches=self.total_batches_per_epoch,
current_loss=current_loss,
avg_loss=avg_loss,
learning_rate=learning_rate,
epoch_start_time=self.epoch_start_time or 0,
epoch_duration=epoch_duration,
total_duration=total_duration,
batches_per_second=batches_per_second,
samples_per_second=samples_per_second,
eta_current_epoch=eta_current_epoch,
eta_total=eta_total,
stage=self.current_stage,
stage_progress=self.stage_progress
)
def _broadcast_progress(self, event_type: str, data: Any = None):
"""广播进度更新"""
if not self.websocket_callback:
return
try:
message = {
'event_type': event_type,
'training_id': self.training_id,
'product_id': self.product_id,
'model_type': self.model_type,
'training_mode': self.training_mode,
'timestamp': datetime.now().isoformat(),
'data': data
}
# 如果data是TrainingMetrics对象转换为字典
if isinstance(data, TrainingMetrics):
message['data'] = {
'epoch': data.epoch,
'total_epochs': data.total_epochs,
'batch': data.batch,
'total_batches': data.total_batches,
'current_loss': round(data.current_loss, 6),
'avg_loss': round(data.avg_loss, 6),
'learning_rate': data.learning_rate,
'epoch_duration': round(data.epoch_duration, 2),
'total_duration': round(data.total_duration, 2),
'batches_per_second': round(data.batches_per_second, 2),
'samples_per_second': round(data.samples_per_second, 0),
'eta_current_epoch': round(data.eta_current_epoch, 1),
'eta_total': round(data.eta_total, 1),
'stage': data.stage,
'stage_progress': round(data.stage_progress, 1),
'overall_progress': round((data.epoch / data.total_epochs) * 100, 1)
}
self.websocket_callback(message)
except Exception as e:
print(f"Broadcast failed: {e}")
def get_current_status(self) -> Dict[str, Any]:
"""获取当前训练状态"""
with self._lock:
if not self.is_training and not self.is_completed:
return {'status': 'idle'}
current_loss = self.current_epoch_losses[-1] if self.current_epoch_losses else 0
metrics = self._calculate_metrics(current_loss, 0.001)
return {
'status': 'training' if self.is_training else ('completed' if self.is_completed else 'idle'),
'training_id': self.training_id,
'product_id': self.product_id,
'model_type': self.model_type,
'training_mode': self.training_mode,
'current_epoch': self.current_epoch,
'total_epochs': self.total_epochs,
'current_batch': self.current_batch,
'total_batches': self.total_batches_per_epoch,
'current_stage': self.current_stage,
'stage_progress': self.stage_progress,
'overall_progress': (self.current_epoch / self.total_epochs) * 100 if self.total_epochs > 0 else 0,
'eta_total': metrics.eta_total if hasattr(metrics, 'eta_total') else 0,
'is_cancelled': self.is_cancelled
}
# 全局进度管理器实例
progress_manager = TrainingProgressManager()