Skip to content

Commit

Permalink
addressed jareds comments
Browse files Browse the repository at this point in the history
  • Loading branch information
shoeybi committed Jun 5, 2020
1 parent 7802200 commit 197c132
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 18 deletions.
3 changes: 3 additions & 0 deletions megatron/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,9 @@ def parse_args(extra_args_provider=None, defaults={},
if args.num_unique_layers < args.num_layers:
assert args.DDP_impl == 'local', \
'torch-DDP does not work with parameters sharing.'
# Mixed precision checks.
if args.fp16_lm_cross_entropy:
assert args.fp16, 'lm cross entropy in fp16 only support in fp16 mode.'

_print_args(args)
return args
Expand Down
8 changes: 7 additions & 1 deletion megatron/model/bert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(self, num_tokentypes=2, add_binary_head=True,
super(BertModel, self).__init__()
args = get_args()

self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy
self.add_binary_head = add_binary_head
self.parallel_output = parallel_output
init_method = init_method_normal(args.init_method_std)
Expand Down Expand Up @@ -170,7 +171,12 @@ def forward(self, input_ids, attention_mask,
if lm_labels is None:
return lm_logits, binary_logits
else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
if self.fp16_lm_cross_entropy:
assert lm_logits.dtype == torch.half
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits, lm_labels)
else:
lm_loss = mpu.vocab_parallel_cross_entropy(lm_logits.float(),
lm_labels)
return lm_loss, binary_logits


Expand Down
7 changes: 6 additions & 1 deletion megatron/model/gpt2_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __init__(self, num_tokentypes=0, parallel_output=True):
args = get_args()

self.parallel_output = parallel_output
self.fp16_lm_cross_entropy = args.fp16_lm_cross_entropy

self.language_model, self._language_model_key = get_language_model(
attention_mask_func=gpt2_attention_mask_func,
Expand Down Expand Up @@ -79,7 +80,11 @@ def forward(self, input_ids, position_ids, attention_mask, labels=None,
if labels is None:
return output
else:
loss = mpu.vocab_parallel_cross_entropy(output, labels)
if self.fp16_lm_cross_entropy:
assert output.dtype == torch.half
loss = mpu.vocab_parallel_cross_entropy(output, labels)
else:
loss = mpu.vocab_parallel_cross_entropy(output.float(), labels)
return loss


Expand Down
14 changes: 5 additions & 9 deletions pretrain_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,16 +78,12 @@ def forward_step(data_iterator, model):
timers('batch generator').stop()

# Forward model. lm_labels
if args.fp16_lm_cross_entropy:
lm_loss_, sop_logits = model(tokens, padding_mask, tokentype_ids=types,
lm_labels=lm_labels)
else:
lm_logits, sop_logits = model(tokens, padding_mask, tokentype_ids=types)
lm_loss_ = mpu.vocab_parallel_cross_entropy(
lm_logits.contiguous().float(), lm_labels.contiguous())
lm_loss_, sop_logits = model(tokens, padding_mask,
tokentype_ids=types,
lm_labels=lm_labels)

sop_loss = F.cross_entropy(sop_logits.view(-1, 2).contiguous().float(),
sentence_order.view(-1).contiguous(),
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1),
ignore_index=-1)

lm_loss = torch.sum(
Expand Down
9 changes: 2 additions & 7 deletions pretrain_gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,13 +82,8 @@ def forward_step(data_iterator, model):
timers('batch generator').stop()

# Forward model.
if args.fp16_lm_cross_entropy:
losses = model(tokens, position_ids, attention_mask, labels=labels)
else:
output = model(tokens, position_ids, attention_mask)
losses = mpu.vocab_parallel_cross_entropy(output.contiguous().float(),
labels)

losses = model(tokens, position_ids, attention_mask, labels=labels)

loss_mask = loss_mask.view(-1)
loss = torch.sum(losses.view(-1) * loss_mask) / loss_mask.sum()

Expand Down

0 comments on commit 197c132

Please sign in to comment.