Skip to content

Commit

Permalink
Update RegNet
Browse files Browse the repository at this point in the history
  • Loading branch information
kuangliu committed Apr 15, 2020
1 parent 5b79044 commit 9b0869d
Showing 1 changed file with 9 additions and 9 deletions.
18 changes: 9 additions & 9 deletions models/regnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,16 @@ def forward(self, x):


class Block(nn.Module):
def __init__(self, w_in, w_out, stride, num_groups, bottleneck_ratio, se_ratio):
def __init__(self, w_in, w_out, stride, group_width, bottleneck_ratio, se_ratio):
super(Block, self).__init__()
# 1x1
w_b = int(round(w_out * bottleneck_ratio))
self.conv1 = nn.Conv2d(w_in, w_b, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(w_b)
# 3x3
groups = w_b // num_groups
num_groups = w_b // group_width
self.conv2 = nn.Conv2d(w_b, w_b, kernel_size=3,
stride=stride, padding=1, groups=groups, bias=False)
stride=stride, padding=1, groups=num_groups, bias=False)
self.bn2 = nn.BatchNorm2d(w_b)
# se
self.with_se = se_ratio > 0
Expand Down Expand Up @@ -83,15 +83,15 @@ def _make_layer(self, idx):
depth = self.cfg['depths'][idx]
width = self.cfg['widths'][idx]
stride = self.cfg['strides'][idx]
num_groups = self.cfg['num_groups']
group_width = self.cfg['group_width']
bottleneck_ratio = self.cfg['bottleneck_ratio']
se_ratio = self.cfg['se_ratio']

layers = []
for i in range(depth):
s = stride if i == 0 else 1
layers.append(Block(self.in_planes, width,
s, num_groups, bottleneck_ratio, se_ratio))
s, group_width, bottleneck_ratio, se_ratio))
self.in_planes = width
return nn.Sequential(*layers)

Expand All @@ -112,7 +112,7 @@ def RegNetX_200MF():
'depths': [1, 1, 4, 7],
'widths': [24, 56, 152, 368],
'strides': [1, 1, 2, 2],
'num_groups': 8,
'group_width': 8,
'bottleneck_ratio': 1,
'se_ratio': 0,
}
Expand All @@ -124,7 +124,7 @@ def RegNetX_400MF():
'depths': [1, 2, 7, 12],
'widths': [32, 64, 160, 384],
'strides': [1, 1, 2, 2],
'num_groups': 16,
'group_width': 16,
'bottleneck_ratio': 1,
'se_ratio': 0,
}
Expand All @@ -136,15 +136,15 @@ def RegNetY_400MF():
'depths': [1, 2, 7, 12],
'widths': [32, 64, 160, 384],
'strides': [1, 1, 2, 2],
'num_groups': 16,
'group_width': 16,
'bottleneck_ratio': 1,
'se_ratio': 0.25,
}
return RegNet(cfg)


def test():
net = RegNetY_400MF()
net = RegNetX_200MF()
print(net)
x = torch.randn(2, 3, 32, 32)
y = net(x)
Expand Down

0 comments on commit 9b0869d

Please sign in to comment.