Skip to content

Commit

Permalink
Add use_glu flag to RvT
Browse files Browse the repository at this point in the history
  • Loading branch information
jon-tow committed Apr 30, 2021
1 parent 7807f24 commit 6f3a5fc
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions vit_pytorch/rvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@ def forward(self, x):
return F.gelu(gates) * x

class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout = 0.):
def __init__(self, dim, hidden_dim, dropout = 0., use_glu = True):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim * 2),
GEGLU(),
nn.Linear(dim, hidden_dim * 2 if use_glu else hidden_dim),
GEGLU() if use_glu else nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
Expand Down Expand Up @@ -154,14 +154,14 @@ def forward(self, x, pos_emb, fmap_dims):
return self.to_out(out)

class Transformer(nn.Module):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_rotary = True, use_ds_conv = True):
def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0., use_rotary = True, use_ds_conv = True, use_glu = True):
super().__init__()
self.layers = nn.ModuleList([])
self.pos_emb = AxialRotaryEmbedding(dim_head)
for _ in range(depth):
self.layers.append(nn.ModuleList([
PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout, use_rotary = use_rotary, use_ds_conv = use_ds_conv)),
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout, use_glu = use_glu))
]))
def forward(self, x, fmap_dims):
pos_emb = self.pos_emb(x[:, 1:])
Expand All @@ -174,7 +174,7 @@ def forward(self, x, fmap_dims):
# Rotary Vision Transformer

class RvT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., use_rotary = True, use_ds_conv = True):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0., use_rotary = True, use_ds_conv = True, use_glu = True):
super().__init__()
assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
num_patches = (image_size // patch_size) ** 2
Expand All @@ -187,7 +187,7 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml
)

self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, use_rotary, use_ds_conv)
self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout, use_rotary, use_ds_conv, use_glu)

self.mlp_head = nn.Sequential(
nn.LayerNorm(dim),
Expand Down

0 comments on commit 6f3a5fc

Please sign in to comment.