Skip to content

Commit

Permalink
offer hard distillation as well
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Apr 1, 2021
1 parent deb9620 commit b1f1044
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 8 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,8 @@ distiller = DistillWrapper(
student = v,
teacher = teacher,
temperature = 3, # temperature of distillation
alpha = 0.5 # trade between main loss and distillation loss
alpha = 0.5, # trade between main loss and distillation loss
hard = False # whether to use soft or hard distillation
)

img = torch.randn(2, 3, 256, 256)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'vit-pytorch',
packages = find_packages(exclude=['examples']),
version = '0.14.2',
version = '0.14.4',
license='MIT',
description = 'Vision Transformer (ViT) - Pytorch',
author = 'Phil Wang',
Expand Down
18 changes: 12 additions & 6 deletions vit_pytorch/distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def __init__(
teacher,
student,
temperature = 1.,
alpha = 0.5
alpha = 0.5,
hard = False
):
super().__init__()
assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer'
Expand All @@ -116,6 +117,7 @@ def __init__(
num_classes = student.num_classes
self.temperature = temperature
self.alpha = alpha
self.hard = hard

self.distillation_token = nn.Parameter(torch.randn(1, 1, dim))

Expand All @@ -137,11 +139,15 @@ def forward(self, img, labels, temperature = None, alpha = None, **kwargs):

loss = F.cross_entropy(student_logits, labels)

distill_loss = F.kl_div(
F.log_softmax(distill_logits / T, dim = -1),
F.softmax(teacher_logits / T, dim = -1).detach(),
reduction = 'batchmean')
if not self.hard:
distill_loss = F.kl_div(
F.log_softmax(distill_logits / T, dim = -1),
F.softmax(teacher_logits / T, dim = -1).detach(),
reduction = 'batchmean')
distill_loss *= T ** 2

distill_loss *= T ** 2
else:
teacher_labels = teacher_logits.argmax(dim = -1)
distill_loss = F.cross_entropy(student_logits, teacher_labels)

return loss * alpha + distill_loss * (1 - alpha)

0 comments on commit b1f1044

Please sign in to comment.