diff --git a/setup.py b/setup.py index 7a42e3b..14c999e 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.33.0', + version = '0.33.1', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/max_vit.py b/vit_pytorch/max_vit.py index cf9ac45..3db5ce4 100644 --- a/vit_pytorch/max_vit.py +++ b/vit_pytorch/max_vit.py @@ -28,6 +28,20 @@ def __init__(self, dim, fn): def forward(self, x): return self.fn(self.norm(x)) + x +class FeedForward(nn.Module): + def __init__(self, dim, mult = 4, dropout = 0.): + super().__init__() + inner_dim = int(dim * mult) + self.net = nn.Sequential( + nn.Linear(dim, inner_dim), + nn.GELU(), + nn.Dropout(dropout), + nn.Linear(inner_dim, dim), + nn.Dropout(dropout) + ) + def forward(self, x): + return self.net(x) + # MBConv class SqueezeExcitation(nn.Module): @@ -244,10 +258,12 @@ def __init__( ), Rearrange('b d (x w1) (y w2) -> b x y w1 w2 d', w1 = w, w2 = w), # block-like attention PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)), + PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)), Rearrange('b x y w1 w2 d -> b d (x w1) (y w2)'), Rearrange('b d (w1 x) (w2 y) -> b x y w1 w2 d', w1 = w, w2 = w), # grid-like attention PreNormResidual(layer_dim, Attention(dim = layer_dim, dim_head = dim_head, dropout = dropout, window_size = w)), + PreNormResidual(layer_dim, FeedForward(dim = layer_dim, dropout = dropout)), Rearrange('b x y w1 w2 d -> b d (w1 x) (w2 y)'), )