Skip to content

Commit

Permalink
make sure distillation still works
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Feb 22, 2021
1 parent 05edfff commit 6af7bbc
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 12 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.7.0',
version = '0.7.1',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
Expand Down
6 changes: 2 additions & 4 deletions vit_pytorch/distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,8 @@ def exists(val):

class DistillMixin:
def forward(self, img, distill_token = None, mask = None):
p, distilling = self.patch_size, exists(distill_token)

x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
distilling = exists(distill_token)
x = self.to_patch_embedding(img)
b, n, _ = x.shape

cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
Expand Down
14 changes: 7 additions & 7 deletions vit_pytorch/efficient.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from einops import rearrange, repeat
from torch import nn
from einops import rearrange, repeat
from einops.layers.torch import Rearrange

class ViT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, pool = 'cls', channels = 3):
Expand All @@ -10,10 +11,12 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, poo
num_patches = (image_size // patch_size) ** 2
patch_dim = channels * patch_size ** 2

self.patch_size = patch_size
self.to_patch_embedding = nn.Sequential(
Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
nn.Linear(patch_dim, dim),
)

self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.transformer = transformer

Expand All @@ -26,10 +29,7 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, transformer, poo
)

def forward(self, img):
p = self.patch_size

x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
x = self.to_patch_embedding(img)
b, n, _ = x.shape

cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
Expand Down

0 comments on commit 6af7bbc

Please sign in to comment.