Skip to content

Commit

Permalink
coordatt_module
Browse files Browse the repository at this point in the history
  • Loading branch information
uyzhang committed Dec 26, 2021
1 parent 76ecd0e commit d919367
Showing 1 changed file with 69 additions and 0 deletions.
69 changes: 69 additions & 0 deletions code/channel_spatial_attentions/coordatt_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import jittor as jt
from jittor import nn


class h_sigmoid(nn.Module):
def __init__(self):
super(h_sigmoid, self).__init__()
self.relu = nn.ReLU6()

def execute(self, x):
return self.relu(x + 3) / 6


class h_swish(nn.Module):
def __init__(self):
super(h_swish, self).__init__()
self.sigmoid = h_sigmoid()

def execute(self, x):
return x * self.sigmoid(x)


class CoordAtt(nn.Module):
def __init__(self, inp, oup, reduction=32):
super(CoordAtt, self).__init__()
self.pool_h = nn.AdaptiveAvgPool2d((None, 1))
self.pool_w = nn.AdaptiveAvgPool2d((1, None))

mip = max(8, inp // reduction)

self.conv1 = nn.Conv2d(inp, mip, kernel_size=1, stride=1, padding=0)
self.bn1 = nn.BatchNorm2d(mip)
self.act = h_swish()

self.conv_h = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)
self.conv_w = nn.Conv2d(mip, oup, kernel_size=1, stride=1, padding=0)

def execute(self, x):
identity = x

n, c, h, w = x.size()
x_h = self.pool_h(x)
x_w = self.pool_w(x).permute(0, 1, 3, 2)

y = jt.concat([x_h, x_w], dim=2)
y = self.conv1(y)
y = self.bn1(y)
y = self.act(y)

x_h, x_w = jt.split(y, [h, w], dim=2)
x_w = x_w.permute(0, 1, 3, 2)

a_h = self.conv_h(x_h).sigmoid()
a_w = self.conv_w(x_w).sigmoid()

out = identity * a_w * a_h

return out


def main():
attention_block = CoordAtt(64, 64)
input = jt.rand([4, 64, 32, 32])
output = attention_block(input)
print(input.size(), output.size())


if __name__ == '__main__':
main()

0 comments on commit d919367

Please sign in to comment.