Skip to content

Commit

Permalink
Merge: [BERT/PyT] [ELECTRA/TF2] resume p2 option, fix early stopping
Browse files Browse the repository at this point in the history
  • Loading branch information
nv-kkudrynski committed Jun 13, 2022
2 parents 340db9e + 544a2d6 commit 691a34f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
10 changes: 7 additions & 3 deletions PyTorch/LanguageModeling/BERT/run_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,10 @@ def parse_arguments():
default=False,
action='store_true',
help="Whether to train with seq len 512")
parser.add_argument('--resume_phase2',
default=False,
action='store_true',
help="Whether to resume training with seq len 512")
parser.add_argument('--allreduce_post_accumulation',
default=False,
action='store_true',
Expand Down Expand Up @@ -427,13 +431,13 @@ def prepare_model_and_optimizer(args, device, sequence_output_is_dense):
model.checkpoint_activations(args.checkpoint_activations)

if args.resume_from_checkpoint:
# For phase2, need to reset the learning rate and step count in the checkpoint
if args.phase2 or args.init_checkpoint :
# For phase2 from scratch, need to reset the learning rate and step count in the checkpoint. Else restore values in checkpoint.
if (args.phase2 and not args.resume_phase2) or args.init_checkpoint :
for group in checkpoint['optimizer']['param_groups'] :
group['step'].zero_()
group['lr'].fill_(args.learning_rate)
else :
if 'grad_scaler' in checkpoint and not args.phase2:
if 'grad_scaler' in checkpoint and (not args.phase2 or args.resume_phase2):
grad_scaler.load_state_dict(checkpoint['grad_scaler'])
optimizer.load_state_dict(checkpoint['optimizer']) # , strict=False)

Expand Down
2 changes: 1 addition & 1 deletion TensorFlow2/LanguageModeling/ELECTRA/run_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,7 @@ def main(e2e_start_time):
iter_save_path = iter_manager.save(checkpoint_number=step)
log(" ** Saved iterator checkpoint for step {}: {}".format(step, iter_save_path), all_rank=True)
local_step += 1
if (local_step % (config.steps_this_run * args.gradient_accumulation_steps) == 0):
if config.steps_this_run != -1 and (local_step % (config.steps_this_run * args.gradient_accumulation_steps) == 0):
#terminating run sooner as steps_this_run has been reached
log("terminating as steps_this_run:{} has been reached".format(config.steps_this_run))
break
Expand Down

0 comments on commit 691a34f

Please sign in to comment.