Skip to content

Commit

Permalink
first commit
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 4, 2020
1 parent 070469d commit ee8088b
Show file tree
Hide file tree
Showing 4 changed files with 151 additions and 1 deletion.
28 changes: 27 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,33 @@
## Vision Transformer - Pytorch (wip)
## Vision Transformer - Pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. There's really not much to code here, but may as well lay out all the code so we expedite the attention revolution and get everyone on the same page.

## Install

```bash
$ pip install vit-pytorch
```

## Usage

```python
import torch
from vit_pytorch import ViT

v = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 8,
mlp_dim = 2048
)

img = torch.randn(1, 3, 256, 256)
preds = v(img) # (1, 1000)
```

## Citations

```bibtex
Expand Down
28 changes: 28 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
from setuptools import setup, find_packages

setup(
name = 'vit-pytorch',
packages = find_packages(),
version = '0.0.1',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
author_email = '[email protected]',
url = 'https://github.com/lucidrains/vit-pytorch',
keywords = [
'artificial intelligence',
'attention mechanism',
'image recognition'
],
install_requires=[
'torch>=1.6',
'einops>=0.3'
],
classifiers=[
'Development Status :: 4 - Beta',
'Intended Audience :: Developers',
'Topic :: Scientific/Engineering :: Artificial Intelligence',
'License :: OSI Approved :: MIT License',
'Programming Language :: Python :: 3.6',
],
)
1 change: 1 addition & 0 deletions vit_pytorch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from vit_pytorch.vit_pytorch import ViT
95 changes: 95 additions & 0 deletions vit_pytorch/vit_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
from einops import rearrange
import torch.nn.functional as F
from torch import nn

class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x

class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x))

class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, dim)
)
def forward(self, x):
return self.net(x)

class Attention(nn.Module):
def __init__(self, dim, heads = 8):
super().__init__()
self.heads = heads
self.scale = dim ** -0.5

self.to_qkv = nn.Linear(dim, dim * 3, bias = False)
self.to_out = nn.Linear(dim, dim)
def forward(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x)
q, k, v = rearrange(qkv, 'b n (qkv h d) -> qkv b h n d', qkv = 3, h = h)

dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
attn = dots.softmax(dim=-1)

out = torch.einsum('bhij,bhjd->bhid', attn, v)
out = rearrange(out, 'b h n d -> b n (h d)')
out = self.to_out(out)
return out

class Transformer(nn.Module):
def __init__(self, dim, depth, heads, mlp_dim):
super().__init__()
layers = []
for _ in range(depth):
layers.extend([
Residual(PreNorm(dim, Attention(dim, heads = heads))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim)))
])
self.net = nn.Sequential(*layers)
def forward(self, x):
return self.net(x)

class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, channels = 3):
super().__init__()
assert image_size % patch_size == 0, 'image dimensions must be divisible by the patch size'
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2

self.patch_size = patch_size

self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = Transformer(dim, depth, heads, mlp_dim)

self.mlp_head = nn.Sequential(
nn.Linear(dim, mlp_dim),
nn.GELU(),
nn.Linear(mlp_dim, num_classes)
)

def forward(self, img):
p = self.patch_size

x = rearrange(img, 'b c (p1 h) (p2 w) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
x = torch.cat((self.cls_token, x), dim=1)
x += self.pos_embedding
x = self.transformer(x)

return self.mlp_head(x[:, 0])

0 comments on commit ee8088b

Please sign in to comment.