Skip to content

Commit

Permalink
adding a better way to do attention
Browse files Browse the repository at this point in the history
  • Loading branch information
Abhay Gupta committed Nov 3, 2020
1 parent f734d41 commit 4e4e920
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 23 deletions.
47 changes: 30 additions & 17 deletions vit/models/Attention.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,44 @@
import math
import torch
import torch.nn as nn
from einops import rearrange
import torch.nn.functional as F


class SelfAttention(nn.Module):
def __init__(self, dim, heads=8):
def __init__(
self, dim, heads=8, qkv_bias=False, qk_scale=None, dropout_rate=0.0
):
super().__init__()
self.heads = heads
self.scale = dim ** -0.5
self.num_heads = heads
head_dim = dim // heads
self.scale = qk_scale or head_dim ** -0.5

self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
self.to_out = nn.Linear(dim, dim)
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(dropout_rate)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(dropout_rate)

def forward(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv=3, h=h)

dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
attn = dots.softmax(dim=-1)

out = torch.einsum('bhij,bhjd->bhid', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out
B, N, C = x.shape
qkv = (
self.qkv(x)
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
.permute(2, 0, 3, 1, 4)
)
q, k, v = (
qkv[0],
qkv[1],
qkv[2],
) # make torchscript happy (cannot use tensor as tuple)

attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x


class AxialAttention(nn.Module):
Expand Down
37 changes: 31 additions & 6 deletions vit/models/Transformer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from torch import nn
from Attention import SelfAttention, AxialAttention
from Attention import SelfAttention


class Residual(nn.Module):
Expand All @@ -21,26 +21,51 @@ def forward(self, x):
return self.fn(self.norm(x))


class PreNormDrop(nn.Module):
def __init__(self, dim, dropout_rate, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.dropout = nn.Dropout(p=dropout_rate)
self.fn = fn

def forward(self, x):
return self.dropout(self.fn(self.norm(x)))


class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
def __init__(self, dim, hidden_dim, dropout_rate):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, dim)
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(p=dropout_rate),
nn.Linear(hidden_dim, dim),
nn.Dropout(p=dropout_rate),
)

def forward(self, x):
return self.net(x)


class TransformerModel(nn.Module):
def __init__(self, dim, depth, heads, mlp_dim):
def __init__(self, dim, depth, heads, mlp_dim, dropout_rate=0.1):
super().__init__()
layers = []
for _ in range(depth):
layers.extend(
[
Residual(PreNorm(dim, SelfAttention(dim, heads=heads))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim))),
Residual(
PreNormDrop(
dim,
dropout_rate,
SelfAttention(
dim, heads=heads, dropout_rate=dropout_rate
),
)
),
Residual(
PreNorm(dim, FeedForward(dim, mlp_dim, dropout_rate))
),
]
)
self.net = nn.Sequential(*layers)
Expand Down

0 comments on commit 4e4e920

Please sign in to comment.