数据读取
Torchvision 中默认使用的图像加载器是 PIL,因此为了确保 Torchvision 正常运行,我们还需要安装一个 Python 的第三方图像处理库 ——Pillow 库。Pillow 提供了广泛的文件格式支持,强大的图像处理能力,主要包括图像储存、图像显示、格式转换以及基本的图像处理操作等。
我们先介绍 Torchvision 的常用数据集及其读取方法。
PyTorch 为我们提供了一种十分方便的数据读取机制,即使用 Dataset 类与 DataLoader 类的组合,来得到数据迭代器。在训练或预测时,数据迭代器能够输出每一批次所需的数据,并且对数据进行相应的预处理与数据增强操作。
下面我们分别来看下 Dataset 类与 DataLoader 类。
Dataset 类
PyTorch 中的 Dataset 类是一个抽象类,它可以用来表示数据集。我们通过继承 Dataset 类来自定义数据集的格式、大小和其它属性,后面就可以供 DataLoader 类直接使用。
其实这就表示,无论使用自定义的数据集,还是官方为我们封装好的数据集,其本质都是继承了 Dataset 类。而在继承 Dataset 类时,至少需要重写以下几个方法:
- init ():构造函数,可自定义数据读取方法以及进行数据预处理;
- len ():返回数据集大小;
- getitem ():索引数据集中的某一个数据。
下面我们来编写一个简单的例子,看下如何使用 Dataset 类定义一个 Tensor 类型的数据集。
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
# 构造函数
def __init__(self, data_tensor, target_tensor):
self.data_tensor = data_tensor
self.target_tensor = target_tensor
# 返回数据集大小
def __len__(self):
return self.data_tensor.size(0)
# 返回索引的数据与标签
def __getitem__(self, index):
return self.data_tensor[index], self.target_tensor[index]
'''
我们定义了一个名字为 MyDataset 的数据集,在构造函数中,传入 Tensor 类型的数据与标签;
在 __len__ 函数中,直接返回 Tensor 的大小;在 __getitem__ 函数中返回索引的数据与标签。
'''
然后我们来看一下如何调用刚才定义的数据集。首先随机生成一个 10*3 维的数据 Tensor,然后生成 10 维的标签 Tensor,与数据 Tensor 相对应。利用这两个 Tensor,生成一个 MyDataset 的对象。查看数据集的大小可以直接用 len () 函数,索引调用数据可以直接使用下标。
# 生成数据
data_tensor = torch.randn(10, 3)
target_tensor = torch.randint(2, (10,)) # 标签是0或1
# 生成10个随机数,随机数的范围只能是0或者1
# 将数据封装成Dataset
my_dataset = MyDataset(data_tensor, target_tensor)
# 查看数据集大小
print('Dataset size:', len(my_dataset))
# 使用索引调用数据
print('tensor_data[0]: ', my_dataset[0])
DataLoader 类
在实际项目中,如果数据量很大,考虑到内存有限、I/O 速度等问题,在训练过程中不可能一次性的将所有数据全部加载到内存中,也不能只用一个进程去加载,所以就需要多进程、迭代加载,而 DataLoader 就是基于这些需要被设计出来的。
DataLoader 是一个迭代器,最基本的使用方法就是传入一个 Dataset 对象,它会根据参数 batch_size 的值生成一个 batch 的数据,节省内存的同时,它还可以实现多进程、数据打乱等处理。
DataLoader 类的调用方式如下:
from torch.utils.data import DataLoader
tensor_dataloader = DataLoader(dataset=my_dataset, # 传入的数据集, 必须参数
batch_size=2, # 输出的batch大小
shuffle=True, # 数据是否打乱
num_workers=0) # 进程数, 0表示只有主进程
# 以循环形式输出
for data, target in tensor_dataloader:
print(data, target)
'''
输出:
tensor([[-0.1781, -1.1019, -0.1507],
[-0.6170, 0.2366, 0.1006]]) tensor([0, 0])
tensor([[ 0.9451, -0.4923, -1.8178],
[-0.4046, -0.5436, -1.7911]]) tensor([0, 0])
tensor([[-0.4561, -1.2480, -0.3051],
[-0.9738, 0.9465, 0.4812]]) tensor([1, 0])
tensor([[ 0.0260, 1.5276, 0.1687],
[ 1.3692, -0.0170, -1.6831]]) tensor([1, 0])
tensor([[ 0.0515, -0.8892, -0.1699],
[ 0.4931, -0.0697, 0.4171]]) tensor([1, 0])
'''
# 输出一个batch(用iter()强制类型转换成迭代器的对象,next()是输出迭代器下一个元素)
print('One batch tensor data: ', iter(tensor_dataloader).next())
'''
输出:
One batch tensor data: [tensor([[ 0.9451, -0.4923, -1.8178],
[-0.4046, -0.5436, -1.7911]]), tensor([0, 0])]
'''
结合代码,我们梳理一下 DataLoader 中的几个参数,它们分别表示:
- dataset:Dataset 类型,输入的数据集,必须参数;
- batch_size:int 类型,每个 batch 有多少个样本;
- shuffle:bool 类型,在每个 epoch 开始的时候,是否对数据进行重新打乱;
- num_workers:int 类型,加载数据的进程数,0 意味着所有的数据都会被加载进主进程,默认为 0。
思考题
按照上述代码,One batch tensor data 的输出是否正确,若不正确,为什么?
利用 Torchvision 读取数据
Torchvision 库中的 torchvision.datasets 包中提供了丰富的图像数据集的接口。常用的图像数据集,例如 MNIST、COCO 等,这个模块都为我们做了相应的封装。
下表中列出了 torchvision.datasets 包所有支持的数据集。各个数据集的说明与接口,详见链接 https://pytorch.org/vision/stable/datasets.html。
注意,torchvision.datasets 这个包本身并不包含数据集的文件本身,它的工作方式是先从网络上把数据集下载到用户指定目录,然后再用它的加载器把数据集加载到内存中。最后,把这个加载后的数据集作为对象返回给用户。
为了让你进一步加深对知识的理解,我们以 MNIST 数据集为例,来说明一下这个模块具体的使用方法。
MNIST 数据集简介
MNIST 数据集是一个著名的手写数字数据集,因为上手简单,在深度学习领域,手写数字识别是一个很经典的学习入门样例。
MNIST 数据集是 NIST 数据集的一个子集,MNIST 数据集你可以通过这里下载。它包含了四个部分。
MNIST 数据集是 ubyte 格式存储,我们先将 “训练集图片” 解析成图片格式,来直观地看一看数据集具体是什么样子的。具体怎么解析,我们在后面数据预览再展开。
接下来,我们看一下如何使用 Torchvision 来读取 MNIST 数据集。
对于 torchvision.datasets 所支持的所有数据集,它都内置了相应的数据集接口。例如刚才介绍的 MNIST 数据集,torchvision.datasets 就有一个 MNIST 的接口,接口内封装了从下载、解压缩、读取数据、解析数据等全部过程。
这些接口的工作方式差不多,都是先从网络上把数据集下载到指定目录,然后再用加载器把数据集加载到内存中,最后将加载后的数据集作为对象返回给用户。
以 MNIST 为例,我们可以用如下方式调用:
# 以MNIST为例
import torchvision
mnist_dataset = torchvision.datasets.MNIST(root='./data',
train=True,
transform=None,
target_transform=None,
download=True)
torchvision.datasets.MNIST 是一个类,对它进行实例化,即可返回一个 MNIST 数据集对象。构造函数包括包含 5 个参数:
- root:是一个字符串,用于指定你想要保存 MNIST 数据集的位置。如果 download 是 Flase,则会从目标位置读取数据集;
- download:布尔类型,表示是否下载数据集。如果为 True,则会自动从网上下载这个数据集,存储到 root 指定的位置。如果指定位置已经存在数据集文件,则不会重复下载;
- train:布尔类型,表示是否加载训练集数据。如果为 True,则只加载训练数据。如果为 False,则只加载测试数据集。这里需要注意,并不是所有的数据集都做了训练集和测试集的划分,这个参数并不一定是有效参数,具体需要参考官方接口说明文档;
- transform:用于对图像进行预处理操作,例如数据增强、归一化、旋转或缩放等。这些操作我们会在下节课展开讲解;
- target_transform:用于对图像标签进行预处理操作。
运行上述的代码后,程序会首先去指定的网址下载了 MNIST 数据集,然后进行了解压缩等操作。如果你再次运行相同的代码,则不会再有下载的过程。
如果你用 type 函数查看一下 mnist_dataset 的类型,就可以得到 torchvision.datasets.mnist.MNIST ,而这个类是之前我们介绍过的 Dataset 类的派生类。相当于 torchvision.datasets ,它已经帮我们写好了对 Dataset 类的继承,完成了对数据集的封装,我们直接使用即可。
这里我们主要以 MNIST 为例,进行了说明。其它的数据集使用方法类似,调用的时候你只要需要将类名 “MNIST” 换成其它数据集名字即可。
数据预览
完成了数据读取工作,我们得到的是对应的 mnist_dataset,刚才已经讲过了,这是一个封装了的数据集。
如果想要查看 mnist_dataset 中的具体内容,我们需要把它转化为列表。(如果 IOPub data rate 超限,可以只加载测试集数据,令 train=False)
mnist_dataset_list = list(mnist_dataset)
print(mnist_dataset_list)
转换后的数据集对象变成了一个元组列表,每个元组有两个元素,第一个元素是图像数据,第二个元素是图像的标签。
这里图像数据是 PIL.Image.Image 类型的,这种类型可以直接在 Jupyter 中显示出来。显示一条数据的代码如下。
display(mnist_dataset_list[0][0])
print("Image label is:", mnist_dataset_list[0][1])
上面介绍了两种读取数据的方法,也就是自定义和读取常用图像数据集。最通用的数据读取方法,就是自己定义一个 Dataset 的派生类。而读取常用的图像数据集,就可以利用 PyTorch 提供的视觉包 Torchvision。
极客时间版权所有: https://time.geekbang.org/column/article/429826
(有删改)