import torch def get_device(): """ 获取可用的设备(GPU或CPU) 返回: torch.device: 可用的设备 """ if torch.cuda.is_available(): return torch.device('cuda') else: return torch.device('cpu') def to_device(data, device): """ 将数据移动到指定设备 参数: data: 要移动的数据(可以是张量、列表、元组或字典) device: 目标设备 返回: 移动到设备上的数据 """ if isinstance(data, (list, tuple)): return [to_device(x, device) for x in data] elif isinstance(data, dict): return {k: to_device(v, device) for k, v in data.items()} elif isinstance(data, torch.Tensor): return data.to(device) else: return data class DeviceDataLoader: """ 将DataLoader的数据移动到指定设备的包装器 """ def __init__(self, dataloader, device): self.dataloader = dataloader self.device = device def __iter__(self): for batch in self.dataloader: yield to_device(batch, self.device) def __len__(self): return len(self.dataloader)