Skip to content

Commit

Permalink
update offset_attention
Browse files Browse the repository at this point in the history
  • Loading branch information
uyzhang committed Jan 5, 2022
1 parent 211aece commit 3cffecf
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions code/spatial_attentions/offset_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# PCT: Point Cloud Transformer (CVMJ 2021)

import jittor as jt
from jittor import nn


class SA_Layer(nn.Module):
def __init__(self, channels):
super(SA_Layer, self).__init__()
self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
self.q_conv.weight = self.k_conv.weight
self.v_conv = nn.Conv1d(channels, channels, 1)
self.trans_conv = nn.Conv1d(channels, channels, 1)
self.after_norm = nn.BatchNorm1d(channels)
self.act = nn.ReLU()
self.softmax = nn.Softmax(dim=-1)

def execute(self, x):
x_q = self.q_conv(x).permute(0, 2, 1) # b, n, c
x_k = self.k_conv(x) # b, c, n
x_v = self.v_conv(x)
energy = nn.bmm(x_q, x_k) # b, n, n
attention = self.softmax(energy)
attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True))
x_r = nn.bmm(x_v, attention) # b, c, n
x_r = self.act(self.after_norm(self.trans_conv(x - x_r)))
x = x + x_r
return x


def main():
attention_block = SA_Layer(64)
input = jt.rand([4, 64, 32])
output = attention_block(input)
print(input.size(), output.size())


if __name__ == '__main__':
main()

0 comments on commit 3cffecf

Please sign in to comment.