Skip to content

Commit

Permalink
use convolution on query with padding to give the network absolute sp…
Browse files Browse the repository at this point in the history
…atial awareness in addition to relative encoding from rotary embeddings
  • Loading branch information
lucidrains committed Apr 14, 2021
1 parent 6289619 commit 53b3af0
Showing 1 changed file with 40 additions and 8 deletions.
48 changes: 40 additions & 8 deletions vit_pytorch/rvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,16 @@ def forward(self, x):
sin, cos = map(lambda t: repeat(t, 'n d -> () n (d j)', j = 2), (sin, cos))
return sin, cos

class DepthWiseConv2d(nn.Module):
def __init__(self, dim_in, dim_out, kernel_size, padding, stride = 1, bias = True):
super().__init__()
self.net = nn.Sequential(
nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias)
)
def forward(self, x):
return self.net(x)

# helper classes

class PreNorm(nn.Module):
Expand All @@ -53,6 +63,18 @@ def __init__(self, dim, fn):
def forward(self, x, **kwargs):
return self.fn(self.norm(x), **kwargs)

class SpatialConv(nn.Module):
def __init__(self, dim_in, dim_out, kernel, bias = False):
super().__init__()
self.conv = DepthWiseConv2d(dim_in, dim_out, kernel, padding = kernel // 2, bias = False)

def forward(self, x, fmap_dims):
cls_token, x = x[:, :1], x[:, 1:]
x = rearrange(x, 'b (h w) d -> b d h w', **fmap_dims)
x = self.conv(x)
x = rearrange(x, 'b d h w -> b (h w) d')
return torch.cat((cls_token, x), dim = 1)

class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim = -1)
Expand All @@ -72,24 +94,30 @@ def forward(self, x):
return self.net(x)

class Attention(nn.Module):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., conv_query_kernel = 9):
super().__init__()
inner_dim = dim_head * heads

self.heads = heads
self.scale = dim_head ** -0.5

self.attend = nn.Softmax(dim = -1)
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

self.to_q = SpatialConv(dim, inner_dim, conv_query_kernel, bias = False)

self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)

self.to_out = nn.Sequential(
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)

def forward(self, x, pos_emb):
def forward(self, x, pos_emb, fmap_dims):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim = -1)

q = self.to_q(x, fmap_dims = fmap_dims)
qkv = (q, *self.to_kv(x).chunk(2, dim = -1))

q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)

# apply 2d rotary embeddings to queries and keys, excluding CLS tokens
Expand Down Expand Up @@ -121,11 +149,11 @@ def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
]))
def forward(self, x):
def forward(self, x, fmap_dims):
pos_emb = self.pos_emb(x[:, 1:])

for attn, ff in self.layers:
x = attn(x, pos_emb = pos_emb) + x
x = attn(x, pos_emb = pos_emb, fmap_dims = fmap_dims) + x
x = ff(x) + x
return x

Expand All @@ -138,6 +166,7 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2

self.patch_size = patch_size
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.Linear(patch_dim, dim),
Expand All @@ -152,12 +181,15 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml
)

def forward(self, img):
b, _, h, w, p = *img.shape, self.patch_size

x = self.to_patch_embedding(img)
b, n, _ = x.shape
n = x.shape[1]

cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim=1)

x = self.transformer(x)
fmap_dims = {'h': h // p, 'w': w // p}
x = self.transformer(x, fmap_dims = fmap_dims)

return self.mlp_head(x)

0 comments on commit 53b3af0

Please sign in to comment.