48 lines
1.2 KiB
Python
48 lines
1.2 KiB
Python
import torch
|
||
|
||
def get_device():
|
||
"""
|
||
获取可用的设备(GPU或CPU)
|
||
|
||
返回:
|
||
torch.device: 可用的设备
|
||
"""
|
||
if torch.cuda.is_available():
|
||
return torch.device('cuda')
|
||
else:
|
||
return torch.device('cpu')
|
||
|
||
def to_device(data, device):
|
||
"""
|
||
将数据移动到指定设备
|
||
|
||
参数:
|
||
data: 要移动的数据(可以是张量、列表、元组或字典)
|
||
device: 目标设备
|
||
|
||
返回:
|
||
移动到设备上的数据
|
||
"""
|
||
if isinstance(data, (list, tuple)):
|
||
return [to_device(x, device) for x in data]
|
||
elif isinstance(data, dict):
|
||
return {k: to_device(v, device) for k, v in data.items()}
|
||
elif isinstance(data, torch.Tensor):
|
||
return data.to(device)
|
||
else:
|
||
return data
|
||
|
||
class DeviceDataLoader:
|
||
"""
|
||
将DataLoader的数据移动到指定设备的包装器
|
||
"""
|
||
def __init__(self, dataloader, device):
|
||
self.dataloader = dataloader
|
||
self.device = device
|
||
|
||
def __iter__(self):
|
||
for batch in self.dataloader:
|
||
yield to_device(batch, self.device)
|
||
|
||
def __len__(self):
|
||
return len(self.dataloader) |