Skip to content

Commit

Permalink
Fix FakeTensor printing (pytorch#99205)
Browse files Browse the repository at this point in the history
I got too confused by the FakeTensor printing, so this PR fixes it to
print normally.

Before:
```
with FakeTensorMode():
    x = torch.empty(2, 2, device="cpu")
    print(x)
    # FakeTensor(FakeTensor(..., device='meta', shape=(2, 2)), cpu)
```
After (Tensor printing doesn't print the default device):
```
FakeTensor(..., shape=(2, 2))
```

Test Plan:
- new test
Pull Request resolved: pytorch#99205
Approved by: https://github.com/eellison
  • Loading branch information
zou3519 authored and pytorchmergebot committed Apr 18, 2023
1 parent 20a90a1 commit 57e1a50
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
7 changes: 7 additions & 0 deletions test/test_fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ def test_shape_take_not_device(self):
self.assertEqual(out.device.type, "cpu")
self.assertTrue(isinstance(out, FakeTensor))

def test_repr(self):
with FakeTensorMode():
x = torch.empty(2, 2, device="cpu")
self.assertEqual(repr(x), 'FakeTensor(..., size=(2, 2))')
x = torch.empty(2, 2, device="meta")
self.assertEqual(repr(x), "FakeTensor(..., device='meta', size=(2, 2))")

@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_zero_dim(self):
with FakeTensorMode() as mode:
Expand Down
6 changes: 0 additions & 6 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -924,12 +924,6 @@ def __init__(self, *args, **kwargs):
def from_tensor(t, fake_mode):
return fake_mode.from_tensor(t)

# TODO: resolve error in default __repr__
def __repr__(self):
with in_kernel_invocation_manager(self.fake_mode):
self_repr = super().__repr__()
return f"FakeTensor({self_repr}, {self.fake_device})"

@classmethod
@count
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
Expand Down
7 changes: 5 additions & 2 deletions torch/_tensor_str.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,12 +536,15 @@ def indented_str(s, indent):
prefix = "_to_functional_tensor("
tensor_str = repr(torch._from_functional_tensor(self))
else:
if self.is_meta:
# Circular import problem, so we import it here
from torch._subclasses.fake_tensor import FakeTensor

if self.is_meta or isinstance(self, FakeTensor):
suffixes.append("size=" + str(tuple(self.shape)))
if self.dtype != torch.get_default_dtype():
suffixes.append("dtype=" + str(self.dtype))
# TODO: This implies that ellipses is valid syntax for allocating
# a meta tensor, which it could be, but it isn't right now
# a meta tensor or FakeTensor, which it could be, but it isn't right now
if not custom_contents_provided:
tensor_str = "..."
else:
Expand Down

0 comments on commit 57e1a50

Please sign in to comment.