Skip to content

Commit

Permalink
BatchNorm update bug fix. Replaced mutable=True in model.apply()
Browse files Browse the repository at this point in the history
…during training with `mutable=['batch_stats', 'get_bounds']`, otherwise BN statistics would not get updated during training due to recent change in flax BN implementation.

PiperOrigin-RevId: 374920534
  • Loading branch information
lisa-1010 authored and copybara-github committed May 20, 2021
1 parent e4252b1 commit ae9d07f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions aqt/jax/imagenet/train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,17 @@ def train_step(model, state, batch, hparams, update_bounds, learning_rate_fn):

def loss_fn(params):
"""loss function used for training."""
variables = {'params': params, **state.model_state}
variables = {'params': params}
variables.update(state.model_state)
logits, new_model_state = model.apply(
variables, batch['image'], mutable=True)
variables, batch['image'], mutable=['batch_stats', 'get_bounds'])
loss = cross_entropy_loss(logits, batch['label'])
weight_penalty_params = jax.tree_leaves(variables['params'])
weight_decay = hparams.weight_decay
weight_l2 = sum(
[jnp.sum(x**2) for x in weight_penalty_params if x.ndim > 1])
weight_penalty = weight_decay * 0.5 * weight_l2
loss = loss + weight_penalty
new_model_state, _ = new_model_state.pop('params')
return loss, (new_model_state, logits)

step = state.step
Expand Down

0 comments on commit ae9d07f

Please sign in to comment.