diff --git a/test/test_autograd.py b/test/test_autograd.py index 671ebc449de4d..10e8ad4096921 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -7021,6 +7021,32 @@ def run_tests(fn): run_tests(lambda v: v.swapdims_(0, 0)) run_tests(lambda v: v.swapaxes_(0, 0)) + def test_autograd_print_tensor(self): + a = torch.ones(1, requires_grad=True) + a_clone = a.clone() + self.assertEqual(repr(a), "tensor([1.], requires_grad=True)") + self.assertEqual(repr(a_clone), "tensor([1.], grad_fn=)") + + with torch.no_grad(): + b = a[:] + b *= 2 + + # Special handling for printing view created in no-grad and modified + # in-placed in no-grad. + self.assertEqual(repr(b), "tensor([2.], grad_fn=)") + + class Func(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + return x + + @staticmethod + def backward(ctx, x): + return x + + c = Func.apply(a) + self.assertEqual(repr(c), "tensor([2.], grad_fn=)") + def test_autograd_inplace_view_of_view(self): x = torch.zeros(2) with torch.no_grad(): diff --git a/torch/_tensor_str.py b/torch/_tensor_str.py index 364149e7bb7e5..1293a0fd61aec 100644 --- a/torch/_tensor_str.py +++ b/torch/_tensor_str.py @@ -601,11 +601,22 @@ def indented_str(s, indent): # Use inp here to get the original grad_fn and not the one generated by the forward grad # unpacking. - if inp.grad_fn is not None: - name = type(inp.grad_fn).__name__ - if name == "CppFunction": - name = inp.grad_fn.name().rsplit("::", 1)[-1] - suffixes.append(f"grad_fn=<{name}>") + grad_fn_name = None + try: + grad_fn = inp.grad_fn + except RuntimeError: + # Accessing the grad_fn calls rebasing logic which would cause an error + # if that tensor is a view created in no-grad mode modified in-place in + # no-grad mode. See: https://github.com/pytorch/pytorch/issues/99968 + grad_fn_name = "Invalid" + + if grad_fn_name is None and grad_fn is not None: + grad_fn_name = type(grad_fn).__name__ + if grad_fn_name == "CppFunction": + grad_fn_name = grad_fn.name().rsplit("::", 1)[-1] + + if grad_fn_name is not None: + suffixes.append(f"grad_fn=<{grad_fn_name}>") elif inp.requires_grad: suffixes.append("requires_grad=True")