Skip to content

Commit

Permalink
Add pre-activation ResNet
Browse files Browse the repository at this point in the history
  • Loading branch information
kuangliu committed Jun 6, 2017
1 parent 818df51 commit 9e0454a
Showing 1 changed file with 67 additions and 5 deletions.
72 changes: 67 additions & 5 deletions models/resnet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,12 @@
'''ResNet18/34/50/101/152 in Pytorch.'''
'''ResNet in Pytorch.
BasicBlock and Bottleneck module is from the original ResNet paper:
"Deep Residual Learning for Image Recognition", CVPR2016.
PreActBlock and PreActBottleneck module is from the later paper:
"Identity Mappings in Deep Residual Networks", ECCV2016.
which is considered better than the original structure.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
Expand Down Expand Up @@ -35,6 +43,31 @@ def forward(self, x):
return out


class PreActBlock(nn.Module):
'''Pre-activation version of the BasicBlock.'''
expansion = 1

def __init__(self, in_planes, planes, stride=1):
super(PreActBlock, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = conv3x3(in_planes, planes, stride)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = conv3x3(planes, planes)

self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)

def forward(self, x):
out = self.conv1(F.relu(self.bn1(x)))
out = self.conv2(F.relu(self.bn2(out)))
out += self.shortcut(x)
return out


class Bottleneck(nn.Module):
expansion = 4

Expand Down Expand Up @@ -63,6 +96,34 @@ def forward(self, x):
return out


class PreActBottleneck(nn.Module):
'''Pre-activation version of the original Bottleneck module.'''
expansion = 4

def __init__(self, in_planes, planes, stride=1):
super(PreActBottleneck, self).__init__()
self.bn1 = nn.BatchNorm2d(in_planes)
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, self.expansion*planes, kernel_size=1, bias=False)

self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)

def forward(self, x):
out = self.conv1(F.relu(self.bn1(x)))
out = self.conv2(F.relu(self.bn2(out)))
out = self.conv3(F.relu(self.bn3(out)))
out += self.shortcut(x)
return out


class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10):
super(ResNet, self).__init__()
Expand Down Expand Up @@ -97,7 +158,7 @@ def forward(self, x):


def ResNet18():
return ResNet(BasicBlock, [2,2,2,2])
return ResNet(PreActBlock, [2,2,2,2])

def ResNet34():
return ResNet(BasicBlock, [3,4,6,3])
Expand All @@ -111,9 +172,10 @@ def ResNet101():
def ResNet152():
return ResNet(Bottleneck, [3,8,36,3])

def test_resnet():
net = ResNet50()

def test():
net = ResNet18()
y = net(Variable(torch.randn(1,3,32,32)))
print(y.size())

# test_resnet()
# test()

0 comments on commit 9e0454a

Please sign in to comment.