ShopTRAINING/models/utils.py

48 lines
1.2 KiB
Python
Raw Normal View History

import torch
def get_device():
"""
获取可用的设备GPU或CPU
返回:
torch.device: 可用的设备
"""
if torch.cuda.is_available():
return torch.device('cuda')
else:
return torch.device('cpu')
def to_device(data, device):
"""
将数据移动到指定设备
参数:
data: 要移动的数据可以是张量列表元组或字典
device: 目标设备
返回:
移动到设备上的数据
"""
if isinstance(data, (list, tuple)):
return [to_device(x, device) for x in data]
elif isinstance(data, dict):
return {k: to_device(v, device) for k, v in data.items()}
elif isinstance(data, torch.Tensor):
return data.to(device)
else:
return data
class DeviceDataLoader:
"""
将DataLoader的数据移动到指定设备的包装器
"""
def __init__(self, dataloader, device):
self.dataloader = dataloader
self.device = device
def __iter__(self):
for batch in self.dataloader:
yield to_device(batch, self.device)
def __len__(self):
return len(self.dataloader)