Skip to content

Commit

Permalink
add optimizer utilities (fixes jax-ml#244 and jax-ml#681)
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed May 6, 2019
1 parent b226cbc commit e751189
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 1 deletion.
17 changes: 16 additions & 1 deletion jax/experimental/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@
import jax.numpy as np
from jax.util import partial, safe_zip, safe_map, unzip2
from jax import tree_util
from jax.tree_util import tree_flatten, tree_unflatten, register_pytree_node
from jax.tree_util import (tree_map, tree_flatten, tree_unflatten,
register_pytree_node)

map = safe_map
zip = safe_zip
Expand Down Expand Up @@ -376,3 +377,17 @@ def make_schedule(scalar_or_schedule):
return constant(scalar_or_schedule)
else:
raise TypeError(type(scalar_or_schedule))


### utilities

def l2_norm(tree):
"""Compute the l2 norm of a pytree of arrays. Useful for weight decay."""
leaves, _ = tree_flatten(tree)
return np.sqrt(sum(np.vdot(x, x) for x in leaves))

def clip_grads(grad_tree, max_norm):
"""Clip gradients stored as a pytree of arrays to maximum norm `max_norm`."""
norm = l2_norm(grad_tree)
normalize = lambda g: np.where(norm < max_norm, g, g * (max_norm / norm))
return tree_map(normalize, grad_tree)
22 changes: 22 additions & 0 deletions tests/optimizers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
import functools

from absl.testing import absltest
import numpy as onp

import jax.numpy as np
import jax.test_util as jtu
from jax import jit, grad
Expand Down Expand Up @@ -213,6 +215,26 @@ def get_params(opt_state):
opt_state = init_fun(np.zeros(3))
self.assertRaises(TypeError, lambda: update_fun(opt_state))

def testUtilityNorm(self):
x0 = (np.ones(2), (np.ones(3), np.ones(4)))
norm = optimizers.l2_norm(x0)
expected = onp.sqrt(onp.sum(onp.ones(2+3+4)**2))
self.assertAllClose(norm, expected, check_dtypes=False)

def testUtilityClipGrads(self):
g = (np.ones(2), (np.ones(3), np.ones(4)))
norm = optimizers.l2_norm(g)

ans = optimizers.clip_grads(g, 1.1 * norm)
expected = g
self.assertAllClose(ans, expected, check_dtypes=False)

ans = optimizers.l2_norm(optimizers.clip_grads(g, 0.9 * norm))
expected = 0.9 * norm
self.assertAllClose(ans, expected, check_dtypes=False)




if __name__ == '__main__':
absltest.main()

0 comments on commit e751189

Please sign in to comment.