diff --git a/README.md b/README.md index 9624e48..45ee556 100644 --- a/README.md +++ b/README.md @@ -93,7 +93,8 @@ distiller = DistillWrapper( student = v, teacher = teacher, temperature = 3, # temperature of distillation - alpha = 0.5 # trade between main loss and distillation loss + alpha = 0.5, # trade between main loss and distillation loss + hard = False # whether to use soft or hard distillation ) img = torch.randn(2, 3, 256, 256) diff --git a/setup.py b/setup.py index 9203fea..23cb6e9 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'vit-pytorch', packages = find_packages(exclude=['examples']), - version = '0.14.2', + version = '0.14.4', license='MIT', description = 'Vision Transformer (ViT) - Pytorch', author = 'Phil Wang', diff --git a/vit_pytorch/distill.py b/vit_pytorch/distill.py index d22447e..d6c9e13 100644 --- a/vit_pytorch/distill.py +++ b/vit_pytorch/distill.py @@ -104,7 +104,8 @@ def __init__( teacher, student, temperature = 1., - alpha = 0.5 + alpha = 0.5, + hard = False ): super().__init__() assert (isinstance(student, (DistillableViT, DistillableT2TViT, DistillableEfficientViT))) , 'student must be a vision transformer' @@ -116,6 +117,7 @@ def __init__( num_classes = student.num_classes self.temperature = temperature self.alpha = alpha + self.hard = hard self.distillation_token = nn.Parameter(torch.randn(1, 1, dim)) @@ -137,11 +139,15 @@ def forward(self, img, labels, temperature = None, alpha = None, **kwargs): loss = F.cross_entropy(student_logits, labels) - distill_loss = F.kl_div( - F.log_softmax(distill_logits / T, dim = -1), - F.softmax(teacher_logits / T, dim = -1).detach(), - reduction = 'batchmean') + if not self.hard: + distill_loss = F.kl_div( + F.log_softmax(distill_logits / T, dim = -1), + F.softmax(teacher_logits / T, dim = -1).detach(), + reduction = 'batchmean') + distill_loss *= T ** 2 - distill_loss *= T ** 2 + else: + teacher_labels = teacher_logits.argmax(dim = -1) + distill_loss = F.cross_entropy(student_logits, teacher_labels) return loss * alpha + distill_loss * (1 - alpha)