Skip to content

Commit

Permalink
Raise error when differentiating w.r.t. outer variable with defjvp_all
Browse files Browse the repository at this point in the history
  • Loading branch information
j-towns committed Jun 27, 2019
1 parent 31fa041 commit 323d9f5
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 1 deletion.
8 changes: 7 additions & 1 deletion jax/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1118,7 +1118,13 @@ def defjvp_all(fun, custom_jvp):
_check_custom_transforms_type("defjvp_all", fun)
def custom_transforms_jvp(primals, tangents, **params):
consts, jax_kwargs, jax_args = primals[0], primals[1], primals[2:]
_, _, jax_args_dot = tangents[0], tangents[1], tangents[2:]
consts_dot, _, jax_args_dot = tangents[0], tangents[1], tangents[2:]
if consts_dot is not ad_util.zero:
msg = (
"Detected differentiation w.r.t. variables from outside the scope of "
"{}, but defjvp and defjvp_all only support differentiation w.r.t. "
"positional arguments.")
raise ValueError(msg.format(str(fun)))
if jax_kwargs:
msg = ("defjvp_all requires the corresponding custom_transforms function "
"not to be called with keyword arguments.")
Expand Down
14 changes: 14 additions & 0 deletions tests/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,6 +562,20 @@ def foo(x, y):
self.assertAllClose(grad_ans, 3. * 4. + onp.cos(onp.sin(3. * 4)),
check_dtypes=False)

def test_defjvp_closure_error(self):
def foo(x):
@api.custom_transforms
def bar(y):
return x * y

api.defjvp(bar, lambda y_dot, ans, y: x * y)
return bar(x)
jtu.check_raises(
lambda: api.jvp(foo, (1.,), (1.,)), ValueError,
"Detected differentiation w.r.t. variables from outside "
"the scope of <jax.custom_transforms function bar>, but defjvp and "
"defjvp_all only support differentiation w.r.t. positional arguments.")

def test_custom_transforms_eval_with_pytrees(self):
@api.custom_transforms
def f(x):
Expand Down

0 comments on commit 323d9f5

Please sign in to comment.