Skip to content

Commit

Permalink
fix maxvit - need feedforwards after attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 6, 2022
1 parent 2d4089c commit 13f8e12
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 1 deletion.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.33.0',
version = '0.33.1',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
Expand Down
16 changes: 16 additions & 0 deletions vit_pytorch/max_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,20 @@ def __init__(self, dim, fn):
def forward(self, x):
return self.fn(self.norm(x)) + x

class FeedForward(nn.Module):
def __init__(self, dim, mult = 4, dropout = 0.):
super().__init__()
inner_dim = int(dim * mult)
self.net = nn.Sequential(
nn.Linear(dim, inner_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(inner_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)

# MBConv

class SqueezeExcitation(nn.Module):
Expand Down Expand Up @@ -244,10 +258,12 @@ def __init__(
),
Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w), # block-like attention
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'),

Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w), # grid-like attention
PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)),
PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)),
Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'),
)

Expand Down

0 comments on commit 13f8e12

Please sign in to comment.