-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnin.py
39 lines (36 loc) · 1.58 KB
/
nin.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
from torch import nn
from d2l import torch as d2l
# nin的思想使用1*1的卷积核来替代全连接层
# 每个nin_block后面接MaxPool2d来减小输出高宽
def nin_block(in_channels, out_channels, kernel_size, stride, padding):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=1),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=1),
nn.ReLU()
)
net = nn.Sequential(
nin_block(1, 96, kernel_size=11, stride=4, padding=0),
nn.MaxPool2d(kernel_size=3, stride=2),
nin_block(96, 256, kernel_size=5, stride=1, padding=2),
nn.MaxPool2d(kernel_size=3, stride=2),
nin_block(256, 384, kernel_size=3, stride=1, padding=1),
nn.MaxPool2d(kernel_size=3, stride=2), nn.Dropout(0.5),
# 标签类别数是10
nin_block(384, 10, kernel_size=3, stride=1, padding=1),
# 全局平均池化层将窗口形状自动设置成输入的高和宽
nn.AdaptiveAvgPool2d((1, 1)),
# 将四维的输出转成二维的输出,其形状为(批量大小, 10), 使用Flatten层即可
nn.Flatten())
if __name__ == '__main__':
X = torch.randn(size=(1, 1, 224, 224), dtype=torch.float32)
for layer in net:
X = layer(X)
print(layer.__class__.__name__,'output shape:\t', X.shape)
batch_size = 128
train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=224)
lr, num_epochs = 0.01, 10
d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())