forked from deepset-ai/haystack
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_distillation.py
33 lines (23 loc) · 1.28 KB
/
test_distillation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
from haystack.nodes import FARMReader
import torch
def test_distillation():
student = FARMReader(model_name_or_path="prajjwal1/bert-tiny")
teacher = FARMReader(model_name_or_path="prajjwal1/bert-small")
# create a checkpoint of weights before distillation
student_weights = []
for name, weight in student.inferencer.model.named_parameters():
if "weight" in name and weight.requires_grad:
student_weights.append(torch.clone(weight))
assert len(student_weights) == 22
student_weights.pop(-2) # pooler is not updated due to different attention head
student.distil_from(teacher, data_dir="samples/squad", train_filename="tiny.json")
# create new checkpoint
new_student_weights = [torch.clone(param) for param in student.inferencer.model.parameters()]
new_student_weights = []
for name, weight in student.inferencer.model.named_parameters():
if "weight" in name and weight.requires_grad:
new_student_weights.append(weight)
assert len(new_student_weights) == 22
new_student_weights.pop(-2) # pooler is not updated due to different attention head
# check if weights have changed
assert not any(torch.equal(old_weight, new_weight) for old_weight, new_weight in zip(student_weights, new_student_weights))