Skip to content

Commit

Permalink
fix seg model.
Browse files Browse the repository at this point in the history
  • Loading branch information
donnyyou committed Jan 6, 2020
1 parent 1e9a86a commit f0eb05b
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 18 deletions.
25 changes: 15 additions & 10 deletions model/seg/nets/deeplabv3.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(self, features, inner_features=512, out_features=512, dilations=(12
def forward(self, x):
_, _, h, w = x.size()

feat1 = F.interpolate(self.conv1(x), size=(h, w), mode='bilinear', align_corners=True)
feat1 = F.interpolate(self.conv1(x), size=(h, w), mode='bilinear', align_corners=False)

feat2 = self.conv2(x)
feat3 = self.conv3(x)
Expand All @@ -65,13 +65,17 @@ def __init__(self, configer):
super(DeepLabV3, self).__init__()
self.configer = configer
self.num_classes = self.configer.get('data', 'num_classes')
self.backbone = ModuleHelper.get_backbone(
base = ModuleHelper.get_backbone(
backbone=self.configer.get('network.backbone'),
pretrained=self.configer.get('network.pretrained')
)

self.head = nn.Sequential(ASPPModule(self.backbone.get_num_features(),
norm_type=self.configer.get('network', 'norm_type')),
self.stage1 = nn.Squential(
base.conv1, base.bn1, base.relu1, base.conv2, base.bn2, base.relu2, base.conv3, base.bn3,
base.relu3, base.max_pool, base.layer1, base.layer2, base.layer3
)
self.stage2 = base.layer4
num_features = 512 if 'resnet18' in self.configer.get('network.backbone') else 2048
self.head = nn.Sequential(ASPPModule(num_features, norm_type=self.configer.get('network', 'norm_type')),
nn.Conv2d(512, self.num_classes, kernel_size=1, stride=1, padding=0, bias=True))
self.dsn = nn.Sequential(
nn.Conv2d(1024, 512, kernel_size=3, stride=1, padding=1),
Expand All @@ -82,13 +86,14 @@ def __init__(self, configer):
self.valid_loss_dict = configer.get('loss', 'loss_weights', configer.get('loss.loss_type'))

def forward(self, data_dict):
x = self.backbone(data_dict['img'])
x_dsn = self.dsn(x[-2])
x = self.head(x[-1])
x = self.stage1(data_dict['img'])
x_dsn = self.dsn(x)
x = self.stage2(x)
x = self.head(x)
x_dsn = F.interpolate(x_dsn, size=(data_dict['img'].size(2), data_dict['img'].size(3)),
mode="bilinear", align_corners=True)
mode="bilinear", align_corners=False)
x = F.interpolate(x, size=(data_dict['img'].size(2), data_dict['img'].size(3)),
mode="bilinear", align_corners=True)
mode="bilinear", align_corners=False)
out_dict = dict(dsn_out=x_dsn, out=x)
if self.configer.get('phase') == 'test':
return out_dict
Expand Down
22 changes: 14 additions & 8 deletions model/seg/nets/pspnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def forward(self, x):
ppm_out = [x]
for pool_scale in self.ppm:
ppm_out.append(F.interpolate(pool_scale(x), (input_size[2], input_size[3]),
mode='bilinear', align_corners=True))
mode='bilinear', align_corners=False))

ppm_out = torch.cat(ppm_out, 1)

Expand All @@ -62,11 +62,16 @@ def __init__(self, configer):
super(PSPNet, self).__init__()
self.configer = configer
self.num_classes = self.configer.get('data', 'num_classes')
self.backbone = ModuleHelper.get_backbone(
base = ModuleHelper.get_backbone(
backbone=self.configer.get('network.backbone'),
pretrained=self.configer.get('network.pretrained')
)
num_features = self.backbone.get_num_features()
self.stage1 = nn.Squential(
base.conv1, base.bn1, base.relu1, base.conv2, base.bn2, base.relu2, base.conv3, base.bn3,
base.relu3, base.max_pool, base.layer1, base.layer2, base.layer3
)
self.stage2 = base.layer4
num_features = 512 if 'resnet18' in self.configer.get('network.backbone') else 2048
self.dsn = nn.Sequential(
_ConvBatchNormReluBlock(num_features // 2, num_features // 4, 3, 1,
norm_type=self.configer.get('network', 'norm_type')),
Expand All @@ -84,14 +89,15 @@ def __init__(self, configer):
self.valid_loss_dict = configer.get('loss', 'loss_weights', configer.get('loss.loss_type'))

def forward(self, data_dict):
x = self.backbone(data_dict['img'])
aux_x = self.dsn(x[-2])
x = self.ppm(x[-1])
x = self.stage1(data_dict['img'])
aux_x = self.dsn(x)
x = self.stage2(x)
x = self.ppm(x)
x = self.cls(x)
x_dsn = F.interpolate(aux_x, size=(data_dict['img'].size(2), data_dict['img'].size(3)),
mode="bilinear", align_corners=True)
mode="bilinear", align_corners=False)
x = F.interpolate(x, size=(data_dict['img'].size(2), data_dict['img'].size(3)),
mode="bilinear", align_corners=True)
mode="bilinear", align_corners=False)
out_dict = dict(dsn_out=x_dsn, out=x)
if self.configer.get('phase') == 'test':
return out_dict
Expand Down

0 comments on commit f0eb05b

Please sign in to comment.