Skip to content

Commit

Permalink
Update cbam.py
Browse files Browse the repository at this point in the history
add cbam.py
  • Loading branch information
MenghaoGuo authored Nov 30, 2021
1 parent 12701bd commit c465452
Showing 1 changed file with 94 additions and 0 deletions.
94 changes: 94 additions & 0 deletions code/channel_spatial_attentions/cbam.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import jittor as jt
import jittor.nn as nn

class BasicConv(nn.Module):
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, bn=True, bias=False):
super(BasicConv, self).__init__()
self.out_channels = out_planes
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
self.relu = nn.ReLU() if relu else None

def execute(self, x):
x = self.conv(x)
if self.bn is not None:
x = self.bn(x)
if self.relu is not None:
x = self.relu(x)
return x


class ChannelGate(nn.Module):
def __init__(self, channel, reduction=16):
super(ChannelGate, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc_avg = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(),
nn.Linear(channel // reduction, channel, bias=False),
)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc_max = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=False),
nn.ReLU(),
nn.Linear(channel // reduction, channel, bias=False),
)

self.sigmoid = nn.Sigmoid()

def execute(self, x):
b, c, _, _ = x.size()
y_avg = self.avg_pool(x).view(b, c)
y_avg = self.fc_avg(y_avg).view(b, c, 1, 1)

y_max = self.max_pool(x).view(b, c)
y_max = self.fc_max(y_max).view(b, c, 1, 1)

y = self.sigmoid(y_avg+ y_avg)
return x * y.expand_as(x)


class ChannelPool(nn.Module):
def __init__(self):
super().__init__()

def execute(self, x):
x_max = jt.max(x,1).unsqueeze(1)
x_avg = jt.mean(x,1).unsqueeze(1)
x = jt.concat([x_max, x_avg], dim=1)
return x

class SpatialGate(nn.Module):
def __init__(self):
super(SpatialGate, self).__init__()
kernel_size = 7
self.compress = ChannelPool()
self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False)
self.sigmoid = nn.Sigmoid()
def execute(self, x):
x_compress = self.compress(x)
x_out = self.spatial(x_compress)
scale = self.sigmoid(x_out) # broadcasting
return x * scale

class CBAM(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16):
super(CBAM, self).__init__()
self.ChannelGate = ChannelGate(gate_channels, reduction_ratio)
self.SpatialGate = SpatialGate()
def execute(self, x):
x_out = self.ChannelGate(x)
x_out = self.SpatialGate(x_out)
return x_out


def main():
attention_blcok = CBAM(64)
input = jt.rand([4, 64, 32, 32])
output = attention_blcok(input)
print (input.size(), output.size())

if __name__ == '__main__':
main()


0 comments on commit c465452

Please sign in to comment.