Skip to content

Commit

Permalink
Fix JVP rule for lax.pow()
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Aug 23, 2022
1 parent a73a6a8 commit b8fe0ab
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 2 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK.

## jax 0.3.17 (Unreleased)
* [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.16...main).
* Bugs
* Fix corner case issue in gradient of `lax.pow` with an exponent of zero
({jax-issue}`12041`)
* Breaking changes
* {func}`jax.checkpoint`, also known as {func}`jax.remat`, no longer supports
the `concrete` option, following the previous version's deprecation; see
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/lax/lax.py
Original file line number Diff line number Diff line change
Expand Up @@ -2004,8 +2004,7 @@ def _abs_jvp_rule(g, ans, x):
pow_p = standard_naryop([_float | _complex, _float | _complex], 'pow')

def _pow_jvp_lhs(g, ans, x, y):
jac = mul(y, pow(x, select(eq(y, _zeros(y)), _ones(y), sub(y, _ones(y)))))
return mul(g, jac)
return mul(g, mul(y, pow(x, sub(y, _ones(y)))))

def _pow_jvp_rhs(g, ans, x, y):
return mul(g, mul(log(_replace_zero(x)), ans))
Expand Down
20 changes: 20 additions & 0 deletions tests/lax_autodiff_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,26 @@ def testReverseGrad(self):
check_grads(rev, (np.array([[6., 5., 4.], [3., 2., 1.]]),), 2,
rtol={np.float32: 3e-3})

def testPowSecondDerivative(self):
# https://github.com/google/jax/issues/12033
x, y = 4.0, 0.0
expected = ((0.0, 1/x), (1/x, np.log(x) ** 2))

with self.subTest("jacfwd"):
result_fwd = jax.jacfwd(jax.jacfwd(lax.pow, (0, 1)), (0, 1))(x, y)
self.assertAllClose(result_fwd, expected)

with self.subTest("jacrev"):
result_rev = jax.jacrev(jax.jacrev(lax.pow, (0, 1)), (0, 1))(x, y)
self.assertAllClose(result_rev, expected)

with self.subTest("zero to the zero"):
result = jax.grad(lax.pow)(0.0, 0.0)
# TODO(jakevdp) special-case zero in a way that doesn't break other cases
# See https://github.com/google/jax/pull/12041#issuecomment-1222766191
# self.assertEqual(result, 0.0)
self.assertAllClose(result, np.nan)

@parameterized.named_parameters(jtu.cases_from_list(
{"testcase_name": "_predshape={}_argshapes={}".format(
jtu.format_shape_dtype_string(pred_shape, np.bool_),
Expand Down

0 comments on commit b8fe0ab

Please sign in to comment.