Skip to content

Commit

Permalink
Merge pull request lucidrains#51 from umbertov/main
Browse files Browse the repository at this point in the history
Add class for distillation with efficient attention
  • Loading branch information
lucidrains authored Dec 25, 2020
2 parents e0007bd + 5a225c8 commit 4a6469e
Showing 1 changed file with 37 additions and 1 deletion.
38 changes: 37 additions & 1 deletion vit_pytorch/distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import torch.nn.functional as F
from torch import nn
from vit_pytorch.vit_pytorch import ViT
from vit_pytorch.efficient import ViT as EfficientViT

from einops import rearrange, repeat

Expand Down Expand Up @@ -52,7 +53,12 @@ def __init__(
alpha = 0.5
):
super().__init__()
assert isinstance(student, DistillableViT), 'student must be a vision transformer'
assert (
isinstance(student, DistillableViT)
or
isinstance(student, DistillableEfficientViT)
) , 'student must be a vision transformer'

self.teacher = teacher
self.student = student

Expand Down Expand Up @@ -89,3 +95,33 @@ def forward(self, img, labels, temperature = None, alpha = None, **kwargs):
distill_loss *= T ** 2

return loss * alpha + distill_loss * (1 - alpha)


class DistillableEfficientViT(EfficientViT):
def __init__(self, *args, **kwargs):
super(DistillableEfficientViT, self).__init__(*args, **kwargs)
self.dim = kwargs['dim']
self.num_classes = kwargs['num_classes']

def forward(self, img, distill_token, mask = None):
p = self.patch_size

x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)
x = self.patch_to_embedding(x)
b, n, _ = x.shape

cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
x = torch.cat((cls_tokens, x), dim = 1)
x += self.pos_embedding[:, :(n + 1)]

distill_tokens = repeat(distill_token, '() n d -> b n d', b = b)
x = torch.cat((x, distill_tokens), dim = 1)

x = self.transformer(x)

x, distill_tokens = x[:, :-1], x[:, -1]
x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]

x = self.to_latent(x)
return self.mlp_head(x), distill_tokens

0 comments on commit 4a6469e

Please sign in to comment.