Skip to content

Commit

Permalink
Merge pull request MenghaoGuo#18 from uyzhang/main
Browse files Browse the repository at this point in the history
update triplet_attention
  • Loading branch information
uyzhang authored Dec 27, 2021
2 parents 5ca4d95 + 8444eb1 commit 42e0bc4
Showing 1 changed file with 103 additions and 0 deletions.
103 changes: 103 additions & 0 deletions code/channel_spatial_attentions/triplet_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Rotate to attend: Convolutional triplet attention module (WACV 2021)
import jittor as jt
from jittor import nn


class BasicConv(nn.Module):
def __init__(
self,
in_planes,
out_planes,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
relu=True,
bn=True,
bias=False,
):
super(BasicConv, self).__init__()
self.out_channels = out_planes
self.conv = nn.Conv2d(
in_planes,
out_planes,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias,
)
self.bn = (
nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True)
if bn
else None
)
self.relu = nn.ReLU() if relu else None

def execute(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x


class ZPool(nn.Module):
def execute(self, x):
return jt.concat(
(x.max(1).unsqueeze(1), x.mean(1).unsqueeze(1)), dim=1
)


class AttentionGate(nn.Module):
def __init__(self):
super(AttentionGate, self).__init__()
kernel_size = 7
self.compress = ZPool()
self.conv = BasicConv(
2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False
)

def execute(self, x):
x_compress = self.compress(x)
x_out = self.conv(x_compress)
scale = x_out.sigmoid()
return x * scale


class TripletAttention(nn.Module):
def __init__(self, no_spatial=False):
super(TripletAttention, self).__init__()
self.cw = AttentionGate()
self.hc = AttentionGate()
self.no_spatial = no_spatial
if not no_spatial:
self.hw = AttentionGate()

def execute(self, x):
x_perm1 = x.permute(0, 2, 1, 3)
x_out1 = self.cw(x_perm1)
x_out11 = x_out1.permute(0, 2, 1, 3)
x_perm2 = x.permute(0, 3, 2, 1)
x_out2 = self.hc(x_perm2)
x_out21 = x_out2.permute(0, 3, 2, 1)
if not self.no_spatial:
x_out = self.hw(x)
x_out = 1 / 3 * (x_out + x_out11 + x_out21)
else:
x_out = 1 / 2 * (x_out11 + x_out21)
return x_out


def main():
attention_block = TripletAttention()
input = jt.ones([4, 64, 32, 32])
output = attention_block(input)
print(input.size(), output.size())


if __name__ == '__main__':
main()

0 comments on commit 42e0bc4

Please sign in to comment.