Skip to content

Commit

Permalink
fix hidden dimension in MaxViT thanks to @arquolo
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 25, 2022
1 parent 6460119 commit 2c6dd70
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.35.4',
version = '0.35.5',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
long_description_content_type = 'text/markdown',
Expand Down
12 changes: 6 additions & 6 deletions vit_pytorch/max_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,12 @@ def MBConv(
stride = 2 if downsample else 1

net = nn.Sequential(
nn.Conv2d(dim_in, dim_out, 1),
nn.BatchNorm2d(dim_out),
nn.SiLU(),
nn.Conv2d(dim_out, dim_out, 3, stride = stride, padding = 1, groups = dim_out),
SqueezeExcitation(dim_out, shrinkage_rate = shrinkage_rate),
nn.Conv2d(dim_out, dim_out, 1),
nn.Conv2d(dim_in, hidden_dim, 1),
nn.BatchNorm2d(hidden_dim),
nn.GELU(),
nn.Conv2d(hidden_dim, hidden_dim, 3, stride = stride, padding = 1, groups = dim_out),
SqueezeExcitation(hidden_dim, shrinkage_rate = shrinkage_rate),
nn.Conv2d(hidden_dim, dim_out, 1),
nn.BatchNorm2d(dim_out)
)

Expand Down

0 comments on commit 2c6dd70

Please sign in to comment.