Skip to content

Commit

Permalink
Report log likelihood for label smoothing
Browse files Browse the repository at this point in the history
  • Loading branch information
edunov authored and myleott committed Jan 22, 2018
1 parent c537860 commit dd31fa9
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
3 changes: 3 additions & 0 deletions fairseq/criterions/label_smoothed_cross_entropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,11 @@ def forward(self, model, sample, reduce=True):
lprobs = model.get_normalized_probs(net_output, log_probs=True)
target = sample['target'].view(-1)
loss = LabelSmoothedNLLLoss.apply(lprobs, target, self.eps, self.padding_idx, self.weights, reduce)
nll_loss = F.nll_loss(lprobs, target, size_average=False, ignore_index=self.padding_idx, reduce=reduce)
sample_size = sample['target'].size(0) if self.args.sentence_avg else sample['ntokens']
logging_output = {
'loss': loss.data[0] if reduce else loss.data,
'nll_loss': nll_loss.data[0] if reduce else loss.data,
'sample_size': sample_size,
}
return loss, sample_size, logging_output
Expand All @@ -78,4 +80,5 @@ def aggregate_logging_outputs(logging_outputs):
sample_size = sum(log.get('sample_size', 0) for log in logging_outputs)
return {
'loss': sum(log.get('loss', 0) for log in logging_outputs) / sample_size / math.log(2),
'nll_loss': sum(log.get('nll_loss', 0) for log in logging_outputs) / sample_size / math.log(2),
}
21 changes: 18 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions):
sample_without_replacement=args.sample_without_replacement,
sort_by_source_size=(epoch <= args.curriculum))
loss_meter = AverageMeter()
nll_loss_meter = AverageMeter()
bsz_meter = AverageMeter() # sentences per batch
wpb_meter = AverageMeter() # words per batch
wps_meter = TimeMeter() # words per second
Expand All @@ -164,6 +165,11 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions):
del loss_dict['loss'] # don't include in extra_meters or extra_postfix

ntokens = sum(s['ntokens'] for s in sample)

if 'nll_loss' in loss_dict:
nll_loss = loss_dict['nll_loss']
nll_loss_meter.update(nll_loss, ntokens)

nsentences = sum(s['net_input']['src_tokens'].size(0) for s in sample)
loss_meter.update(loss, nsentences if args.sentence_avg else ntokens)
bsz_meter.update(nsentences)
Expand Down Expand Up @@ -193,7 +199,9 @@ def train(args, epoch, batch_offset, trainer, dataset, max_positions):

t.print(collections.OrderedDict([
('train loss', round(loss_meter.avg, 2)),
('train ppl', get_perplexity(loss_meter.avg)),
('train ppl', get_perplexity(nll_loss_meter.avg
if nll_loss_meter.count > 0
else loss_meter.avg)),
('s/checkpoint', round(wps_meter.elapsed_time)),
('words/s', round(wps_meter.avg)),
('words/batch', round(wpb_meter.avg)),
Expand Down Expand Up @@ -242,16 +250,21 @@ def validate(args, epoch, trainer, dataset, max_positions, subset):
descending=True, # largest batch first to warm the caching allocator
)
loss_meter = AverageMeter()
nll_loss_meter = AverageMeter()
extra_meters = collections.defaultdict(lambda: AverageMeter())

prefix = 'valid on \'{}\' subset'.format(subset)
with utils.build_progress_bar(args, itr, epoch, prefix) as t:
for _, sample in data.skip_group_enumerator(t, args.num_gpus):
loss_dict = trainer.valid_step(sample)
ntokens = sum(s['ntokens'] for s in sample)
loss = loss_dict['loss']
del loss_dict['loss'] # don't include in extra_meters or extra_postfix

ntokens = sum(s['ntokens'] for s in sample)
if 'nll_loss' in loss_dict:
nll_loss = loss_dict['nll_loss']
nll_loss_meter.update(nll_loss, ntokens)

loss_meter.update(loss, ntokens)

extra_postfix = []
Expand All @@ -265,7 +278,9 @@ def validate(args, epoch, trainer, dataset, max_positions, subset):

t.print(collections.OrderedDict([
('valid loss', round(loss_meter.avg, 2)),
('valid ppl', get_perplexity(loss_meter.avg)),
('valid ppl', get_perplexity(nll_loss_meter.avg
if nll_loss_meter.count > 0
else loss_meter.avg)),
] + [
(k, meter.avg)
for k, meter in extra_meters.items()
Expand Down

0 comments on commit dd31fa9

Please sign in to comment.