Skip to content

Commit

Permalink
update segformer_module
Browse files Browse the repository at this point in the history
  • Loading branch information
uyzhang committed Jan 2, 2022
1 parent 8cdd5d1 commit 7786452
Showing 1 changed file with 63 additions and 0 deletions.
63 changes: 63 additions & 0 deletions code/spatial_attentions/segformer_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# SegFormer: Simple and Efficient Design for Semantic Segmentation with Transformers
import jittor as jt
from jittor import nn


class EfficientAttention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
super().__init__()
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

self.dim = dim
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5

self.q = nn.Linear(dim, dim, bias=qkv_bias)
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)

self.sr_ratio = sr_ratio
if sr_ratio > 1:
self.sr = nn.Conv2d(
dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)

def execute(self, x, H, W):
B, N, C = x.shape
q = self.q(x).reshape(B, N, self.num_heads, C //
self.num_heads).permute(0, 2, 1, 3)

if self.sr_ratio > 1:
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
x_ = self.norm(x_)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads,
C // self.num_heads).permute(2, 0, 3, 1, 4)
else:
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C //
self.num_heads).permute(2, 0, 3, 1, 4)
k, v = kv[0], kv[1]

attn = (q @ k.transpose(0, 1, 3, 2)) * self.scale
attn = nn.softmax(attn, dim=-1)
attn = self.attn_drop(attn)

x = (attn @ v).transpose(0, 2, 1, 3).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)

return x


def main():
attention_block = EfficientAttention(64)
input = jt.rand([4, 128, 64])
output = attention_block(input, 8, 8)
print(input.size(), output.size())


if __name__ == '__main__':
main()

0 comments on commit 7786452

Please sign in to comment.