Skip to content

Commit

Permalink
add lr layer decay
Browse files Browse the repository at this point in the history
  • Loading branch information
kimiyoung committed Jun 21, 2019
1 parent 0b642d1 commit 23728ae
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,23 @@ def get_train_op(FLAGS, total_loss, grads_and_vars=None):
grads_and_vars = optimizer.compute_gradients(total_loss)
gradients, variables = zip(*grads_and_vars)
clipped, gnorm = tf.clip_by_global_norm(gradients, FLAGS.clip)

if FLAGS.lr_layer_decay_rate != 1.0:
n_layer = 0
for i in range(len(clipped)):
m = re.search(r"model/transformer/layer_(\d+?)/", variables[i].name)
if not m: continue
n_layer = max(n_layer, int(m.group(1)) + 1)

for i in range(len(clipped)):
for l in range(n_layer):
if "model/transformer/layer_{}/".format(l) in variables[i].name:
abs_rate = FLAGS.lr_layer_decay_rate ** (n_layer - 1 - l)
clipped[i] *= abs_rate
tf.logging.info("Apply mult {:.4f} to layer-{} grad of {}".format(
abs_rate, l, variables[i].name))
break

train_op = optimizer.apply_gradients(
zip(clipped, variables), global_step=global_step)

Expand Down

0 comments on commit 23728ae

Please sign in to comment.