Skip to content

Commit

Permalink
[fake_tensor] Move unrecognized_type NotImplemented before ConstProp (p…
Browse files Browse the repository at this point in the history
…ytorch#135033)

We should not try to do ConstProp on the unrecognized types (e.g. Subclasses).
In case of those types throwing NotImplemented will jump to the next torch_dispatch.

Test:
```
 python test/functorch/test_aotdispatch.py -k test_aot_test_subclasses_with_tensor_factories
```
Pull Request resolved: pytorch#135033
Approved by: https://github.com/zou3519, https://github.com/bdhirsh
  • Loading branch information
IvanKobzarev authored and pytorchmergebot committed Sep 5, 2024
1 parent a096f28 commit 1efd341
Showing 1 changed file with 18 additions and 17 deletions.
35 changes: 18 additions & 17 deletions torch/_subclasses/fake_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1715,11 +1715,28 @@ def _dispatch_impl(
) -> Optional[FakeTensor]:
flat_args, args_spec = pytree.tree_flatten((args, kwargs))

# DO NOT PUT LOGIC BEFORE UNRECOGNIZED TYPE CHECKING
# We must throw NotImplemented in case of unrecognized types to handle subclasses.
# Throwing the exception will pass the control to the next __torch_dispatch__.
# See [subclass inputs] below
# NB: If you're seeing a mysterious infinite loop involving fake
# tensor, it might be related to this line. Though I'm not sure
# how you'll know to read this comment, as this line won't show up
# in the stack trace.
has_unrecognized_types = _check_for_subclass(flat_args)
if has_unrecognized_types:
unrecognized_types = [
type(x) for x in flat_args if _check_for_subclass_arg(x)
]
not_implemented_log.debug(
"FakeTensorMode unrecognized subclass(es): %s", unrecognized_types
)
return NotImplemented

flat_arg_fake_tensors = [t for t in flat_args if self.is_our_fake(t)]
has_symbolic_sizes = any(
i._has_symbolic_sizes_strides for i in flat_arg_fake_tensors
) or any(isinstance(a, SymInt) for a in flat_args)
has_subclasses = any(is_traceable_wrapper_subclass(a) for a in flat_args)

converter = self.fake_tensor_converter

Expand All @@ -1737,7 +1754,6 @@ def _dispatch_impl(
should_allow_numbers_as_tensors(func)
and not has_symbolic_sizes
and not flat_arg_fake_tensors
and not has_subclasses
):
assert all(
t.constant is not None for t in flat_arg_fake_tensors
Expand All @@ -1757,21 +1773,6 @@ def _dispatch_impl(
out = out.clone()
return converter.from_real_tensor(self, out, make_constant=True)

# See [subclass inputs] below
# NB: If you're seeing a mysterious infinite loop involving fake
# tensor, it might be related to this line. Though I'm not sure
# how you'll know to read this comment, as this line won't show up
# in the stack trace.
has_unrecognized_types = _check_for_subclass(flat_args)
if has_unrecognized_types:
unrecognized_types = [
type(x) for x in flat_args if _check_for_subclass_arg(x)
]
not_implemented_log.debug(
"FakeTensorMode unrecognized subclass(es): %s", unrecognized_types
)
return NotImplemented

# if we are in the dispatch mode, we will enter this function even if the inputs
# are not FakeTensors. For now, throw if any non-Fake Tensor inputs
# and just support constructors.
Expand Down

0 comments on commit 1efd341

Please sign in to comment.