From 197c132e9a1e3a1a7318a5cb99e065af6adffc6f Mon Sep 17 00:00:00 2001 From: mohammad Date: Fri, 5 Jun 2020 13:57:21 -0700 Subject: [PATCH] addressed jareds comments --- megatron/arguments.py | 3 +++ megatron/model/bert_model.py | 8 +++++++- megatron/model/gpt2_model.py | 7 ++++++- pretrain_bert.py | 14 +++++--------- pretrain_gpt2.py | 9 ++------- 5 files changed, 23 insertions(+), 18 deletions(-) diff --git a/megatron/arguments.py b/megatron/arguments.py index 67a46ee8b4..2333b0cbac 100644 --- a/megatron/arguments.py +++ b/megatron/arguments.py @@ -97,6 +97,9 @@ def parse_args(extra_args_provider=None, defaults={}, if args.num_unique_layers < args.num_layers: assert args.DDP_impl == 'local', \ 'torch-DDP does not work with parameters sharing.' + # Mixed precision checks. + if args.fp16_lm_cross_entropy: + assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.' _print_args(args) return args diff --git a/megatron/model/bert_model.py b/megatron/model/bert_model.py index a32f4307f9..14f8bea6cd 100644 --- a/megatron/model/bert_model.py +++ b/megatron/model/bert_model.py @@ -115,6 +115,7 @@ def __init__(self, num_tokentypes=2, add_binary_head=True, super(BertModel, self).__init__() args = get_args() + self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.add_binary_head = add_binary_head self.parallel_output = parallel_output init_method = init_method_normal(args.init_method_std) @@ -170,7 +171,12 @@ def forward(self, input_ids, attention_mask, if lm_labels is None: return lm_logits, binary_logits else: - lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) + if self.fp16_lm_cross_entropy: + assert lm_logits.dtype == torch.half + lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels) + else: + lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(), + lm_labels) return lm_loss, binary_logits diff --git a/megatron/model/gpt2_model.py b/megatron/model/gpt2_model.py index b5e0a59ba2..b0d275f1cb 100644 --- a/megatron/model/gpt2_model.py +++ b/megatron/model/gpt2_model.py @@ -40,6 +40,7 @@ def __init__(self, num_tokentypes=0, parallel_output=True): args = get_args() self.parallel_output = parallel_output + self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy self.language_model, self._language_model_key = get_language_model( attention_mask_func=gpt2_attention_mask_func, @@ -79,7 +80,11 @@ def forward(self, input_ids, position_ids, attention_mask, labels=None, if labels is None: return output else: - loss = mpu.vocab_parallel_cross_entropy(output, labels) + if self.fp16_lm_cross_entropy: + assert output.dtype == torch.half + loss = mpu.vocab_parallel_cross_entropy(output, labels) + else: + loss = mpu.vocab_parallel_cross_entropy(output.float(), labels) return loss diff --git a/pretrain_bert.py b/pretrain_bert.py index e4153ba8d2..0d38c139f4 100644 --- a/pretrain_bert.py +++ b/pretrain_bert.py @@ -78,16 +78,12 @@ def forward_step(data_iterator, model): timers('batch generator').stop() # Forward model. lm_labels - if args.fp16_lm_cross_entropy: - lm_loss_, sop_logits = model(tokens, padding_mask, tokentype_ids=types, - lm_labels=lm_labels) - else: - lm_logits, sop_logits = model(tokens, padding_mask, tokentype_ids=types) - lm_loss_ = mpu.vocab_parallel_cross_entropy( - lm_logits.contiguous().float(), lm_labels.contiguous()) + lm_loss_, sop_logits = model(tokens, padding_mask, + tokentype_ids=types, + lm_labels=lm_labels) - sop_loss = F.cross_entropy(sop_logits.view(-1, 2).contiguous().float(), - sentence_order.view(-1).contiguous(), + sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(), + sentence_order.view(-1), ignore_index=-1) lm_loss = torch.sum( diff --git a/pretrain_gpt2.py b/pretrain_gpt2.py index 396bbb7a90..6adeb1d2ab 100644 --- a/pretrain_gpt2.py +++ b/pretrain_gpt2.py @@ -82,13 +82,8 @@ def forward_step(data_iterator, model): timers('batch generator').stop() # Forward model. - if args.fp16_lm_cross_entropy: - losses = model(tokens, position_ids, attention_mask, labels=labels) - else: - output = model(tokens, position_ids, attention_mask) - losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(), - labels) - + losses = model(tokens, position_ids, attention_mask, labels=labels) + loss_mask = loss_mask.view(-1) loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()