-
Notifications
You must be signed in to change notification settings - Fork 18
/
Copy pathDFF_edge.py
114 lines (90 loc) · 4.56 KB
/
DFF_edge.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
###########################################################################
# Created by: Yuan Hu
# Email: [email protected]
# Copyright (c) 2019
###########################################################################
from __future__ import division
import os
import numpy as np
import torch
import torch.nn as nn
from modeling.base import BaseNet
class DFF(BaseNet):
r"""Dynamic Feature Fusion for Semantic Edge Detection
Parameters
----------
nclass : int
Number of categories for the training dataset.
backbone : string
Pre-trained dilated backbone network type (default:'resnet50'; 'resnet50',
'resnet101' or 'resnet152').
norm_layer : object
Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`
Reference:
Yuan Hu, Yunpeng Chen, Xiang Li, Jiashi Feng. "Dynamic Feature Fusion
for Semantic Edge Detection" *IJCAI*, 2019
"""
def __init__(self, nclass, backbone, norm_layer=nn.BatchNorm2d, **kwargs):
super(DFF, self).__init__(nclass, backbone, norm_layer=norm_layer, **kwargs)
self.nclass = nclass
self.ada_learner = LocationAdaptiveLearner(nclass, nclass * 4, nclass * 4, norm_layer=norm_layer)
self.side1 = nn.Sequential(nn.Conv2d(64, 1, 1),
norm_layer(1))
self.side2 = nn.Sequential(nn.Conv2d(256, 1, 1, bias=True),
norm_layer(1),
nn.ConvTranspose2d(1, 1, 4, stride=2, padding=1, bias=False))
self.side3 = nn.Sequential(nn.Conv2d(512, 1, 1, bias=True),
norm_layer(1),
nn.ConvTranspose2d(1, 1, 8, stride=4, padding=2, bias=False))
self.side5 = nn.Sequential(nn.Conv2d(2048, nclass, 1, bias=True),
norm_layer(nclass),
nn.ConvTranspose2d(nclass, nclass, 16, stride=8, padding=4, bias=False))
self.side5_w = nn.Sequential(nn.Conv2d(2048, nclass * 4, 1, bias=True),
norm_layer(nclass * 4),
nn.ConvTranspose2d(nclass * 4, nclass * 4, 16, stride=8, padding=4, bias=False))
def forward(self, x):
c1, c2, c3, _, c5 = self.base_forward(x)
side1 = self.side1(c1) # (N, 1, H, W)
side2 = self.side2(c2) # (N, 1, H, W)
side3 = self.side3(c3) # (N, 1, H, W)
side5 = self.side5(c5) # (N, 19, H, W)
side5_w = self.side5_w(c5) # (N, 19*4, H, W)
ada_weights = self.ada_learner(side5_w) # (N, 19, 4, H, W)
slice5 = side5[:, 0:1, :, :] # (N, 1, H, W)
fuse = torch.cat((slice5, side1, side2, side3), 1)
for i in range(side5.size(1) - 1):
slice5 = side5[:, i + 1:i + 2, :, :] # (N, 1, H, W)
fuse = torch.cat((fuse, slice5, side1, side2, side3), 1) # (N, 19*4, H, W)
fuse = fuse.view(fuse.size(0), self.nclass, -1, fuse.size(2), fuse.size(3)) # (N, 19, 4, H, W)
fuse = torch.mul(fuse, ada_weights) # (N, 19, 4, H, W)
fuse = torch.sum(fuse, 2) # (N, 19, H, W)
outputs = [side5, fuse]
#outputs = [torch.sigmoid(r) for r in outputs]
return tuple(outputs)
class LocationAdaptiveLearner(nn.Module):
"""docstring for LocationAdaptiveLearner"""
def __init__(self, nclass, in_channels, out_channels, norm_layer=nn.BatchNorm2d):
super(LocationAdaptiveLearner, self).__init__()
self.nclass = nclass
self.conv1 = nn.Sequential(nn.Conv2d(in_channels, out_channels, 1, bias=True),
norm_layer(out_channels),
nn.ReLU(inplace=True))
self.conv2 = nn.Sequential(nn.Conv2d(out_channels, out_channels, 1, bias=True),
norm_layer(out_channels),
nn.ReLU(inplace=True))
self.conv3 = nn.Sequential(nn.Conv2d(out_channels, out_channels, 1, bias=True),
norm_layer(out_channels))
def forward(self, x):
# x:side5_w (N, 19*4, H, W)
x = self.conv1(x) # (N, 19*4, H, W)
x = self.conv2(x) # (N, 19*4, H, W)
x = self.conv3(x) # (N, 19*4, H, W)
x = x.view(x.size(0), self.nclass, -1, x.size(2), x.size(3)) # (N, 19, 4, H, W)
return x
if __name__ == '__main__':
model = DFF(1, backbone='resnet50')
dummy_input = torch.rand(1, 3, 524, 1000)
output = model(dummy_input)
for out in output:
print(out.size())
#print(out)