数据预处理(torchvision)
不管我们的网络设计的有多复杂,选择什么样的优化器和损失函数,我们在训练模型时首先需要面对的是如何处理我们的数据。最简单的一个问题,我们需要怎么把数据拿过来送进我们的网络中呢 (数据读取)。还有,我们把数据送进去之前还需要对其进行一些什么操作呢。
PyTorch 为我们提供了丰富的 API 以供我们方便的进行学习。
torchvision 是 pytorch 的一个图形库,其中还提供一些常用的数据集和几个已经搭建好的经典网络模型,以及一些图像数据处理方面的工具,主要供数据预处理阶段使用。它服务于 PyTorch 深度学习框架,主要用来构建计算机视觉模型。以下是 torchvision 的构成:
torchvision.datasets: 一些加载数据的函数及常用的数据集接口;
torchvision.models: 包含常用的模型结构(含预训练模型),例如 AlexNet、VGG、ResNet 等;
torchvision.transforms: 常用的图片变换,例如裁剪、旋转等;
torchvision.utils: 其他的一些有用的方法。
这部分是讲解如何读取数据集和处理数据集,如果你不是想真正上手写深度学习的代码,可以无视这块。