forked from kuangliu/pytorch-cifar
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
kuangliu
committed
Dec 11, 2017
1 parent
886af4c
commit 0d4ac1c
Showing
2 changed files
with
129 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
'''PNASNet in PyTorch. | ||
Paper: Progressive Neural Architecture Search | ||
''' | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from torch.autograd import Variable | ||
|
||
|
||
class SepConv(nn.Module): | ||
'''Separable Convolution.''' | ||
def __init__(self, in_planes, out_planes, kernel_size, stride): | ||
super(SepConv, self).__init__() | ||
self.conv1 = nn.Conv2d(in_planes, out_planes, | ||
kernel_size, stride, | ||
padding=(kernel_size-1)//2, | ||
bias=False, groups=in_planes) | ||
self.bn1 = nn.BatchNorm2d(out_planes) | ||
|
||
def forward(self, x): | ||
return self.bn1(self.conv1(x)) | ||
|
||
|
||
class CellA(nn.Module): | ||
def __init__(self, in_planes, out_planes, stride=1): | ||
super(CellA, self).__init__() | ||
self.stride = stride | ||
self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride) | ||
if stride==2: | ||
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) | ||
self.bn1 = nn.BatchNorm2d(out_planes) | ||
|
||
def forward(self, x): | ||
y1 = self.sep_conv1(x) | ||
y2 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1) | ||
if self.stride==2: | ||
y2 = self.bn1(self.conv1(y2)) | ||
return F.relu(y1+y2) | ||
|
||
class CellB(nn.Module): | ||
def __init__(self, in_planes, out_planes, stride=1): | ||
super(CellB, self).__init__() | ||
self.stride = stride | ||
# Left branch | ||
self.sep_conv1 = SepConv(in_planes, out_planes, kernel_size=7, stride=stride) | ||
self.sep_conv2 = SepConv(in_planes, out_planes, kernel_size=3, stride=stride) | ||
# Right branch | ||
self.sep_conv3 = SepConv(in_planes, out_planes, kernel_size=5, stride=stride) | ||
if stride==2: | ||
self.conv1 = nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) | ||
self.bn1 = nn.BatchNorm2d(out_planes) | ||
# Reduce channels | ||
self.conv2 = nn.Conv2d(2*out_planes, out_planes, kernel_size=1, stride=1, padding=0, bias=False) | ||
self.bn2 = nn.BatchNorm2d(out_planes) | ||
|
||
def forward(self, x): | ||
# Left branch | ||
y1 = self.sep_conv1(x) | ||
y2 = self.sep_conv2(x) | ||
# Right branch | ||
y3 = F.max_pool2d(x, kernel_size=3, stride=self.stride, padding=1) | ||
if self.stride==2: | ||
y3 = self.bn1(self.conv1(y3)) | ||
y4 = self.sep_conv3(x) | ||
# Concat & reduce channels | ||
b1 = F.relu(y1+y2) | ||
b2 = F.relu(y3+y4) | ||
y = torch.cat([b1,b2], 1) | ||
return F.relu(self.bn2(self.conv2(y))) | ||
|
||
class PNASNet(nn.Module): | ||
def __init__(self, cell_type, num_cells, num_planes): | ||
super(PNASNet, self).__init__() | ||
self.in_planes = num_planes | ||
self.cell_type = cell_type | ||
|
||
self.conv1 = nn.Conv2d(3, num_planes, kernel_size=3, stride=1, padding=1, bias=False) | ||
self.bn1 = nn.BatchNorm2d(num_planes) | ||
|
||
self.layer1 = self._make_layer(num_planes, num_cells=6) | ||
self.layer2 = self._downsample(num_planes*2) | ||
self.layer3 = self._make_layer(num_planes*2, num_cells=6) | ||
self.layer4 = self._downsample(num_planes*4) | ||
self.layer5 = self._make_layer(num_planes*4, num_cells=6) | ||
|
||
self.linear = nn.Linear(num_planes*4, 10) | ||
|
||
def _make_layer(self, planes, num_cells): | ||
layers = [] | ||
for _ in range(num_cells): | ||
layers.append(self.cell_type(self.in_planes, planes, stride=1)) | ||
self.in_planes = planes | ||
return nn.Sequential(*layers) | ||
|
||
def _downsample(self, planes): | ||
layer = self.cell_type(self.in_planes, planes, stride=2) | ||
self.in_planes = planes | ||
return layer | ||
|
||
def forward(self, x): | ||
out = F.relu(self.bn1(self.conv1(x))) | ||
out = self.layer1(out) | ||
out = self.layer2(out) | ||
out = self.layer3(out) | ||
out = self.layer4(out) | ||
out = self.layer5(out) | ||
out = F.avg_pool2d(out, 8) | ||
out = self.linear(out.view(out.size(0), -1)) | ||
return out | ||
|
||
|
||
def PNASNetA(): | ||
return PNASNet(CellA, num_cells=6, num_planes=44) | ||
|
||
def PNASNetB(): | ||
return PNASNet(CellB, num_cells=6, num_planes=32) | ||
|
||
|
||
def test(): | ||
net = PNASNetB() | ||
print(net) | ||
x = Variable(torch.randn(1,3,32,32)) | ||
y = net(x) | ||
print(y) | ||
|
||
# test() |