Skip to content

Commit

Permalink
Update se_module.py
Browse files Browse the repository at this point in the history
add se block
  • Loading branch information
MenghaoGuo authored Nov 30, 2021
1 parent 8447e62 commit 3686d06
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions code/channel_attentions/se_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import jittor as jt
import jittor.nn as nn

class SELayer(nn.Module):
def __init__(self, channel, reduction=16):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(),
nn.Linear(channel // reduction, channel, bias=False),
nn.Sigmoid()
)

def execute(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)


def main():
attention_blcok = SELayer(64)
input = jt.rand([4, 64, 32, 32])
output = attention_blcok(input)
print (input.size(), output.size())

if __name__ == '__main__':
main()


0 comments on commit 3686d06

Please sign in to comment.