|
| 1 | +# 影像分类数据集 |
| 2 | +:label:`sec_fashion_mnist` |
| 3 | + |
| 4 | +广泛使用的图像分类数据集之一是 MNIST 数据集 :cite:`LeCun.Bottou.Bengio.ea.1998`。虽然它作为基准数据集运行良好,但即使是按照当今标准的简单模型也能达到 95% 以上的分类准确率,因此不适合区分较强的模型和较弱的模型。如今,MNIST 的作用更多的是理智检查,而不是作为基准。到了赌注只是一点点, 我们将集中讨论在未来部分的质量相似, 但相对复杂的时尚多国主义数据集 :cite:`Xiao.Rasul.Vollgraf.2017`, 这是在 2017 年发布. |
| 5 | + |
| 6 | +```{.python .input} |
| 7 | +%matplotlib inline |
| 8 | +from d2l import mxnet as d2l |
| 9 | +from mxnet import gluon |
| 10 | +import sys |
| 11 | +
|
| 12 | +d2l.use_svg_display() |
| 13 | +``` |
| 14 | + |
| 15 | +```{.python .input} |
| 16 | +#@tab pytorch |
| 17 | +%matplotlib inline |
| 18 | +from d2l import torch as d2l |
| 19 | +import torch |
| 20 | +import torchvision |
| 21 | +from torchvision import transforms |
| 22 | +from torch.utils import data |
| 23 | +
|
| 24 | +d2l.use_svg_display() |
| 25 | +``` |
| 26 | + |
| 27 | +```{.python .input} |
| 28 | +#@tab tensorflow |
| 29 | +%matplotlib inline |
| 30 | +from d2l import tensorflow as d2l |
| 31 | +import tensorflow as tf |
| 32 | +
|
| 33 | +d2l.use_svg_display() |
| 34 | +``` |
| 35 | + |
| 36 | +## 读取数据集 |
| 37 | + |
| 38 | +我们可以通过框架中的内置函数将 Fashion-MNist 数据集下载并读取到内存中。 |
| 39 | + |
| 40 | +```{.python .input} |
| 41 | +mnist_train = gluon.data.vision.FashionMNIST(train=True) |
| 42 | +mnist_test = gluon.data.vision.FashionMNIST(train=False) |
| 43 | +``` |
| 44 | + |
| 45 | +```{.python .input} |
| 46 | +#@tab pytorch |
| 47 | +# `ToTensor` converts the image data from PIL type to 32-bit floating point |
| 48 | +# tensors. It divides all numbers by 255 so that all pixel values are between |
| 49 | +# 0 and 1 |
| 50 | +trans = transforms.ToTensor() |
| 51 | +mnist_train = torchvision.datasets.FashionMNIST( |
| 52 | + root="../data", train=True, transform=trans, download=True) |
| 53 | +mnist_test = torchvision.datasets.FashionMNIST( |
| 54 | + root="../data", train=False, transform=trans, download=True) |
| 55 | +``` |
| 56 | + |
| 57 | +```{.python .input} |
| 58 | +#@tab tensorflow |
| 59 | +mnist_train, mnist_test = tf.keras.datasets.fashion_mnist.load_data() |
| 60 | +``` |
| 61 | + |
| 62 | +时尚 MNist 由 10 个类别的图像组成,每个类别由训练数据集中的 6000 图像和测试数据集中的 1000 个图像表示。* 测试数据集 *(或 * 测试集 *)用于评估模型性能,而不用于培训。因此,训练集和测试集分别包含 60000 和 10000 个图像。 |
| 63 | + |
| 64 | +```{.python .input} |
| 65 | +#@tab mxnet, pytorch |
| 66 | +len(mnist_train), len(mnist_test) |
| 67 | +``` |
| 68 | + |
| 69 | +```{.python .input} |
| 70 | +#@tab tensorflow |
| 71 | +len(mnist_train[0]), len(mnist_test[0]) |
| 72 | +``` |
| 73 | + |
| 74 | +每个输入图像的高度和宽度均为 28 像素。请注意,数据集由灰度图像组成,其通道数为 1。为了简洁起见,在这本书中,我们存储任何图像的形状与高度 $h$ 宽度 $w$ 像素为 $h \times w$ 或($h$,$w$)。 |
| 75 | + |
| 76 | +```{.python .input} |
| 77 | +#@tab all |
| 78 | +mnist_train[0][0].shape |
| 79 | +``` |
| 80 | + |
| 81 | +时尚 MNist 中的图片与以下类别相关:T 恤、长裤、套头衫、连衣裙、外套、凉鞋、衬衫、运动鞋、包和踝靴。以下函数在数字标签索引及其文本名称之间进行转换。 |
| 82 | + |
| 83 | +```{.python .input} |
| 84 | +#@tab all |
| 85 | +def get_fashion_mnist_labels(labels): #@save |
| 86 | + """Return text labels for the Fashion-MNIST dataset.""" |
| 87 | + text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat', |
| 88 | + 'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot'] |
| 89 | + return [text_labels[int(i)] for i in labels] |
| 90 | +``` |
| 91 | + |
| 92 | +我们现在可以创建一个函数来显示这些示例。 |
| 93 | + |
| 94 | +```{.python .input} |
| 95 | +#@tab all |
| 96 | +def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save |
| 97 | + """Plot a list of images.""" |
| 98 | + figsize = (num_cols * scale, num_rows * scale) |
| 99 | + _, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize) |
| 100 | + axes = axes.flatten() |
| 101 | + for i, (ax, img) in enumerate(zip(axes, imgs)): |
| 102 | + ax.imshow(d2l.numpy(img)) |
| 103 | + ax.axes.get_xaxis().set_visible(False) |
| 104 | + ax.axes.get_yaxis().set_visible(False) |
| 105 | + if titles: |
| 106 | + ax.set_title(titles[i]) |
| 107 | + return axes |
| 108 | +``` |
| 109 | + |
| 110 | +以下是训练数据集中前几个示例的图像及其相应标签(以文本形式)。 |
| 111 | + |
| 112 | +```{.python .input} |
| 113 | +X, y = mnist_train[:18] |
| 114 | +show_images(X.squeeze(axis=-1), 2, 9, titles=get_fashion_mnist_labels(y)); |
| 115 | +``` |
| 116 | + |
| 117 | +```{.python .input} |
| 118 | +#@tab pytorch |
| 119 | +X, y = next(iter(data.DataLoader(mnist_train, batch_size=18))) |
| 120 | +show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y)); |
| 121 | +``` |
| 122 | + |
| 123 | +```{.python .input} |
| 124 | +#@tab tensorflow |
| 125 | +X = tf.constant(mnist_train[0][:18]) |
| 126 | +y = tf.constant(mnist_train[1][:18]) |
| 127 | +show_images(X, 2, 9, titles=get_fashion_mnist_labels(y)); |
| 128 | +``` |
| 129 | + |
| 130 | +## 读取小批处理 |
| 131 | + |
| 132 | +为了让我们在阅读训练和测试集时更轻松,我们使用内置的数据迭代器,而不是从头开始创建数据迭代器。回想一下,在每次迭代中,数据加载器每次读取大小为 `batch_size` 的小批数据。我们还随机洗牌训练数据迭代器的示例。 |
| 133 | + |
| 134 | +```{.python .input} |
| 135 | +batch_size = 256 |
| 136 | +
|
| 137 | +def get_dataloader_workers(): #@save |
| 138 | + """Use 4 processes to read the data expect for Windows.""" |
| 139 | + return 0 if sys.platform.startswith('win') else 4 |
| 140 | +
|
| 141 | +# `ToTensor` converts the image data from uint8 to 32-bit floating point. It |
| 142 | +# divides all numbers by 255 so that all pixel values are between 0 and 1 |
| 143 | +transformer = gluon.data.vision.transforms.ToTensor() |
| 144 | +train_iter = gluon.data.DataLoader(mnist_train.transform_first(transformer), |
| 145 | + batch_size, shuffle=True, |
| 146 | + num_workers=get_dataloader_workers()) |
| 147 | +``` |
| 148 | + |
| 149 | +```{.python .input} |
| 150 | +#@tab pytorch |
| 151 | +batch_size = 256 |
| 152 | +
|
| 153 | +def get_dataloader_workers(): #@save |
| 154 | + """Use 4 processes to read the data.""" |
| 155 | + return 4 |
| 156 | +
|
| 157 | +train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True, |
| 158 | + num_workers=get_dataloader_workers()) |
| 159 | +``` |
| 160 | + |
| 161 | +```{.python .input} |
| 162 | +#@tab tensorflow |
| 163 | +batch_size = 256 |
| 164 | +train_iter = tf.data.Dataset.from_tensor_slices( |
| 165 | + mnist_train).batch(batch_size).shuffle(len(mnist_train[0])) |
| 166 | +``` |
| 167 | + |
| 168 | +让我们看一下读取训练数据所需的时间。 |
| 169 | + |
| 170 | +```{.python .input} |
| 171 | +#@tab all |
| 172 | +timer = d2l.Timer() |
| 173 | +for X, y in train_iter: |
| 174 | + continue |
| 175 | +f'{timer.stop():.2f} sec' |
| 176 | +``` |
| 177 | + |
| 178 | +## 把所有东西放在一起 |
| 179 | + |
| 180 | +现在我们定义了 `load_data_fashion_mnist` 函数,用于获取和读取时尚多国主义数据集。它返回训练集和验证集的数据迭代器。此外,它还接受一个可选参数,将图像大小调整为另一种形状。 |
| 181 | + |
| 182 | +```{.python .input} |
| 183 | +def load_data_fashion_mnist(batch_size, resize=None): #@save |
| 184 | + """Download the Fashion-MNIST dataset and then load it into memory.""" |
| 185 | + dataset = gluon.data.vision |
| 186 | + trans = [dataset.transforms.ToTensor()] |
| 187 | + if resize: |
| 188 | + trans.insert(0, dataset.transforms.Resize(resize)) |
| 189 | + trans = dataset.transforms.Compose(trans) |
| 190 | + mnist_train = dataset.FashionMNIST(train=True).transform_first(trans) |
| 191 | + mnist_test = dataset.FashionMNIST(train=False).transform_first(trans) |
| 192 | + return (gluon.data.DataLoader(mnist_train, batch_size, shuffle=True, |
| 193 | + num_workers=get_dataloader_workers()), |
| 194 | + gluon.data.DataLoader(mnist_test, batch_size, shuffle=False, |
| 195 | + num_workers=get_dataloader_workers())) |
| 196 | +``` |
| 197 | + |
| 198 | +```{.python .input} |
| 199 | +#@tab pytorch |
| 200 | +def load_data_fashion_mnist(batch_size, resize=None): #@save |
| 201 | + """Download the Fashion-MNIST dataset and then load it into memory.""" |
| 202 | + trans = [transforms.ToTensor()] |
| 203 | + if resize: |
| 204 | + trans.insert(0, transforms.Resize(resize)) |
| 205 | + trans = transforms.Compose(trans) |
| 206 | + mnist_train = torchvision.datasets.FashionMNIST( |
| 207 | + root="../data", train=True, transform=trans, download=True) |
| 208 | + mnist_test = torchvision.datasets.FashionMNIST( |
| 209 | + root="../data", train=False, transform=trans, download=True) |
| 210 | + return (data.DataLoader(mnist_train, batch_size, shuffle=True, |
| 211 | + num_workers=get_dataloader_workers()), |
| 212 | + data.DataLoader(mnist_test, batch_size, shuffle=False, |
| 213 | + num_workers=get_dataloader_workers())) |
| 214 | +``` |
| 215 | + |
| 216 | +```{.python .input} |
| 217 | +#@tab tensorflow |
| 218 | +def load_data_fashion_mnist(batch_size, resize=None): #@save |
| 219 | + """Download the Fashion-MNIST dataset and then load it into memory.""" |
| 220 | + mnist_train, mnist_test = tf.keras.datasets.fashion_mnist.load_data() |
| 221 | + # Divide all numbers by 255 so that all pixel values are between |
| 222 | + # 0 and 1, add a batch dimension at the last. And cast label to int32 |
| 223 | + process = lambda X, y: (tf.expand_dims(X, axis=3) / 255, |
| 224 | + tf.cast(y, dtype='int32')) |
| 225 | + resize_fn = lambda X, y: ( |
| 226 | + tf.image.resize_with_pad(X, resize, resize) if resize else X, y) |
| 227 | + return ( |
| 228 | + tf.data.Dataset.from_tensor_slices(process(*mnist_train)).batch( |
| 229 | + batch_size).shuffle(len(mnist_train[0])).map(resize_fn), |
| 230 | + tf.data.Dataset.from_tensor_slices(process(*mnist_test)).batch( |
| 231 | + batch_size).map(resize_fn)) |
| 232 | +``` |
| 233 | + |
| 234 | +下面我们通过指定 `resize` 参数来测试 `load_data_fashion_mnist` 函数的图像调整大小特征。 |
| 235 | + |
| 236 | +```{.python .input} |
| 237 | +#@tab all |
| 238 | +train_iter, test_iter = load_data_fashion_mnist(32, resize=64) |
| 239 | +for X, y in train_iter: |
| 240 | + print(X.shape, X.dtype, y.shape, y.dtype) |
| 241 | + break |
| 242 | +``` |
| 243 | + |
| 244 | +我们现在已经准备好与时尚 MNist 数据集在下面的部分。 |
| 245 | + |
| 246 | +## 摘要 |
| 247 | + |
| 248 | +* 时尚 MNist 是一个服装分类数据集,由代表 10 个类别的图像组成。我们将在后续章节和章节中使用此数据集来评估各种分类算法。 |
| 249 | +* 我们将任何高度为 $h$ 像素的图像形状存储为 $w$ 像素。 |
| 250 | +* 数据迭代器是高效性能的关键组件。依靠实施良好的数据迭代器,利用高性能计算来避免减慢训练循环。 |
| 251 | + |
| 252 | +## 练习 |
| 253 | + |
| 254 | +1. 将 `batch_size`(实例,减少到 1)是否会影响读取性能? |
| 255 | +1. 数据迭代器的性能非常重要。你认为当前的实现足够快吗?探索各种选择来改进它。 |
| 256 | +1. 查看框架的在线 API 文档。还有哪些其他数据集可用? |
| 257 | + |
| 258 | +:begin_tab:`mxnet` |
| 259 | +[Discussions](https://discuss.d2l.ai/t/48) |
| 260 | +:end_tab: |
| 261 | + |
| 262 | +:begin_tab:`pytorch` |
| 263 | +[Discussions](https://discuss.d2l.ai/t/49) |
| 264 | +:end_tab: |
| 265 | + |
| 266 | +:begin_tab:`tensorflow` |
| 267 | +[Discussions](https://discuss.d2l.ai/t/224) |
| 268 | +:end_tab: |
0 commit comments