互联网爱好者创业的站长之家 – 南方站长网
您的位置:首页 >资讯 >

在 PyTorch 中使用 Datasets 和 DataLoader 自定义数据

时间:2021-07-16 09:45:36 | 来源:TechWeb

原标题:在 PyTorch 中使用 Datasets 和 DataLoader 自定义数据

有时候,在处理大数据集时,一次将整个数据加载到内存中变得非常难。

因此,唯一的方法是将数据分批加载到内存中进行处理,这需要编写额外的代码来执行此操作。对此,PyTorch 已经提供了 Dataloader 功能。

DataLoader

下面显示了 PyTorch 库中DataLoader函数的语法及其参数信息。

DataLoader(dataset,batch_size=1,shuffle=False,sampler=None,batch_sampler=None,num_workers=0,collate_fn=None,pin_memory=False,drop_last=False,timeout=0,worker_init_fn=None,*,prefetch_factor=2,persistent_workers=False)

几个重要参数

dataset:必须首先使用数据集构造 DataLoader 类。 Shuffle :是否重新整理数据。 Sampler :指的是可选的 torch.utils.data.Sampler 类实例。采样器定义了检索样本的策略,顺序或随机或任何其他方式。使用采样器时应将 Shuffle 设置为 false。 Batch_Sampler :批处理级别。 num_workers :加载数据所需的子进程数。 collate_fn :将样本整理成批次。Torch 中可以进行自定义整理。 加载内置 MNIST 数据集

MNIST 是一个著名的包含手写数字的数据集。下面介绍如何使用DataLoader功能处理 PyTorch 的内置 MNIST 数据集。

上面代码,导入了 torchvision 的torch计算机视觉模块。通常在处理图像数据集时使用,并且可以帮助对图像进行规范化、调整大小和裁剪。

对于 MNIST 数据集,下面使用了归一化技术。

ToTensor()能够把灰度范围从0-255变换到0-1之间。

transform=transforms.Compose([transforms.ToTensor()])

下面代码用于加载所需的数据集。使用 PyTorchDataLoader通过给定 batch_size = 64来加载数据。shuffle=True打乱数据。

trainset=datasets.MNIST('~/.pytorch/MNIST_data/',download=True,train=True,transform=transform)trainloader=torch.utils.data.DataLoader(trainset,batch_size=64,shuffle=True)

为了获取数据集的所有图像,一般使用iter函数和数据加载器DataLoader。

dataiter=iter(trainloader)images,labels=dataiter.next()print(images.shape)print(labels.shape)plt.imshow(images[1].numpy().squeeze(),cmap='Greys_r')

自定义数据集

下面的代码创建一个包含 1000 个随机数的自定义数据集。

fromtorch.utils.dataimportDatasetimportrandomclassSampleDataset(Dataset):def__init__(self,r1,r2):randomlist=[]foriinrange(120):n=random.randint(r1,r2)randomlist.append(n)self.samples=randomlistdef__len__(self):returnlen(self.samples)def__getitem__(self,idx):return(self.samples[idx])dataset=SampleDataset(1,100)dataset[100:120]

在这里插入图片描述在这里插入图片描述

最后,将在自定义数据集上使用 dataloader 函数。将 batch_size 设为 12,并且还启用了num_workers =2 的并行多进程数据加载。

fromtorch.utils.dataimportDataLoaderloader=DataLoader(dataset,batch_size=12,shuffle=True,num_workers=2)fori,batchinenumerate(loader):print(i,batch)

写在后面通过几个示例了解了 PyTorch Dataloader 在将大量数据批量加载到内存中的作用。

郑重声明:本文版权归原作者所有,转载文章仅为传播更多信息之目的,如有侵权行为,请第一时间联系我们修改或删除,多谢。

猜你喜欢