Skip to content

Commit

Permalink
add zeroing of weight parameters of batchnorm in levit just before re…
Browse files Browse the repository at this point in the history
…sidual connection, noticed by @EelcoHoogendoorn
  • Loading branch information
lucidrains committed Apr 27, 2021
1 parent 3df6c31 commit 0df1505
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 2 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.16.11',
version = '0.16.12',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
Expand Down
5 changes: 4 additions & 1 deletion vit_pytorch/levit.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,10 +53,13 @@ def __init__(self, dim, fmap_size, heads = 8, dim_key = 32, dim_value = 64, drop

self.attend = nn.Softmax(dim = -1)

out_batch_norm = nn.BatchNorm2d(dim_out)
nn.init.zeros_(out_batch_norm.weight)

self.to_out = nn.Sequential(
nn.GELU(),
nn.Conv2d(inner_dim_value, dim_out, 1),
nn.BatchNorm2d(dim_out),
out_batch_norm,
nn.Dropout(dropout)
)

Expand Down

0 comments on commit 0df1505

Please sign in to comment.