From 1efd341d15cf9052f074734e75b913ff6db0cc39 Mon Sep 17 00:00:00 2001 From: IvanKobzarev Date: Thu, 5 Sep 2024 06:45:54 -0700 Subject: [PATCH] [fake_tensor] Move unrecognized_type NotImplemented before ConstProp (#135033) We should not try to do ConstProp on the unrecognized types (e.g. Subclasses). In case of those types throwing NotImplemented will jump to the next torch_dispatch. Test: ``` python test/functorch/test_aotdispatch.py -k test_aot_test_subclasses_with_tensor_factories ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/135033 Approved by: https://github.com/zou3519, https://github.com/bdhirsh --- torch/_subclasses/fake_tensor.py | 35 ++++++++++++++++---------------- 1 file changed, 18 insertions(+), 17 deletions(-) diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py index 51f726f53bacc3..27348dca80e483 100644 --- a/torch/_subclasses/fake_tensor.py +++ b/torch/_subclasses/fake_tensor.py @@ -1715,11 +1715,28 @@ def _dispatch_impl( ) -> Optional[FakeTensor]: flat_args, args_spec = pytree.tree_flatten((args, kwargs)) + # DO NOT PUT LOGIC BEFORE UNRECOGNIZED TYPE CHECKING + # We must throw NotImplemented in case of unrecognized types to handle subclasses. + # Throwing the exception will pass the control to the next __torch_dispatch__. + # See [subclass inputs] below + # NB: If you're seeing a mysterious infinite loop involving fake + # tensor, it might be related to this line. Though I'm not sure + # how you'll know to read this comment, as this line won't show up + # in the stack trace. + has_unrecognized_types = _check_for_subclass(flat_args) + if has_unrecognized_types: + unrecognized_types = [ + type(x) for x in flat_args if _check_for_subclass_arg(x) + ] + not_implemented_log.debug( + "FakeTensorMode unrecognized subclass(es): %s", unrecognized_types + ) + return NotImplemented + flat_arg_fake_tensors = [t for t in flat_args if self.is_our_fake(t)] has_symbolic_sizes = any( i._has_symbolic_sizes_strides for i in flat_arg_fake_tensors ) or any(isinstance(a, SymInt) for a in flat_args) - has_subclasses = any(is_traceable_wrapper_subclass(a) for a in flat_args) converter = self.fake_tensor_converter @@ -1737,7 +1754,6 @@ def _dispatch_impl( should_allow_numbers_as_tensors(func) and not has_symbolic_sizes and not flat_arg_fake_tensors - and not has_subclasses ): assert all( t.constant is not None for t in flat_arg_fake_tensors @@ -1757,21 +1773,6 @@ def _dispatch_impl( out = out.clone() return converter.from_real_tensor(self, out, make_constant=True) - # See [subclass inputs] below - # NB: If you're seeing a mysterious infinite loop involving fake - # tensor, it might be related to this line. Though I'm not sure - # how you'll know to read this comment, as this line won't show up - # in the stack trace. - has_unrecognized_types = _check_for_subclass(flat_args) - if has_unrecognized_types: - unrecognized_types = [ - type(x) for x in flat_args if _check_for_subclass_arg(x) - ] - not_implemented_log.debug( - "FakeTensorMode unrecognized subclass(es): %s", unrecognized_types - ) - return NotImplemented - # if we are in the dispatch mode, we will enter this function even if the inputs # are not FakeTensors. For now, throw if any non-Fake Tensor inputs # and just support constructors.