Skip to content

Commit

Permalink
Internal change
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 264474346
  • Loading branch information
tensorflower-gardener committed Aug 20, 2019
1 parent 8089a56 commit dab0c03
Showing 1 changed file with 8 additions and 0 deletions.
8 changes: 8 additions & 0 deletions official/bert/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,14 @@ def __init__(self,
epsilon=1e-7,
amsgrad=False,
weight_decay_rate=0.0,
include_in_weight_decay=None,
exclude_from_weight_decay=None,
name='AdamWeightDecay',
**kwargs):
super(AdamWeightDecay, self).__init__(
learning_rate, beta_1, beta_2, epsilon, amsgrad, name, **kwargs)
self.weight_decay_rate = weight_decay_rate
self._include_in_weight_decay = include_in_weight_decay
self._exclude_from_weight_decay = exclude_from_weight_decay

@classmethod
Expand Down Expand Up @@ -178,6 +180,12 @@ def _do_use_weight_decay(self, param_name):
"""Whether to use L2 weight decay for `param_name`."""
if self.weight_decay_rate == 0:
return False

if self._include_in_weight_decay:
for r in self._include_in_weight_decay:
if re.search(r, param_name) is not None:
return True

if self._exclude_from_weight_decay:
for r in self._exclude_from_weight_decay:
if re.search(r, param_name) is not None:
Expand Down

0 comments on commit dab0c03

Please sign in to comment.