Skip to content

Commit

Permalink
Merge pull request lucidrains#128 from stevenwalton/main
Browse files Browse the repository at this point in the history
Adding Compact Convolutional Transformers (CCT)
  • Loading branch information
lucidrains authored Jul 2, 2021
2 parents 64a2ef6 + 2ece333 commit 121353c
Show file tree
Hide file tree
Showing 2 changed files with 407 additions and 0 deletions.
68 changes: 68 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ Dropout rate.
Embedding dropout rate.
- `pool`: string, either `cls` token pooling or `mean` pooling


## Distillation

<img src="./images/distill.png" width="300px"></img>
Expand Down Expand Up @@ -118,6 +119,7 @@ v = v.to_vit()
type(v) # <class 'vit_pytorch.vit_pytorch.ViT'>
```


## Deep ViT

This <a href="https://arxiv.org/abs/2103.11886">paper</a> notes that ViT struggles to attend at greater depths (past 12 layers), and suggests mixing the attention of each head post-softmax as a solution, dubbed Re-attention. The results line up with the <a href="https://github.com/lucidrains/x-transformers#talking-heads-attention">Talking Heads</a> paper from NLP.
Expand Down Expand Up @@ -201,6 +203,61 @@ img = torch.randn(1, 3, 224, 224)
preds = v(img) # (1, 1000)
```

## CCT
<img src="https://raw.githubusercontent.com/SHI-Labs/Compact-Transformers/main/images/model_sym.png" width="400px"></img>

<a href="https://arxiv.org/abs/2104.05704">CCT</a> proposes compact transformers
by using convolutions instead of patching and performing sequence pooling. This
allows for CCT to have high accuracy and a low number of parameters.

You can use this with two methods
```python
import torch
from vit_pytorch.cct import CCT

model = CCT(
img_size=224,
embedding_dim=384,
n_conv_layers=2,
kernel_size=7,
stride=2,
padding=3,
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
num_layers=14,
num_heads=6,
mlp_radio=3.,
num_classes=1000,
positional_embedding='learnable', # ['sine', 'learnable', 'none']
)
```

Alternatively you can use one of several pre-defined models `[2,4,6,7,8,14,16]`
which pre-define the number of layers, number of attention heads, the mlp ratio,
and the embedding dimension.

```python
import torch
from vit_pytorch.cct import cct_14

model = cct_14(
img_size=224,
n_conv_layers=1,
kernel_size=7,
stride=2,
padding=3,
pooling_kernel_size=3,
pooling_stride=2,
pooling_padding=1,
num_classes=1000,
positional_embedding='learnable', # ['sine', 'learnable', 'none']
)
```
<a href="https://github.com/SHI-Labs/Compact-Transformers">Official
Repository</a> includes links to pretrained model checkpoints.


## Cross ViT

<img src="./images/cross_vit.png" width="400px"></img>
Expand Down Expand Up @@ -680,6 +737,17 @@ Coming from computer vision and new to transformers? Here are some resources tha


## Citations
```bibtex
@article{hassani2021escaping,
title = {Escaping the Big Data Paradigm with Compact Transformers},
author = {Ali Hassani and Steven Walton and Nikhil Shah and Abulikemu Abuduweili and Jiachen Li and Humphrey Shi},
year = 2021,
url = {https://arxiv.org/abs/2104.05704},
eprint = {2104.05704},
archiveprefix = {arXiv},
primaryclass = {cs.CV}
}
```

```bibtex
@misc{dosovitskiy2020image,
Expand Down
Loading

0 comments on commit 121353c

Please sign in to comment.