From a8431785292f3bebb73650261a8504a187a59630 Mon Sep 17 00:00:00 2001 From: Guilherme Leobas Date: Mon, 29 Jul 2024 14:21:39 -0300 Subject: [PATCH] Let dynamo inline functional_call (#128646) Pull Request resolved: https://github.com/pytorch/pytorch/pull/128646 Approved by: https://github.com/zou3519 --- test/dynamo/test_higher_order_ops.py | 168 ++++++++++++++++++++ torch/_dynamo/trace_rules.py | 10 +- torch/_dynamo/variables/__init__.py | 1 + torch/_dynamo/variables/higher_order_ops.py | 14 ++ torch/_functorch/functional_call.py | 7 +- torch/nn/utils/_named_member_accessor.py | 12 +- torch/nn/utils/stateless.py | 125 +++++++++------ 7 files changed, 275 insertions(+), 62 deletions(-) diff --git a/test/dynamo/test_higher_order_ops.py b/test/dynamo/test_higher_order_ops.py index f16da1dfd5c31b..ae036cfbb52780 100644 --- a/test/dynamo/test_higher_order_ops.py +++ b/test/dynamo/test_higher_order_ops.py @@ -3586,6 +3586,174 @@ def wrapper_fn(x): ) self.assertEqual(actual, expected) + @config.patch(inline_inbuilt_nn_modules=True) + def test_functional_call(self): + def wrapper_fn(model, params, inputs, targets): + prediction = torch.func.functional_call(model, params, (inputs,)) + return torch.nn.functional.mse_loss(prediction, targets) + + model = torch.nn.Linear(3, 3) + params = dict(model.named_parameters()) + inputs = torch.randn(64, 3) + targets = torch.randn(64, 3) + + wrapped_gm = self._compile_check(wrapper_fn, (model, params, inputs, targets)) + # Dynamic shapes produce a slightly different graph. + if check_dynamic_shape_capture(): + return + + actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) + if torch._dynamo.config.inline_inbuilt_nn_modules: + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_model_parameters_weight_: "f32[3, 3]", L_model_parameters_bias_: "f32[3]", L_inputs_: "f32[64, 3]", L_targets_: "f32[64, 3]"): + l_model_parameters_weight_ = L_model_parameters_weight_ + l_model_parameters_bias_ = L_model_parameters_bias_ + l_inputs_ = L_inputs_ + l_targets_ = L_targets_ + + prediction: "f32[64, 3]" = torch._C._nn.linear(l_inputs_, l_model_parameters_weight_, l_model_parameters_bias_); l_inputs_ = l_model_parameters_weight_ = l_model_parameters_bias_ = None + + mse_loss: "f32[]" = torch.nn.functional.mse_loss(prediction, l_targets_); prediction = l_targets_ = None + return (mse_loss,) +""", + ) + else: + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_inputs_: "f32[64, 3]", L_targets_: "f32[64, 3]"): + l_inputs_ = L_inputs_ + l_targets_ = L_targets_ + + prediction: "f32[64, 3]" = self.model(l_inputs_); l_inputs_ = None + + mse_loss: "f32[]" = torch.nn.functional.mse_loss(prediction, l_targets_); prediction = l_targets_ = None + return (mse_loss,) +""", + ) + + @config.patch(inline_inbuilt_nn_modules=True) + def test_functional_call_sequential_params_and_buffers(self): + # copied from test/test_stateless.py + class MockModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.l1 = torch.nn.Linear(1, 1) + self.register_buffer("buffer", torch.ones(1)) + self.foo = 0.0 + + def forward(self, x): + return self.l1(x) + self.buffer + + def wrapper_fn(model, params, buffers, inputs): + # two separate dictionaries + return torch.func.functional_call(model, (params, buffers), inputs) + + model = MockModule() + params = dict(model.named_parameters()) + buffers = dict(model.named_buffers()) + inputs = torch.tensor([[1.5]]) + + wrapped_gm = self._compile_check( + wrapper_fn, (model, params, buffers, inputs), fullgraph=False + ) + # Dynamic shapes produce a slightly different graph. + if check_dynamic_shape_capture(): + return + + actual = normalize_gm(wrapped_gm.print_readable(print_output=False)) + if torch._dynamo.config.inline_inbuilt_nn_modules: + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_params_l1_weight_: "f32[1, 1]", L_params_l1_bias_: "f32[1]", L_buffers_buffer_: "f32[1]", L_inputs_: "f32[1, 1]"): + l_params_l1_weight_ = L_params_l1_weight_ + l_params_l1_bias_ = L_params_l1_bias_ + l_buffers_buffer_ = L_buffers_buffer_ + l_inputs_ = L_inputs_ + + linear: "f32[1, 1]" = torch._C._nn.linear(l_inputs_, l_params_l1_weight_, l_params_l1_bias_); l_inputs_ = l_params_l1_weight_ = l_params_l1_bias_ = None + add: "f32[1, 1]" = linear + l_buffers_buffer_; linear = l_buffers_buffer_ = None + return (add,) +""", + ) + else: + self.assertExpectedInline( + actual, + """\ +class GraphModule(torch.nn.Module): + def forward(self, L_x_: "f32[1, 1]"): + l_x_ = L_x_ + + l__self___l1: "f32[1, 1]" = self.L__self___l1(l_x_); l_x_ = None + l__self___buffer: "f32[1]" = self.L__self___buffer + add: "f32[1, 1]" = l__self___l1 + l__self___buffer; l__self___l1 = l__self___buffer = None + return (add,) +""", + ) + + @config.patch(inline_inbuilt_nn_modules=True) + def test_functional_call_disable_capture(self): + counters.clear() + + with config.patch(capture_func_transforms=False): + # We have verified above that this + # function compiles + def wrapper_fn(model, params, inputs, targets): + prediction = torch.func.functional_call(model, params, (inputs,)) + return torch.nn.functional.mse_loss(prediction, targets) + + model = torch.nn.Linear(3, 3) + params = dict(model.named_parameters()) + inputs = torch.randn(64, 3) + targets = torch.randn(64, 3) + + actual = wrapper_fn(model, params, inputs, targets) + expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)( + model, params, inputs, targets + ) + self.assertEqual(len(counters["graph_break"]), 1) + self.assertEqual( + { + "torch.func.functional_call capture is disabled, it can be " + "turned on by setting `torch._dynamo.config.capture_func_transforms=True`": 1, + }, + dict(counters["graph_break"]), + ) + self.assertEqual(actual, expected) + + @config.patch(inline_inbuilt_nn_modules=False) + def test_functional_call_disable_inline_nn_module(self): + counters.clear() + + def wrapper_fn(model, params, inputs, targets): + prediction = torch.func.functional_call(model, params, (inputs,)) + return torch.nn.functional.mse_loss(prediction, targets) + + model = torch.nn.Linear(3, 3) + params = dict(model.named_parameters()) + inputs = torch.randn(64, 3) + targets = torch.randn(64, 3) + + actual = wrapper_fn(model, params, inputs, targets) + expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)( + model, params, inputs, targets + ) + self.assertEqual(len(counters["graph_break"]), 1) + self.assertEqual( + { + "torch.func.functional_call capture is disabled, it can be " + "turned on by setting `torch._dynamo.config.inline_inbuilt_nn_modules=True`": 1, + }, + dict(counters["graph_break"]), + ) + self.assertEqual(actual, expected) + def test_grad(self): counters.clear() diff --git a/torch/_dynamo/trace_rules.py b/torch/_dynamo/trace_rules.py index 8935f49d5cc2fc..8b04cc099b196b 100644 --- a/torch/_dynamo/trace_rules.py +++ b/torch/_dynamo/trace_rules.py @@ -51,6 +51,7 @@ from .variables import ( BuiltinVariable, + FunctionalCallVariable, FunctorchHigherOrderVariable, NestedUserFunctionVariable, SkipFunctionVariable, @@ -273,6 +274,9 @@ "torch._functorch.eager_transforms.safe_unflatten": UserFunctionVariable, # functorch/hessian "torch._functorch.eager_transforms.hessian": FunctorchHigherOrderVariable, + # functional_call + "torch._functorch.functional_call.functional_call": FunctionalCallVariable, + "torch.nn.utils.stateless._groupby_tensor": TorchInGraphFunctionVariable, # functorch/deprecated "torch._functorch.deprecated.jvp": UserFunctionVariable, "torch._functorch.deprecated.hessian": UserFunctionVariable, @@ -2260,9 +2264,6 @@ "torch._functorch.eager_transforms.functionalize", "torch._functorch.eager_transforms.lazy_dynamo_disable", "torch._functorch.eager_transforms.noop", - "torch._functorch.functional_call.construct_stacked_leaf", - "torch._functorch.functional_call.functional_call", - "torch._functorch.functional_call.stack_module_state", "torch._functorch.pyfunctorch.coerce_cinterpreter", "torch._functorch.pyfunctorch.dispatch_functorch", "torch._functorch.pyfunctorch.nested", @@ -3233,6 +3234,7 @@ def _module_dir(m: types.ModuleType): "torch._higher_order_ops.strict_mode", "torch._higher_order_ops.while_loop", "torch._higher_order_ops.associative_scan", + "torch._functorch.functional_call", } @@ -3429,7 +3431,7 @@ def check_verbose(obj, is_inlined_call=False): rule = torch._dynamo.trace_rules.lookup_inner( fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons ) - if rule in [UserFunctionVariable, FunctorchHigherOrderVariable]: + if issubclass(rule, UserFunctionVariable): return SkipResult( False, f"inlined according trace_rules.lookup {reasons.pop()}", diff --git a/torch/_dynamo/variables/__init__.py b/torch/_dynamo/variables/__init__.py index 4a9c048bb6ef9c..18b53a74c7486b 100644 --- a/torch/_dynamo/variables/__init__.py +++ b/torch/_dynamo/variables/__init__.py @@ -34,6 +34,7 @@ UserMethodVariable, ) from .higher_order_ops import ( + FunctionalCallVariable, FunctorchHigherOrderVariable, TorchHigherOrderOperatorVariable, ) diff --git a/torch/_dynamo/variables/higher_order_ops.py b/torch/_dynamo/variables/higher_order_ops.py index 7d108d02f80698..42b2d1b690447b 100644 --- a/torch/_dynamo/variables/higher_order_ops.py +++ b/torch/_dynamo/variables/higher_order_ops.py @@ -1265,6 +1265,7 @@ def call_function( "jacfwd": "jacfwd", "hessian": "hessian", "linearize": "linearize", + "functional_call": "functional_call", }.get(name) assert name is not None unimplemented( @@ -1275,6 +1276,19 @@ def call_function( return super().call_function(tx, args, kwargs) +class FunctionalCallVariable(FunctorchHigherOrderVariable): + def call_function( + self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker] + ) -> VariableTracker: + if not torch._dynamo.config.inline_inbuilt_nn_modules: + unimplemented( + "torch.func.functional_call capture is disabled, " + "it can be turned on by setting " + "`torch._dynamo.config.inline_inbuilt_nn_modules=True`" + ) + return super().call_function(tx, args, kwargs) + + class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable): def create_wrapped_node( self, tx: "InstructionTranslator", args, kwargs, description diff --git a/torch/_functorch/functional_call.py b/torch/_functorch/functional_call.py index 2798dabef1eafd..86c63be17fc9d0 100644 --- a/torch/_functorch/functional_call.py +++ b/torch/_functorch/functional_call.py @@ -1,6 +1,5 @@ # mypy: allow-untyped-decorators # mypy: allow-untyped-defs -from collections import Counter from typing import Any, Dict, List, Optional, Sequence, Tuple, Union import torch @@ -128,7 +127,11 @@ def compute_loss(params, x, t): "Expected all elements of parameter_and_buffer_dicts to be dictionaries" ) all_keys = [k for d in parameter_and_buffer_dicts for k in d.keys()] - repeated_keys = [key for key, n in Counter(all_keys).items() if n > 1] + all_keys_counter: Dict[str, int] = {} + for k in all_keys: + v = all_keys_counter.get(k, 0) + all_keys_counter[k] = v + 1 + repeated_keys = [key for key, n in all_keys_counter.items() if n > 1] if len(repeated_keys) > 0: raise ValueError( f"{repeated_keys} appeared in multiple dictionaries; behavior of functional call is ambiguous" diff --git a/torch/nn/utils/_named_member_accessor.py b/torch/nn/utils/_named_member_accessor.py index e46318b0d3acb5..f1f5a117e685d2 100644 --- a/torch/nn/utils/_named_member_accessor.py +++ b/torch/nn/utils/_named_member_accessor.py @@ -59,13 +59,11 @@ def swap_tensor( else: del module._buffers[name] else: - try: + if hasattr(module, name): orig_tensor = getattr(module, name) - except AttributeError as ex: + else: if not allow_missing: - raise AttributeError( - f"{module._get_name()} has no attribute `{name}`" - ) from ex + raise AttributeError(f"{module._get_name()} has no attribute `{name}`") orig_tensor = _MISSING if ( orig_tensor is not _MISSING @@ -132,9 +130,9 @@ def get_submodule(self, name: str) -> "torch.nn.Module": if not name: return self.module - try: + if name in self.memo: return self.memo[name] - except KeyError: + else: prefix, dot, attr = name.rpartition(".") if dot: module = self.get_submodule(prefix) diff --git a/torch/nn/utils/stateless.py b/torch/nn/utils/stateless.py index 78e5a7f184ae41..69994f315f77e6 100644 --- a/torch/nn/utils/stateless.py +++ b/torch/nn/utils/stateless.py @@ -1,7 +1,5 @@ # mypy: allow-untyped-defs -import contextlib -from collections import defaultdict -from typing import Any, Dict, Iterator, Optional, Set, Tuple, Union +from typing import Any, Dict, Optional, Set, Tuple, Union from typing_extensions import deprecated import torch @@ -47,8 +45,10 @@ def _untie_named_tensors_map( all_named_tensors.update(module.named_buffers(remove_duplicate=False)) # A map of {tensor: set(all_tied_names)} for all tensor names in the module. - tensor_to_tied_names_map: Dict[Tensor, Set[str]] = defaultdict(set) + tensor_to_tied_names_map: Dict[Tensor, Set[str]] = {} for name, tensor in all_named_tensors.items(): + if tensor not in tensor_to_tied_names_map: + tensor_to_tied_names_map[tensor] = set() tensor_to_tied_names_map[tensor].add(name) # A map of {tied_name: set(all_tied_names)} for all tensor names in the module. @@ -61,7 +61,13 @@ def _untie_named_tensors_map( # Make sure the user didn't pass multiple values for the same tied tensor. given_names = set(parameters_and_buffers.keys()) - given_names_for_tied_tensors = given_names.intersection(tied_names_map.keys()) + # same as given_names.intersection(tied_names_map.keys()) but dynamo can't + # handle that + given_names_for_tied_tensors: set[str] = set() + for name in given_names: + if name in tied_names_map: + given_names_for_tied_tensors.add(name) + for given_name in given_names_for_tied_tensors: tied_names = tied_names_map[given_name] if ( @@ -88,68 +94,89 @@ def _untie_named_tensors_map( return untied_parameters_and_buffers -@contextlib.contextmanager -def _reparametrize_module( - module: "torch.nn.Module", - parameters_and_buffers: Dict[str, Tensor], - *, - tie_weights: bool = False, - strict: bool = False, - stack_weights: bool = False, -) -> Iterator[None]: - if tie_weights: - untied_parameters_and_buffers = _untie_named_tensors_map( - module, parameters_and_buffers - ) - else: - untied_parameters_and_buffers = parameters_and_buffers +class _ReparametrizeModule: + def __init__( + self, + module: "torch.nn.Module", + parameters_and_buffers: Dict[str, Tensor], + tie_weights: bool = False, + strict: bool = False, + stack_weights: bool = False, + ): + self.parameters_and_buffers = parameters_and_buffers + self.stack_weights = stack_weights - accessor = NamedMemberAccessor(module) - if strict: - missing_keys, unexpected_keys = accessor.check_keys( - untied_parameters_and_buffers - ) - error_msgs = [] - if len(unexpected_keys) > 0: - error_msgs.append( - f"Unexpected key(s): {', '.join(map(repr, unexpected_keys))}." + if tie_weights: + self.untied_parameters_and_buffers = _untie_named_tensors_map( + module, parameters_and_buffers ) - if len(missing_keys) > 0: - error_msgs.append(f"Missing key(s): {', '.join(map(repr, missing_keys))}.") - if len(error_msgs) > 0: - raise RuntimeError( - "Error(s) in reparametrizing for {}:\n\t{}".format( - module._get_name(), "\n\t".join(error_msgs) - ) + else: + self.untied_parameters_and_buffers = parameters_and_buffers + + self.accessor = NamedMemberAccessor(module) + if strict: + missing_keys, unexpected_keys = self.accessor.check_keys( + self.untied_parameters_and_buffers ) + error_msgs = [] + if len(unexpected_keys) > 0: + error_msgs.append( + f"Unexpected key(s): {', '.join(map(repr, unexpected_keys))}." + ) + if len(missing_keys) > 0: + error_msgs.append( + f"Missing key(s): {', '.join(map(repr, missing_keys))}." + ) + if len(error_msgs) > 0: + raise RuntimeError( + "Error(s) in reparametrizing for {}:\n\t{}".format( + module._get_name(), "\n\t".join(error_msgs) + ) + ) - orig_parameters_and_buffers: Dict[str, Tensor] = {} - try: - orig_parameters_and_buffers, _ = accessor.swap_tensors_dict( - untied_parameters_and_buffers, allow_missing=True + def __enter__(self): + self.orig_parameters_and_buffers, _ = self.accessor.swap_tensors_dict( + self.untied_parameters_and_buffers, allow_missing=True ) - yield - finally: - if stack_weights: + + def __exit__(self, exception_type, exception_value, traceback): + if self.stack_weights: # When stacking is enabled, we will restore the weights in LIFO order. - orig_parameters_and_buffers = dict( - reversed(orig_parameters_and_buffers.items()) + self.orig_parameters_and_buffers = dict( + reversed(self.orig_parameters_and_buffers.items()) ) - new_parameters_and_buffers, _ = accessor.swap_tensors_dict( - orig_parameters_and_buffers, allow_missing=True + new_parameters_and_buffers, _ = self.accessor.swap_tensors_dict( + self.orig_parameters_and_buffers, allow_missing=True ) # Sometimes the module is not completely stateless and has some in-place modifications on # the _parameters and _buffers dictionaries. # Write the changed parameters and buffers back to the original dict. - parameters_and_buffers.update( + self.parameters_and_buffers.update( { k: new_parameters_and_buffers[k] - for k in parameters_and_buffers + for k in self.parameters_and_buffers if k in new_parameters_and_buffers } ) +def _reparametrize_module( + module: "torch.nn.Module", + parameters_and_buffers: Dict[str, Tensor], + *, + tie_weights: bool = False, + strict: bool = False, + stack_weights: bool = False, +) -> _ReparametrizeModule: + return _ReparametrizeModule( + module, + parameters_and_buffers, + tie_weights=tie_weights, + strict=strict, + stack_weights=stack_weights, + ) + + @deprecated( "`torch.nn.utils.stateless.functional_call` is deprecated as of PyTorch 2.0 " "and will be removed in a future version of PyTorch. "