Skip to content

Commit

Permalink
fix small bug
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 29, 2021
1 parent a612327 commit 7807f24
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 5 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -338,7 +338,7 @@ pred = v(img) # (1, 1000)

<img src="./images/twins_svt.png" width="400px"></img>

This <a href="https://arxiv.org/abs/2104.13840">paper</a> mixes local and global attention, along with position encoding generator (proposed in <a href="https://arxiv.org/abs/2102.10882">CPVT</a>) and global average pooling, to achieve the same results as <a href="https://arxiv.org/abs/2103.14030">Swin</a>, without the extra complexity of shifted windows, etc.
This <a href="https://arxiv.org/abs/2104.13840">paper</a> proposes mixing local and global attention, along with position encoding generator (proposed in <a href="https://arxiv.org/abs/2102.10882">CPVT</a>) and global average pooling, to achieve the same results as <a href="https://arxiv.org/abs/2103.14030">Swin</a>, without the extra complexity of shifted windows, CLS tokens, nor positional embeddings.

```python
import torch
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.17.0',
version = '0.17.1',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
Expand Down
6 changes: 3 additions & 3 deletions vit_pytorch/twins_svt.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,11 +162,11 @@ def __init__(self, dim, depth, heads = 8, dim_head = 64, mlp_mult = 4, local_pat
Residual(PreNorm(dim, FeedForward(dim, mlp_mult, dropout = dropout)))
]))
def forward(self, x):
for local_attn, ff, global_attn, ff in self.layers:
for local_attn, ff1, global_attn, ff2 in self.layers:
x = local_attn(x)
x = ff(x)
x = ff1(x)
x = global_attn(x)
x = ff(x)
x = ff2(x)
return x

class TwinsSVT(nn.Module):
Expand Down

0 comments on commit 7807f24

Please sign in to comment.