Skip to content

Commit

Permalink
Merge pull request MenghaoGuo#14 from uyzhang/main
Browse files Browse the repository at this point in the history
update gc_module
  • Loading branch information
uyzhang authored Dec 24, 2021
2 parents f3959f1 + b2815ad commit 563b90d
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 1 deletion.
2 changes: 1 addition & 1 deletion code/channel_attentions/soca_module.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Second-Order Attention Network for Single Image Super-Resolution
# Second-Order Attention Network for Single Image Super-Resolution (CVPR 2019)
import jittor as jt
from jittor import nn, Function

Expand Down
65 changes: 65 additions & 0 deletions code/spatial_attentions/gc_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import jittor as jt
from jittor import nn


class GlobalContextBlock(nn.Module):
def __init__(self,
inplanes,
ratio):
super(GlobalContextBlock, self).__init__()
self.inplanes = inplanes
self.ratio = ratio
self.planes = int(inplanes * ratio)
self.conv_mask = nn.Conv2d(inplanes, 1, kernel_size=1)
self.softmax = nn.Softmax(dim=2)

self.channel_add_conv = nn.Sequential(
nn.Conv2d(self.inplanes, self.planes, kernel_size=1),
nn.LayerNorm([self.planes, 1, 1]),
nn.ReLU(), # yapf: disable
nn.Conv2d(self.planes, self.inplanes, kernel_size=1))

def spatial_pool(self, x):
batch, channel, height, width = x.size()

input_x = x
# [N, C, H * W]
input_x = input_x.view(batch, channel, height * width)
# [N, 1, C, H * W]
input_x = input_x.unsqueeze(1)
# [N, 1, H, W]
context_mask = self.conv_mask(x)
# [N, 1, H * W]
context_mask = context_mask.view(batch, 1, height * width)
# [N, 1, H * W]
context_mask = self.softmax(context_mask)
# [N, 1, H * W, 1]
context_mask = context_mask.unsqueeze(-1)
# [N, 1, C, 1]
context = jt.matmul(input_x, context_mask)
# [N, C, 1, 1]
context = context.view(batch, channel, 1, 1)

return context

def execute(self, x):
# [N, C, 1, 1]
context = self.spatial_pool(x)

out = x
# [N, C, 1, 1]
channel_add_term = self.channel_add_conv(context)
out = out + channel_add_term

return out


def main():
attention_blcok = GlobalContextBlock(64, 1/4)
input = jt.rand([4, 64, 32, 32])
output = attention_blcok(input)
print(input.size(), output.size())


if __name__ == '__main__':
main()

0 comments on commit 563b90d

Please sign in to comment.