Skip to content

Commit

Permalink
Add a termination condition
Browse files Browse the repository at this point in the history
  • Loading branch information
nikitakit committed Dec 20, 2018
1 parent 00526c6 commit f54cd3e
Showing 1 changed file with 9 additions and 3 deletions.
12 changes: 9 additions & 3 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def make_hparams():
step_decay=True, # note that disabling step decay is not implemented
step_decay_factor=0.5,
step_decay_patience=5,
max_consecutive_decays=3, # establishes a termination criterion

partitioned=True,
num_layers_position_only=0,
Expand Down Expand Up @@ -231,12 +232,14 @@ def schedule_lr(iteration):
check_every = len(train_parse) / args.checks_per_epoch
best_dev_fscore = -np.inf
best_dev_model_path = None
best_dev_processed = 0

start_time = time.time()

def check_dev():
nonlocal best_dev_fscore
nonlocal best_dev_model_path
nonlocal best_dev_processed

dev_start_time = time.time()

Expand Down Expand Up @@ -272,6 +275,7 @@ def check_dev():
best_dev_fscore = dev_fscore.fscore
best_dev_model_path = "{}_dev={:.2f}".format(
args.model_path_base, dev_fscore.fscore)
best_dev_processed = total_processed
print("Saving new best model to {}...".format(best_dev_model_path))
torch.save({
'spec': parser.spec,
Expand Down Expand Up @@ -334,9 +338,11 @@ def check_dev():
check_dev()

# adjust learning rate at the end of an epoch
if hparams.step_decay:
if (total_processed // args.batch_size + 1) > hparams.learning_rate_warmup_steps:
scheduler.step(best_dev_fscore)
if (total_processed // args.batch_size + 1) > hparams.learning_rate_warmup_steps:
scheduler.step(best_dev_fscore)
if (total_processed - best_dev_processed) > ((hparams.step_decay_patience + 1) * hparams.max_consecutive_decays * len(train_parse)):
print("Terminating due to lack of improvement in dev fscore.")
break

def run_test(args):
print("Loading test trees from {}...".format(args.test_path))
Expand Down

0 comments on commit f54cd3e

Please sign in to comment.