Skip to content

Commit 7aa9822

Browse files
Mu Liastonzhang
Mu Li
authored andcommitted
linear
1 parent a21f038 commit 7aa9822

16 files changed

+4912
-2
lines changed

Jenkinsfile

+2-2
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ stage("Build and Publish") {
5757

5858
sh label:"Build PDF", script:"""set -ex
5959
conda activate ${ENV_NAME}
60-
d2lbook build pdf
60+
# d2lbook build pdf
6161
"""
6262

6363
if (env.BRANCH_NAME == 'release') {
@@ -70,7 +70,7 @@ stage("Build and Publish") {
7070
} else {
7171
sh label:"Publish", script:"""set -ex
7272
conda activate ${ENV_NAME}
73-
d2lbook deploy html pdf --s3 s3://preview.d2l.ai/${JOB_NAME}/
73+
d2lbook deploy html --s3 s3://preview.d2l.ai/${JOB_NAME}/
7474
"""
7575
if (env.BRANCH_NAME.startsWith("PR-")) {
7676
pullRequest.comment("Job ${JOB_NAME}/${BUILD_NUMBER} is complete. \nCheck the results at http://preview.d2l.ai/${JOB_NAME}/")
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
# 影像分类数据集
2+
:label:`sec_fashion_mnist`
3+
4+
广泛使用的图像分类数据集之一是 MNIST 数据集 :cite:`LeCun.Bottou.Bengio.ea.1998`。虽然它作为基准数据集运行良好,但即使是按照当今标准的简单模型也能达到 95% 以上的分类准确率,因此不适合区分较强的模型和较弱的模型。如今,MNIST 的作用更多的是理智检查,而不是作为基准。到了赌注只是一点点, 我们将集中讨论在未来部分的质量相似, 但相对复杂的时尚多国主义数据集 :cite:`Xiao.Rasul.Vollgraf.2017`, 这是在 2017 年发布.
5+
6+
```{.python .input}
7+
%matplotlib inline
8+
from d2l import mxnet as d2l
9+
from mxnet import gluon
10+
import sys
11+
12+
d2l.use_svg_display()
13+
```
14+
15+
```{.python .input}
16+
#@tab pytorch
17+
%matplotlib inline
18+
from d2l import torch as d2l
19+
import torch
20+
import torchvision
21+
from torchvision import transforms
22+
from torch.utils import data
23+
24+
d2l.use_svg_display()
25+
```
26+
27+
```{.python .input}
28+
#@tab tensorflow
29+
%matplotlib inline
30+
from d2l import tensorflow as d2l
31+
import tensorflow as tf
32+
33+
d2l.use_svg_display()
34+
```
35+
36+
## 读取数据集
37+
38+
我们可以通过框架中的内置函数将 Fashion-MNist 数据集下载并读取到内存中。
39+
40+
```{.python .input}
41+
mnist_train = gluon.data.vision.FashionMNIST(train=True)
42+
mnist_test = gluon.data.vision.FashionMNIST(train=False)
43+
```
44+
45+
```{.python .input}
46+
#@tab pytorch
47+
# `ToTensor` converts the image data from PIL type to 32-bit floating point
48+
# tensors. It divides all numbers by 255 so that all pixel values are between
49+
# 0 and 1
50+
trans = transforms.ToTensor()
51+
mnist_train = torchvision.datasets.FashionMNIST(
52+
root="../data", train=True, transform=trans, download=True)
53+
mnist_test = torchvision.datasets.FashionMNIST(
54+
root="../data", train=False, transform=trans, download=True)
55+
```
56+
57+
```{.python .input}
58+
#@tab tensorflow
59+
mnist_train, mnist_test = tf.keras.datasets.fashion_mnist.load_data()
60+
```
61+
62+
时尚 MNist 由 10 个类别的图像组成,每个类别由训练数据集中的 6000 图像和测试数据集中的 1000 个图像表示。* 测试数据集 *(或 * 测试集 *)用于评估模型性能,而不用于培训。因此,训练集和测试集分别包含 60000 和 10000 个图像。
63+
64+
```{.python .input}
65+
#@tab mxnet, pytorch
66+
len(mnist_train), len(mnist_test)
67+
```
68+
69+
```{.python .input}
70+
#@tab tensorflow
71+
len(mnist_train[0]), len(mnist_test[0])
72+
```
73+
74+
每个输入图像的高度和宽度均为 28 像素。请注意,数据集由灰度图像组成,其通道数为 1。为了简洁起见,在这本书中,我们存储任何图像的形状与高度 $h$ 宽度 $w$ 像素为 $h \times w$ 或($h$,$w$)。
75+
76+
```{.python .input}
77+
#@tab all
78+
mnist_train[0][0].shape
79+
```
80+
81+
时尚 MNist 中的图片与以下类别相关:T 恤、长裤、套头衫、连衣裙、外套、凉鞋、衬衫、运动鞋、包和踝靴。以下函数在数字标签索引及其文本名称之间进行转换。
82+
83+
```{.python .input}
84+
#@tab all
85+
def get_fashion_mnist_labels(labels): #@save
86+
"""Return text labels for the Fashion-MNIST dataset."""
87+
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
88+
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
89+
return [text_labels[int(i)] for i in labels]
90+
```
91+
92+
我们现在可以创建一个函数来显示这些示例。
93+
94+
```{.python .input}
95+
#@tab all
96+
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save
97+
"""Plot a list of images."""
98+
figsize = (num_cols * scale, num_rows * scale)
99+
_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
100+
axes = axes.flatten()
101+
for i, (ax, img) in enumerate(zip(axes, imgs)):
102+
ax.imshow(d2l.numpy(img))
103+
ax.axes.get_xaxis().set_visible(False)
104+
ax.axes.get_yaxis().set_visible(False)
105+
if titles:
106+
ax.set_title(titles[i])
107+
return axes
108+
```
109+
110+
以下是训练数据集中前几个示例的图像及其相应标签(以文本形式)。
111+
112+
```{.python .input}
113+
X, y = mnist_train[:18]
114+
show_images(X.squeeze(axis=-1), 2, 9, titles=get_fashion_mnist_labels(y));
115+
```
116+
117+
```{.python .input}
118+
#@tab pytorch
119+
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
120+
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y));
121+
```
122+
123+
```{.python .input}
124+
#@tab tensorflow
125+
X = tf.constant(mnist_train[0][:18])
126+
y = tf.constant(mnist_train[1][:18])
127+
show_images(X, 2, 9, titles=get_fashion_mnist_labels(y));
128+
```
129+
130+
## 读取小批处理
131+
132+
为了让我们在阅读训练和测试集时更轻松,我们使用内置的数据迭代器,而不是从头开始创建数据迭代器。回想一下,在每次迭代中,数据加载器每次读取大小为 `batch_size` 的小批数据。我们还随机洗牌训练数据迭代器的示例。
133+
134+
```{.python .input}
135+
batch_size = 256
136+
137+
def get_dataloader_workers(): #@save
138+
"""Use 4 processes to read the data expect for Windows."""
139+
return 0 if sys.platform.startswith('win') else 4
140+
141+
# `ToTensor` converts the image data from uint8 to 32-bit floating point. It
142+
# divides all numbers by 255 so that all pixel values are between 0 and 1
143+
transformer = gluon.data.vision.transforms.ToTensor()
144+
train_iter = gluon.data.DataLoader(mnist_train.transform_first(transformer),
145+
batch_size, shuffle=True,
146+
num_workers=get_dataloader_workers())
147+
```
148+
149+
```{.python .input}
150+
#@tab pytorch
151+
batch_size = 256
152+
153+
def get_dataloader_workers(): #@save
154+
"""Use 4 processes to read the data."""
155+
return 4
156+
157+
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
158+
num_workers=get_dataloader_workers())
159+
```
160+
161+
```{.python .input}
162+
#@tab tensorflow
163+
batch_size = 256
164+
train_iter = tf.data.Dataset.from_tensor_slices(
165+
mnist_train).batch(batch_size).shuffle(len(mnist_train[0]))
166+
```
167+
168+
让我们看一下读取训练数据所需的时间。
169+
170+
```{.python .input}
171+
#@tab all
172+
timer = d2l.Timer()
173+
for X, y in train_iter:
174+
continue
175+
f'{timer.stop():.2f} sec'
176+
```
177+
178+
## 把所有东西放在一起
179+
180+
现在我们定义了 `load_data_fashion_mnist` 函数,用于获取和读取时尚多国主义数据集。它返回训练集和验证集的数据迭代器。此外,它还接受一个可选参数,将图像大小调整为另一种形状。
181+
182+
```{.python .input}
183+
def load_data_fashion_mnist(batch_size, resize=None): #@save
184+
"""Download the Fashion-MNIST dataset and then load it into memory."""
185+
dataset = gluon.data.vision
186+
trans = [dataset.transforms.ToTensor()]
187+
if resize:
188+
trans.insert(0, dataset.transforms.Resize(resize))
189+
trans = dataset.transforms.Compose(trans)
190+
mnist_train = dataset.FashionMNIST(train=True).transform_first(trans)
191+
mnist_test = dataset.FashionMNIST(train=False).transform_first(trans)
192+
return (gluon.data.DataLoader(mnist_train, batch_size, shuffle=True,
193+
num_workers=get_dataloader_workers()),
194+
gluon.data.DataLoader(mnist_test, batch_size, shuffle=False,
195+
num_workers=get_dataloader_workers()))
196+
```
197+
198+
```{.python .input}
199+
#@tab pytorch
200+
def load_data_fashion_mnist(batch_size, resize=None): #@save
201+
"""Download the Fashion-MNIST dataset and then load it into memory."""
202+
trans = [transforms.ToTensor()]
203+
if resize:
204+
trans.insert(0, transforms.Resize(resize))
205+
trans = transforms.Compose(trans)
206+
mnist_train = torchvision.datasets.FashionMNIST(
207+
root="../data", train=True, transform=trans, download=True)
208+
mnist_test = torchvision.datasets.FashionMNIST(
209+
root="../data", train=False, transform=trans, download=True)
210+
return (data.DataLoader(mnist_train, batch_size, shuffle=True,
211+
num_workers=get_dataloader_workers()),
212+
data.DataLoader(mnist_test, batch_size, shuffle=False,
213+
num_workers=get_dataloader_workers()))
214+
```
215+
216+
```{.python .input}
217+
#@tab tensorflow
218+
def load_data_fashion_mnist(batch_size, resize=None): #@save
219+
"""Download the Fashion-MNIST dataset and then load it into memory."""
220+
mnist_train, mnist_test = tf.keras.datasets.fashion_mnist.load_data()
221+
# Divide all numbers by 255 so that all pixel values are between
222+
# 0 and 1, add a batch dimension at the last. And cast label to int32
223+
process = lambda X, y: (tf.expand_dims(X, axis=3) / 255,
224+
tf.cast(y, dtype='int32'))
225+
resize_fn = lambda X, y: (
226+
tf.image.resize_with_pad(X, resize, resize) if resize else X, y)
227+
return (
228+
tf.data.Dataset.from_tensor_slices(process(*mnist_train)).batch(
229+
batch_size).shuffle(len(mnist_train[0])).map(resize_fn),
230+
tf.data.Dataset.from_tensor_slices(process(*mnist_test)).batch(
231+
batch_size).map(resize_fn))
232+
```
233+
234+
下面我们通过指定 `resize` 参数来测试 `load_data_fashion_mnist` 函数的图像调整大小特征。
235+
236+
```{.python .input}
237+
#@tab all
238+
train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
239+
for X, y in train_iter:
240+
print(X.shape, X.dtype, y.shape, y.dtype)
241+
break
242+
```
243+
244+
我们现在已经准备好与时尚 MNist 数据集在下面的部分。
245+
246+
## 摘要
247+
248+
* 时尚 MNist 是一个服装分类数据集,由代表 10 个类别的图像组成。我们将在后续章节和章节中使用此数据集来评估各种分类算法。
249+
* 我们将任何高度为 $h$ 像素的图像形状存储为 $w$ 像素。
250+
* 数据迭代器是高效性能的关键组件。依靠实施良好的数据迭代器,利用高性能计算来避免减慢训练循环。
251+
252+
## 练习
253+
254+
1.`batch_size`(实例,减少到 1)是否会影响读取性能?
255+
1. 数据迭代器的性能非常重要。你认为当前的实现足够快吗?探索各种选择来改进它。
256+
1. 查看框架的在线 API 文档。还有哪些其他数据集可用?
257+
258+
:begin_tab:`mxnet`
259+
[Discussions](https://discuss.d2l.ai/t/48)
260+
:end_tab:
261+
262+
:begin_tab:`pytorch`
263+
[Discussions](https://discuss.d2l.ai/t/49)
264+
:end_tab:
265+
266+
:begin_tab:`tensorflow`
267+
[Discussions](https://discuss.d2l.ai/t/224)
268+
:end_tab:

0 commit comments

Comments
 (0)