Skip to content

Commit

Permalink
[fix] jacrev and jacfwd : support non-tensor args again (pytorch#97746)
Browse files Browse the repository at this point in the history
Fixes pytorch#97636

The code to check if argument tensor are complex assumed that all arguments are tensor (which is not the case) which lead to the error.

Pull Request resolved: pytorch#97746
Approved by: https://github.com/zou3519
  • Loading branch information
kshitij12345 authored and pytorchmergebot committed Mar 28, 2023
1 parent 1c83888 commit 2b369eb
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
10 changes: 10 additions & 0 deletions test/functorch/test_eager_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2216,6 +2216,16 @@ def fn(x):
with self.assertRaisesRegex(RuntimeError, "jacfwd: Expected all outputs"):
jacfwd(fn)(x)

@jacrev_and_jacfwd
def test_jac_with_non_tensor_args(self, device, jacapi):
def f(t, int_x):
return t + int_x

t = torch.randn(3, 3, device=device)

actual = jacapi(f)(t, 3)
expected = torch.autograd.functional.jacobian(partial(f, int_x=3), t)
self.assertEqual(actual, expected)

class TestHessian(TestCase):
def _test_against_reference(self, f, inputs):
Expand Down
2 changes: 1 addition & 1 deletion torch/_functorch/eager_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,7 @@ def _safe_zero_index(x):
def error_if_complex(func_name, args, is_input):
flat_args, _ = tree_flatten(args)
for idx, arg in enumerate(flat_args):
if arg.dtype.is_complex:
if isinstance(arg, torch.Tensor) and arg.dtype.is_complex:
input_or_output = ("inputs" if is_input else "outputs")
err_msg = (f"{func_name}: Expected all {input_or_output} "
f"to be real but received complex tensor at flattened input idx: {idx}")
Expand Down

0 comments on commit 2b369eb

Please sign in to comment.