@@ -82,8 +82,8 @@ def __init__(self, block, layers, num_classes=10):
82
82
self .bn = nn .BatchNorm2d (16 )
83
83
self .relu = nn .ReLU (inplace = True )
84
84
self .layer1 = self .make_layer (block , 16 , layers [0 ])
85
- self .layer2 = self .make_layer (block , 32 , layers [0 ], 2 )
86
- self .layer3 = self .make_layer (block , 64 , layers [1 ], 2 )
85
+ self .layer2 = self .make_layer (block , 32 , layers [1 ], 2 )
86
+ self .layer3 = self .make_layer (block , 64 , layers [2 ], 2 )
87
87
self .avg_pool = nn .AvgPool2d (8 )
88
88
self .fc = nn .Linear (64 , num_classes )
89
89
@@ -112,7 +112,7 @@ def forward(self, x):
112
112
out = self .fc (out )
113
113
return out
114
114
115
- model = ResNet (ResidualBlock , [2 , 2 , 2 , 2 ]).to (device )
115
+ model = ResNet (ResidualBlock , [2 , 2 , 2 ]).to (device )
116
116
117
117
118
118
# Loss and optimizer
@@ -166,4 +166,4 @@ def update_lr(optimizer, lr):
166
166
print ('Accuracy of the model on the test images: {} %' .format (100 * correct / total ))
167
167
168
168
# Save the model checkpoint
169
- torch .save (model .state_dict (), 'resnet.ckpt' )
169
+ torch .save (model .state_dict (), 'resnet.ckpt' )
0 commit comments