Skip to content

Commit

Permalink
more loss logs
Browse files Browse the repository at this point in the history
  • Loading branch information
airaria committed Nov 9, 2020
1 parent 55af0bc commit 51c4701
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 133 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

setup(
name="textbrewer",
version="0.2.0.1",
version="0.2.1",
author="ziqingyang",
author_email="[email protected]",
description="PyTorch-based knowledge distillation toolkit for natural language processing",
Expand Down
2 changes: 1 addition & 1 deletion src/textbrewer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.2.0.1"
__version__ = "0.2.1"

from .distillers import BasicTrainer
from .distillers import BasicDistiller
Expand Down
222 changes: 119 additions & 103 deletions src/textbrewer/distiller_basic.py

Large diffs are not rendered by default.

31 changes: 19 additions & 12 deletions src/textbrewer/distiller_general.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,17 +76,20 @@ def train_on_batch(self, batch, args):
results_T = post_adaptor(self.adaptor_T(teacher_batch,results_T))
results_S = post_adaptor(self.adaptor_S(student_batch,results_S))

total_loss = self.compute_loss(results_T, results_S)
total_loss, losses_dict = self.compute_loss(results_S, results_T)

return total_loss
return total_loss, losses_dict


def compute_loss(self,results_T,results_S):
def compute_loss(self,results_S,results_T):

losses_dict = dict()

total_loss = 0
if 'logits' in results_T and 'logits' in results_S:
logits_list_T = results_T['logits'] # list of tensor
logits_list_S = results_S['logits'] # list of tensor

total_kd_loss = 0
if 'logits_mask' in results_S:
masks_list_S = results_S['logits_mask']
logits_list_S = select_logits_with_mask(logits_list_S,masks_list_S) #(mask_sum, num_of_class)
Expand All @@ -102,16 +105,16 @@ def compute_loss(self,results_T,results_S):
temperature = self.d_config.temperature_scheduler(l_S, l_T, self.d_config.temperature)
else:
temperature = self.d_config.temperature
kd_loss = self.kd_loss(l_S, l_T, temperature) * self.d_config.kd_loss_weight
total_loss += kd_loss
total_kd_loss += self.kd_loss(l_S, l_T, temperature)
else:
for l_T,l_S in zip(logits_list_T,logits_list_S):
if self.d_config.temperature_scheduler is not None:
temperature = self.d_config.temperature_scheduler(l_S, l_T, self.d_config.temperature)
else:
temperature = self.d_config.temperature
kd_loss = self.kd_loss(l_S, l_T, temperature) * self.d_config.kd_loss_weight
total_loss += kd_loss
total_kd_loss = self.kd_loss(l_S, l_T, temperature)
total_loss += total_kd_loss * self.d_config.kd_loss_weight
losses_dict['unweighted_kd_loss'] = total_kd_loss

inters_T = {feature: results_T.get(feature,[]) for feature in FEATURES}
inters_S = {feature: results_S.get(feature,[]) for feature in FEATURES}
Expand All @@ -137,7 +140,9 @@ def compute_loss(self,results_T,results_S):
if self.projs[ith]:
#inter_T = self.projs[ith](inter_T)
inter_S = self.projs[ith](inter_S)
total_loss += match_loss(inter_S, inter_T, mask=inputs_mask_S) * match_weight
intermediate_loss = match_loss(inter_S, inter_T, mask=inputs_mask_S)
total_loss += intermediate_loss * match_weight
losses_dict[f'unweighted_{feature}_{loss_type}_{layer_S}_{layer_T}'] = intermediate_loss

if self.has_custom_matches:
for hook_T, hook_S, match_weight, match_loss, proj_func in \
Expand All @@ -151,11 +156,13 @@ def compute_loss(self,results_T,results_S):
self.custom_matches_cache['hook_outputs_S'] = []

if 'losses' in results_S:
total_hl_loss = 0
for loss in results_S['losses']:
# in case of multi-GPU
total_loss += loss.mean() * self.d_config.hard_label_weight
return total_loss

total_hl_loss += loss.mean()
total_loss += total_hl_loss * self.d_config.hard_label_weight
losses_dict['unweighted_hard_label_loss'] = total_hl_loss
return total_loss, losses_dict

def add_match(self,match: CustomMatch):
if type(match.module_T) is str or type(match.module_S) is str:
Expand Down
19 changes: 10 additions & 9 deletions src/textbrewer/distiller_multitask.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class MultiTaskDistiller(GeneralDistiller):
adaptor_S (dict): dict of student adaptors: {task1:adpt1, task2:adpt2, .... }. Keys are tasknames.
"""

def __init__(self, train_config,
distill_config,
model_T,
Expand Down Expand Up @@ -76,7 +76,7 @@ def train(self, optimizer, dataloaders, num_steps, scheduler_class=None, schedul
tasknames = tuple(dataloaders.keys())
sampling_weights = None


global_step = 0
writer_step = 0
optimizer.zero_grad()
Expand All @@ -90,17 +90,18 @@ def train(self, optimizer, dataloaders, num_steps, scheduler_class=None, schedul
if batch_postprocessors is not None:
batch = batch_postprocessors[taskname](batch)
batch_taskname = (batch, taskname)
total_loss = self.train_on_batch(batch_taskname, args)
total_loss, losses_dict = self.train_on_batch(batch_taskname, args)

self.write_loss(total_loss,writer_step,losses_dict)
writer_step += 1

total_loss /= self.t_config.gradient_accumulation_steps
if self.t_config.fp16:
with amp.scale_loss(total_loss,optimizer) as scaled_loss:
scaled_loss.backward()
else:
total_loss.backward()
if self.rank == 0:
scalar_total_loss = total_loss.cpu().item() * self.t_config.gradient_accumulation_steps
self.tb_writer.add_scalar('scalar/total_loss', scalar_total_loss, writer_step)
writer_step += 1

if max_grad_norm > 0:
if self.t_config.fp16:
torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), max_grad_norm)
Expand Down Expand Up @@ -135,6 +136,6 @@ def train_on_batch(self, batch_taskname, args) -> torch.Tensor:
results_T = post_adaptor(adaptor_T(teacher_batch,results_T))
results_S = post_adaptor(adaptor_S(student_batch,results_S))

total_loss = self.compute_loss(results_T, results_S)
total_loss, losses_dict = self.compute_loss(results_S=results_S, results_T=results_T)

return total_loss
return total_loss, losses_dict
16 changes: 12 additions & 4 deletions src/textbrewer/distiller_multiteacher.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def train_on_batch(self, batch, args):
logits_list_T = [results_t['logits'] for results_t in results_T] # list of tensor
logits_list_S = results_S['logits'] # list of tensor
total_loss = 0
losses_dict = dict()
total_kd_loss = 0

if 'logits_mask' in results_S:
masks_list_S = results_S['logits_mask']
Expand All @@ -68,21 +70,27 @@ def train_on_batch(self, batch, args):
temperature = self.d_config.temperature_scheduler(l_S, mean_l_T, self.d_config.temperature)
else:
temperature = self.d_config.temperature
total_loss += self.kd_loss(l_S, mean_l_T, temperature) * self.d_config.kd_loss_weight
total_kd_loss += self.kd_loss(l_S, mean_l_T, temperature)
else:
for l_T, l_S in zip(zip(*logits_list_T),logits_list_S):
mean_l_T = sum(l_T)/len(l_T)
if self.d_config.temperature_scheduler is not None:
temperature = self.d_config.temperature_scheduler(l_S, mean_l_T, self.d_config.temperature)
else:
temperature = self.d_config.temperature
total_loss += self.kd_loss(l_S, mean_l_T, temperature) * self.d_config.kd_loss_weight
total_kd_loss += self.kd_loss(l_S, mean_l_T, temperature)
total_loss += total_kd_loss * self.d_config.kd_loss_weight
losses_dict['unweighted_kd_loss'] = total_kd_loss

if 'losses' in results_S:
total_hl_loss = 0
for loss in results_S['losses']:
# in case of multi-GPU
total_loss += loss.mean() * self.d_config.hard_label_weight
return total_loss
total_hl_loss += loss.mean()
total_loss += total_hl_loss * self.d_config.hard_label_weight
losses_dict['unweighted_hard_label_loss'] = total_hl_loss

return total_loss, losses_dict

def cache_logits(self, batch, args, batch_postprocessor):
if batch_postprocessor is not None:
Expand Down
22 changes: 19 additions & 3 deletions src/textbrewer/distiller_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,17 @@ def __enter__(self):
self.model_T_is_training = self.model_T.training
self.model_T.eval()

self.model_S_is_training = self.model_S.training
self.model_S.train()
if isinstance(self.model_S,(list,tuple)):
self.model_S_is_training = [model_s.training for model_s in self.model_S]
for model_s in self.model_S:
model_s.eval()
elif isinstance(self.model_S,dict):
self.model_S_is_training = {name:model.training for name,model in self.model_S.items()}
for name in self.model_S:
self.model_S[name].eval()
else:
self.model_S_is_training = self.model_S.training
self.model_S.train()

def __exit__(self, exc_type, exc_val, exc_tb):
#Restore model status
Expand All @@ -84,7 +93,14 @@ def __exit__(self, exc_type, exc_val, exc_tb):
else:
self.model_T.train(self.model_T_is_training)

self.model_S.train(self.model_S_is_training)
if isinstance(self.model_S,(list,tuple)):
for i in range(len(self.model_S_is_training)):
self.model_S[i].train(self.model_S_is_training[i])
elif isinstance(self.model_S,dict):
for name,is_training in self.model_S_is_training.items():
self.model_S[name].train(is_training)
else:
self.model_S.train(self.model_S_is_training)


class AbstractDistiller(DistillationContext):
Expand Down

0 comments on commit 51c4701

Please sign in to comment.