From b8fe0ab8b1c34ba18d355aca973ccfbfefad8acd Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Mon, 22 Aug 2022 11:53:02 -0700 Subject: [PATCH] Fix JVP rule for lax.pow() --- CHANGELOG.md | 3 +++ jax/_src/lax/lax.py | 3 +-- tests/lax_autodiff_test.py | 20 ++++++++++++++++++++ 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1c341d8756ee..c417f5c17f6b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -10,6 +10,9 @@ PLEASE REMEMBER TO CHANGE THE '..main' WITH AN ACTUAL TAG in GITHUB LINK. ## jax 0.3.17 (Unreleased) * [GitHub commits](https://github.com/google/jax/compare/jax-v0.3.16...main). +* Bugs + * Fix corner case issue in gradient of `lax.pow` with an exponent of zero + ({jax-issue}`12041`) * Breaking changes * {func}`jax.checkpoint`, also known as {func}`jax.remat`, no longer supports the `concrete` option, following the previous version's deprecation; see diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 6d9d1af78d08..eb618a59fc3c 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -2004,8 +2004,7 @@ def _abs_jvp_rule(g, ans, x): pow_p = standard_naryop([_float | _complex, _float | _complex], 'pow') def _pow_jvp_lhs(g, ans, x, y): - jac = mul(y, pow(x, select(eq(y, _zeros(y)), _ones(y), sub(y, _ones(y))))) - return mul(g, jac) + return mul(g, mul(y, pow(x, sub(y, _ones(y))))) def _pow_jvp_rhs(g, ans, x, y): return mul(g, mul(log(_replace_zero(x)), ans)) diff --git a/tests/lax_autodiff_test.py b/tests/lax_autodiff_test.py index e9c7bb6b6e26..f62e6ae6f8b2 100644 --- a/tests/lax_autodiff_test.py +++ b/tests/lax_autodiff_test.py @@ -529,6 +529,26 @@ def testReverseGrad(self): check_grads(rev, (np.array([[6., 5., 4.], [3., 2., 1.]]),), 2, rtol={np.float32: 3e-3}) + def testPowSecondDerivative(self): + # https://github.com/google/jax/issues/12033 + x, y = 4.0, 0.0 + expected = ((0.0, 1/x), (1/x, np.log(x) ** 2)) + + with self.subTest("jacfwd"): + result_fwd = jax.jacfwd(jax.jacfwd(lax.pow, (0, 1)), (0, 1))(x, y) + self.assertAllClose(result_fwd, expected) + + with self.subTest("jacrev"): + result_rev = jax.jacrev(jax.jacrev(lax.pow, (0, 1)), (0, 1))(x, y) + self.assertAllClose(result_rev, expected) + + with self.subTest("zero to the zero"): + result = jax.grad(lax.pow)(0.0, 0.0) + # TODO(jakevdp) special-case zero in a way that doesn't break other cases + # See https://github.com/google/jax/pull/12041#issuecomment-1222766191 + # self.assertEqual(result, 0.0) + self.assertAllClose(result, np.nan) + @parameterized.named_parameters(jtu.cases_from_list( {"testcase_name": "_predshape={}_argshapes={}".format( jtu.format_shape_dtype_string(pred_shape, np.bool_),