Skip to content

Commit

Permalink
update external_attention
Browse files Browse the repository at this point in the history
  • Loading branch information
uyzhang committed Jan 2, 2022
1 parent ff993b1 commit 349c555
Showing 1 changed file with 57 additions and 0 deletions.
57 changes: 57 additions & 0 deletions code/spatial_attentions/external_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import jittor as jt
from jittor import nn


class External_attention(nn.Module):
'''
Arguments:
c (int): The input and output channel number.
'''

def __init__(self, c):
super(External_attention, self).__init__()

self.conv1 = nn.Conv2d(c, c, 1)

self.k = 64
self.linear_0 = nn.Conv1d(c, self.k, 1, bias=False)

self.linear_1 = nn.Conv1d(self.k, c, 1, bias=False)
self.linear_1.weight = self.linear_0.weight.permute(1, 0, 2)

self.conv2 = nn.Sequential(
nn.Conv2d(c, c, 1, bias=False),
nn.BatchNorm(c))

self.relu = nn.ReLU()

def execute(self, x):
idn = x
x = self.conv1(x)

b, c, h, w = x.size()
n = h*w
x = x.view(b, c, h*w) # b * c * n

attn = self.linear_0(x) # b, k, n
attn = nn.softmax(attn, dim=-1) # b, k, n

attn = attn / (1e-9 + attn.sum(dim=1, keepdims=True)) # b, k, n
x = self.linear_1(attn) # b, c, n

x = x.view(b, c, h, w)
x = self.conv2(x)
x = x + idn
x = self.relu(x)
return x


def main():
attention_block = External_attention(64)
input = jt.rand([4, 64, 32, 32])
output = attention_block(input)
print(input.size(), output.size())


if __name__ == '__main__':
main()

0 comments on commit 349c555

Please sign in to comment.