From 3cffecf74407957942387c10e39b2990e8dd3cbd Mon Sep 17 00:00:00 2001 From: uyzhang Date: Wed, 5 Jan 2022 22:46:47 +0800 Subject: [PATCH] update offset_attention --- code/spatial_attentions/offset_module.py | 40 ++++++++++++++++++++++++ 1 file changed, 40 insertions(+) create mode 100644 code/spatial_attentions/offset_module.py diff --git a/code/spatial_attentions/offset_module.py b/code/spatial_attentions/offset_module.py new file mode 100644 index 0000000..9a59ded --- /dev/null +++ b/code/spatial_attentions/offset_module.py @@ -0,0 +1,40 @@ +# PCT: Point Cloud Transformer (CVMJ 2021) + +import jittor as jt +from jittor import nn + + +class SA_Layer(nn.Module): + def __init__(self, channels): + super(SA_Layer, self).__init__() + self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False) + self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False) + self.q_conv.weight = self.k_conv.weight + self.v_conv = nn.Conv1d(channels, channels, 1) + self.trans_conv = nn.Conv1d(channels, channels, 1) + self.after_norm = nn.BatchNorm1d(channels) + self.act = nn.ReLU() + self.softmax = nn.Softmax(dim=-1) + + def execute(self, x): + x_q = self.q_conv(x).permute(0, 2, 1) # b, n, c + x_k = self.k_conv(x) # b, c, n + x_v = self.v_conv(x) + energy = nn.bmm(x_q, x_k) # b, n, n + attention = self.softmax(energy) + attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True)) + x_r = nn.bmm(x_v, attention) # b, c, n + x_r = self.act(self.after_norm(self.trans_conv(x - x_r))) + x = x + x_r + return x + + +def main(): + attention_block = SA_Layer(64) + input = jt.rand([4, 64, 32]) + output = attention_block(input) + print(input.size(), output.size()) + + +if __name__ == '__main__': + main()