Skip to content

Commit

Permalink
Do not error when printing view created in no-grad modified in-place …
Browse files Browse the repository at this point in the history
…in no-grad (pytorch#113716)

Fixes pytorch#99968

Pull Request resolved: pytorch#113716
Approved by: https://github.com/albanD
  • Loading branch information
soulitzer authored and pytorchmergebot committed Nov 16, 2023
1 parent 6cdb623 commit 3e3c6cc
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 5 deletions.
26 changes: 26 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7021,6 +7021,32 @@ def run_tests(fn):
run_tests(lambda v: v.swapdims_(0, 0))
run_tests(lambda v: v.swapaxes_(0, 0))

def test_autograd_print_tensor(self):
a = torch.ones(1, requires_grad=True)
a_clone = a.clone()
self.assertEqual(repr(a), "tensor([1.], requires_grad=True)")
self.assertEqual(repr(a_clone), "tensor([1.], grad_fn=<CloneBackward0>)")

with torch.no_grad():
b = a[:]
b *= 2

# Special handling for printing view created in no-grad and modified
# in-placed in no-grad.
self.assertEqual(repr(b), "tensor([2.], grad_fn=<Invalid>)")

class Func(torch.autograd.Function):
@staticmethod
def forward(ctx, x):
return x

@staticmethod
def backward(ctx, x):
return x

c = Func.apply(a)
self.assertEqual(repr(c), "tensor([2.], grad_fn=<FuncBackward>)")

def test_autograd_inplace_view_of_view(self):
x = torch.zeros(2)
with torch.no_grad():
Expand Down
21 changes: 16 additions & 5 deletions torch/_tensor_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,11 +601,22 @@ def indented_str(s, indent):

# Use inp here to get the original grad_fn and not the one generated by the forward grad
# unpacking.
if inp.grad_fn is not None:
name = type(inp.grad_fn).__name__
if name == "CppFunction":
name = inp.grad_fn.name().rsplit("::", 1)[-1]
suffixes.append(f"grad_fn=<{name}>")
grad_fn_name = None
try:
grad_fn = inp.grad_fn
except RuntimeError:
# Accessing the grad_fn calls rebasing logic which would cause an error
# if that tensor is a view created in no-grad mode modified in-place in
# no-grad mode. See: https://github.com/pytorch/pytorch/issues/99968
grad_fn_name = "Invalid"

if grad_fn_name is None and grad_fn is not None:
grad_fn_name = type(grad_fn).__name__
if grad_fn_name == "CppFunction":
grad_fn_name = grad_fn.name().rsplit("::", 1)[-1]

if grad_fn_name is not None:
suffixes.append(f"grad_fn=<{grad_fn_name}>")
elif inp.requires_grad:
suffixes.append("requires_grad=True")

Expand Down

0 comments on commit 3e3c6cc

Please sign in to comment.