Skip to content

Commit

Permalink
Merge pull request MenghaoGuo#23 from uyzhang/main
Browse files Browse the repository at this point in the history
update
  • Loading branch information
uyzhang authored Jan 5, 2022
2 parents c096b1b + 3cffecf commit 73918ba
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 0 deletions.
40 changes: 40 additions & 0 deletions code/spatial_attentions/offset_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# PCT: Point Cloud Transformer (CVMJ 2021)

import jittor as jt
from jittor import nn


class SA_Layer(nn.Module):
def __init__(self, channels):
super(SA_Layer, self).__init__()
self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
self.q_conv.weight = self.k_conv.weight
self.v_conv = nn.Conv1d(channels, channels, 1)
self.trans_conv = nn.Conv1d(channels, channels, 1)
self.after_norm = nn.BatchNorm1d(channels)
self.act = nn.ReLU()
self.softmax = nn.Softmax(dim=-1)

def execute(self, x):
x_q = self.q_conv(x).permute(0, 2, 1) # b, n, c
x_k = self.k_conv(x) # b, c, n
x_v = self.v_conv(x)
energy = nn.bmm(x_q, x_k) # b, n, n
attention = self.softmax(energy)
attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True))
x_r = nn.bmm(x_v, attention) # b, c, n
x_r = self.act(self.after_norm(self.trans_conv(x - x_r)))
x = x + x_r
return x


def main():
attention_block = SA_Layer(64)
input = jt.rand([4, 64, 32])
output = attention_block(input)
print(input.size(), output.size())


if __name__ == '__main__':
main()
86 changes: 86 additions & 0 deletions code/spatial_temporal_attentions/dstt_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Decoupled spatial-temporal transformer for video inpainting (arXiv 2021)
import math
import jittor as jt
from jittor import nn


class Attention(nn.Module):
"""
Compute 'Scaled Dot Product Attention
"""

def __init__(self, p=0.1):
super(Attention, self).__init__()
self.dropout = nn.Dropout(p=p)

def execute(self, query, key, value):
scores = jt.matmul(query, key.transpose(-2, -1)
) / math.sqrt(query.size(-1))
p_attn = nn.softmax(scores, dim=-1)
p_attn = self.dropout(p_attn)
p_val = jt.matmul(p_attn, value)
return p_val, p_attn


class MultiHeadedAttention(nn.Module):
"""
Take in model size and number of heads.
"""

def __init__(self, tokensize, d_model, head, mode, p=0.1):
super().__init__()
self.mode = mode
self.query_embedding = nn.Linear(d_model, d_model)
self.value_embedding = nn.Linear(d_model, d_model)
self.key_embedding = nn.Linear(d_model, d_model)
self.output_linear = nn.Linear(d_model, d_model)
self.attention = Attention(p=p)
self.head = head
self.h, self.w = tokensize

def execute(self, x, t):
bt, n, c = x.size()
b = bt // t
c_h = c // self.head
key = self.key_embedding(x)
query = self.query_embedding(x)
value = self.value_embedding(x)
if self.mode == 's':
key = key.view(b, t, n, self.head, c_h).permute(0, 1, 3, 2, 4)
query = query.view(b, t, n, self.head, c_h).permute(0, 1, 3, 2, 4)
value = value.view(b, t, n, self.head, c_h).permute(0, 1, 3, 2, 4)
att, _ = self.attention(query, key, value)
att = att.permute(0, 1, 3, 2, 4).view(bt, n, c)
elif self.mode == 't':
key = key.view(b, t, 2, self.h//2, 2, self.w//2, self.head, c_h)
key = key.permute(0, 2, 4, 6, 1, 3, 5, 7).view(
b, 4, self.head, -1, c_h)
query = query.view(b, t, 2, self.h//2, 2,
self.w//2, self.head, c_h)
query = query.permute(0, 2, 4, 6, 1, 3, 5, 7).view(
b, 4, self.head, -1, c_h)
value = value.view(b, t, 2, self.h//2, 2,
self.w//2, self.head, c_h)
value = value.permute(0, 2, 4, 6, 1, 3, 5, 7).view(
b, 4, self.head, -1, c_h)
att, _ = self.attention(query, key, value)
att = att.view(b, 2, 2, self.head, t, self.h//2, self.w//2, c_h)
att = att.permute(0, 4, 1, 5, 2, 6, 3,
7).view(bt, n, c)
output = self.output_linear(att)
return output


def main():
attention_block_s = MultiHeadedAttention(
tokensize=[4, 8], d_model=64, head=4, mode='s')
attention_block_t = MultiHeadedAttention(
tokensize=[4, 8], d_model=64, head=4, mode='t')
input = jt.rand([8, 32, 64])
output = attention_block_s(input, 2)
output = attention_block_t(output, 2)
print(input.size(), output.size())


if __name__ == '__main__':
main()

0 comments on commit 73918ba

Please sign in to comment.