Skip to content

Pytorch version of Vision Transformer (ViT) with pretrained models. This is part of CASL (https://casl-project.github.io/) and ASYML project.

License

Notifications You must be signed in to change notification settings

GQI7FS6/vision-transformer-pytorch

 
 

Repository files navigation

Vision Transformer - Pytorch

Pytorch implementation of Vision Transformer. Pretrained pytorch weights are provided which are converted from original jax/flax weights. This is a project of the ASYML family and CASL.

Introduction

Figure 1 from paper

Pytorch implementation of paper An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. We provide the pretrained pytorch weights which are converted from pretrained jax/flax models. We also provide fine-tune and evaluation script. Similar results as in original implementation are achieved.

Installation

Create environment:

conda create --name vit --file requirements.txt
conda activate vit

Available Models

We provide pytorch model weights, which are converted from original jax/flax wieghts. You can download them and put the files under 'weights/pytorch' to use them.

Otherwise you can download the original jax/flax weights and put the fimes under 'weights/jax' to use them. We'll convert the weights for you online.

Datasets

Currently three datasets are supported: ImageNet2012, CIFAR10, and CIFAR100. To evaluate or fine-tune on these datasets, download the datasets and put them in 'data/dataset_name'.

More datasets will be supported.

Fine-Tune/Train

python src/train.py --exp-name ft --n-gpu 4 --tensorboard  --model-arch b16 --checkpoint-path weights/pytorch/imagenet21k+imagenet2012_ViT-B_16.pth --image-size 384 --batch-size 32 --data-dir data/ --dataset CIFAR10 --num-classes 10 --train-steps 10000 --lr 0.03 --wd 0.0

Evaluation

Make sure you have downloaded the pretrained weights either in '.npy' format or '.pth' format

python src/eval.py --model-arch b16 --checkpoint-path weights/jax/imagenet21k+imagenet2012_ViT-B_16.npy --image-size 384 --batch-size 128 --data-dir data/ImageNet --dataset ImageNet --num-classes 1000

Results and Models

Pretrained Results on ImageNet2012

upstream model dataset orig. jax acc pytorch acc model link
imagenet21k ViT-B_16 imagenet2012 84.62 83.90 checkpoint
imagenet21k ViT-B_32 imagenet2012 81.79 81.14 checkpoint
imagenet21k ViT-L_16 imagenet2012 85.07 84.94 checkpoint
imagenet21k ViT-L_32 imagenet2012 82.01 81.03 checkpoint

Fine-Tune Results on CIFAR10/100

Due to limited GPU resources, the fine-tune results are obtained by using a batch size of 32 which may impact the performance a bit.

upstream model dataset orig. jax acc pytorch acc
imagenet21k ViT-B_16 CIFAR10 98.92 98.90
imagenet21k ViT-B_16 CIFAR100 92.26 91.65

TODO

  • Colab
  • Integrated into Texar

Acknowledge

  1. https://github.com/google-research/vision_transformer
  2. https://github.com/lucidrains/vit-pytorch
  3. https://github.com/kamalkraj/Vision-Transformer

Contributing

Issues and Pull Requests are welcome for improving this repo. Please follow the contribution guide

License

Apache License 2.0

Supporting Companies and Universities

                  

About

Pytorch version of Vision Transformer (ViT) with pretrained models. This is part of CASL (https://casl-project.github.io/) and ASYML project.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 82.4%
  • Jupyter Notebook 17.6%