48 lines
1.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import torch
def get_device():
"""
获取可用的设备GPU或CPU
返回:
torch.device: 可用的设备
"""
if torch.cuda.is_available():
return torch.device('cuda')
else:
return torch.device('cpu')
def to_device(data, device):
"""
将数据移动到指定设备
参数:
data: 要移动的数据(可以是张量、列表、元组或字典)
device: 目标设备
返回:
移动到设备上的数据
"""
if isinstance(data, (list, tuple)):
return [to_device(x, device) for x in data]
elif isinstance(data, dict):
return {k: to_device(v, device) for k, v in data.items()}
elif isinstance(data, torch.Tensor):
return data.to(device)
else:
return data
class DeviceDataLoader:
"""
将DataLoader的数据移动到指定设备的包装器
"""
def __init__(self, dataloader, device):
self.dataloader = dataloader
self.device = device
def __iter__(self):
for batch in self.dataloader:
yield to_device(batch, self.device)
def __len__(self):
return len(self.dataloader)