48 lines
1.2 KiB
Python
48 lines
1.2 KiB
Python
![]() |
import torch
|
|||
|
|
|||
|
def get_device():
|
|||
|
"""
|
|||
|
获取可用的设备(GPU或CPU)
|
|||
|
|
|||
|
返回:
|
|||
|
torch.device: 可用的设备
|
|||
|
"""
|
|||
|
if torch.cuda.is_available():
|
|||
|
return torch.device('cuda')
|
|||
|
else:
|
|||
|
return torch.device('cpu')
|
|||
|
|
|||
|
def to_device(data, device):
|
|||
|
"""
|
|||
|
将数据移动到指定设备
|
|||
|
|
|||
|
参数:
|
|||
|
data: 要移动的数据(可以是张量、列表、元组或字典)
|
|||
|
device: 目标设备
|
|||
|
|
|||
|
返回:
|
|||
|
移动到设备上的数据
|
|||
|
"""
|
|||
|
if isinstance(data, (list, tuple)):
|
|||
|
return [to_device(x, device) for x in data]
|
|||
|
elif isinstance(data, dict):
|
|||
|
return {k: to_device(v, device) for k, v in data.items()}
|
|||
|
elif isinstance(data, torch.Tensor):
|
|||
|
return data.to(device)
|
|||
|
else:
|
|||
|
return data
|
|||
|
|
|||
|
class DeviceDataLoader:
|
|||
|
"""
|
|||
|
将DataLoader的数据移动到指定设备的包装器
|
|||
|
"""
|
|||
|
def __init__(self, dataloader, device):
|
|||
|
self.dataloader = dataloader
|
|||
|
self.device = device
|
|||
|
|
|||
|
def __iter__(self):
|
|||
|
for batch in self.dataloader:
|
|||
|
yield to_device(batch, self.device)
|
|||
|
|
|||
|
def __len__(self):
|
|||
|
return len(self.dataloader)
|