Skip to content

Commit

Permalink
Let dynamo inline functional_call (pytorch#128646)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#128646
Approved by: https://github.com/zou3519
  • Loading branch information
guilhermeleobas authored and pytorchmergebot committed Jul 30, 2024
1 parent 12b67bd commit a843178
Show file tree
Hide file tree
Showing 7 changed files with 275 additions and 62 deletions.
168 changes: 168 additions & 0 deletions test/dynamo/test_higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3586,6 +3586,174 @@ def wrapper_fn(x):
)
self.assertEqual(actual, expected)

@config.patch(inline_inbuilt_nn_modules=True)
def test_functional_call(self):
def wrapper_fn(model, params, inputs, targets):
prediction = torch.func.functional_call(model, params, (inputs,))
return torch.nn.functional.mse_loss(prediction, targets)

model = torch.nn.Linear(3, 3)
params = dict(model.named_parameters())
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)

wrapped_gm = self._compile_check(wrapper_fn, (model, params, inputs, targets))
# Dynamic shapes produce a slightly different graph.
if check_dynamic_shape_capture():
return

actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
if torch._dynamo.config.inline_inbuilt_nn_modules:
self.assertExpectedInline(
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_model_parameters_weight_: "f32[3, 3]", L_model_parameters_bias_: "f32[3]", L_inputs_: "f32[64, 3]", L_targets_: "f32[64, 3]"):
l_model_parameters_weight_ = L_model_parameters_weight_
l_model_parameters_bias_ = L_model_parameters_bias_
l_inputs_ = L_inputs_
l_targets_ = L_targets_
prediction: "f32[64, 3]" = torch._C._nn.linear(l_inputs_, l_model_parameters_weight_, l_model_parameters_bias_); l_inputs_ = l_model_parameters_weight_ = l_model_parameters_bias_ = None
mse_loss: "f32[]" = torch.nn.functional.mse_loss(prediction, l_targets_); prediction = l_targets_ = None
return (mse_loss,)
""",
)
else:
self.assertExpectedInline(
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_inputs_: "f32[64, 3]", L_targets_: "f32[64, 3]"):
l_inputs_ = L_inputs_
l_targets_ = L_targets_
prediction: "f32[64, 3]" = self.model(l_inputs_); l_inputs_ = None
mse_loss: "f32[]" = torch.nn.functional.mse_loss(prediction, l_targets_); prediction = l_targets_ = None
return (mse_loss,)
""",
)

@config.patch(inline_inbuilt_nn_modules=True)
def test_functional_call_sequential_params_and_buffers(self):
# copied from test/test_stateless.py
class MockModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.l1 = torch.nn.Linear(1, 1)
self.register_buffer("buffer", torch.ones(1))
self.foo = 0.0

def forward(self, x):
return self.l1(x) + self.buffer

def wrapper_fn(model, params, buffers, inputs):
# two separate dictionaries
return torch.func.functional_call(model, (params, buffers), inputs)

model = MockModule()
params = dict(model.named_parameters())
buffers = dict(model.named_buffers())
inputs = torch.tensor([[1.5]])

wrapped_gm = self._compile_check(
wrapper_fn, (model, params, buffers, inputs), fullgraph=False
)
# Dynamic shapes produce a slightly different graph.
if check_dynamic_shape_capture():
return

actual = normalize_gm(wrapped_gm.print_readable(print_output=False))
if torch._dynamo.config.inline_inbuilt_nn_modules:
self.assertExpectedInline(
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_params_l1_weight_: "f32[1, 1]", L_params_l1_bias_: "f32[1]", L_buffers_buffer_: "f32[1]", L_inputs_: "f32[1, 1]"):
l_params_l1_weight_ = L_params_l1_weight_
l_params_l1_bias_ = L_params_l1_bias_
l_buffers_buffer_ = L_buffers_buffer_
l_inputs_ = L_inputs_
linear: "f32[1, 1]" = torch._C._nn.linear(l_inputs_, l_params_l1_weight_, l_params_l1_bias_); l_inputs_ = l_params_l1_weight_ = l_params_l1_bias_ = None
add: "f32[1, 1]" = linear + l_buffers_buffer_; linear = l_buffers_buffer_ = None
return (add,)
""",
)
else:
self.assertExpectedInline(
actual,
"""\
class GraphModule(torch.nn.Module):
def forward(self, L_x_: "f32[1, 1]"):
l_x_ = L_x_
l__self___l1: "f32[1, 1]" = self.L__self___l1(l_x_); l_x_ = None
l__self___buffer: "f32[1]" = self.L__self___buffer
add: "f32[1, 1]" = l__self___l1 + l__self___buffer; l__self___l1 = l__self___buffer = None
return (add,)
""",
)

@config.patch(inline_inbuilt_nn_modules=True)
def test_functional_call_disable_capture(self):
counters.clear()

with config.patch(capture_func_transforms=False):
# We have verified above that this
# function compiles
def wrapper_fn(model, params, inputs, targets):
prediction = torch.func.functional_call(model, params, (inputs,))
return torch.nn.functional.mse_loss(prediction, targets)

model = torch.nn.Linear(3, 3)
params = dict(model.named_parameters())
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)

actual = wrapper_fn(model, params, inputs, targets)
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(
model, params, inputs, targets
)
self.assertEqual(len(counters["graph_break"]), 1)
self.assertEqual(
{
"torch.func.functional_call capture is disabled, it can be "
"turned on by setting `torch._dynamo.config.capture_func_transforms=True`": 1,
},
dict(counters["graph_break"]),
)
self.assertEqual(actual, expected)

@config.patch(inline_inbuilt_nn_modules=False)
def test_functional_call_disable_inline_nn_module(self):
counters.clear()

def wrapper_fn(model, params, inputs, targets):
prediction = torch.func.functional_call(model, params, (inputs,))
return torch.nn.functional.mse_loss(prediction, targets)

model = torch.nn.Linear(3, 3)
params = dict(model.named_parameters())
inputs = torch.randn(64, 3)
targets = torch.randn(64, 3)

actual = wrapper_fn(model, params, inputs, targets)
expected = torch.compile(wrapper_fn, backend="aot_eager", fullgraph=False)(
model, params, inputs, targets
)
self.assertEqual(len(counters["graph_break"]), 1)
self.assertEqual(
{
"torch.func.functional_call capture is disabled, it can be "
"turned on by setting `torch._dynamo.config.inline_inbuilt_nn_modules=True`": 1,
},
dict(counters["graph_break"]),
)
self.assertEqual(actual, expected)

def test_grad(self):
counters.clear()

Expand Down
10 changes: 6 additions & 4 deletions torch/_dynamo/trace_rules.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@

from .variables import (
BuiltinVariable,
FunctionalCallVariable,
FunctorchHigherOrderVariable,
NestedUserFunctionVariable,
SkipFunctionVariable,
Expand Down Expand Up @@ -273,6 +274,9 @@
"torch._functorch.eager_transforms.safe_unflatten": UserFunctionVariable,
# functorch/hessian
"torch._functorch.eager_transforms.hessian": FunctorchHigherOrderVariable,
# functional_call
"torch._functorch.functional_call.functional_call": FunctionalCallVariable,
"torch.nn.utils.stateless._groupby_tensor": TorchInGraphFunctionVariable,
# functorch/deprecated
"torch._functorch.deprecated.jvp": UserFunctionVariable,
"torch._functorch.deprecated.hessian": UserFunctionVariable,
Expand Down Expand Up @@ -2260,9 +2264,6 @@
"torch._functorch.eager_transforms.functionalize",
"torch._functorch.eager_transforms.lazy_dynamo_disable",
"torch._functorch.eager_transforms.noop",
"torch._functorch.functional_call.construct_stacked_leaf",
"torch._functorch.functional_call.functional_call",
"torch._functorch.functional_call.stack_module_state",
"torch._functorch.pyfunctorch.coerce_cinterpreter",
"torch._functorch.pyfunctorch.dispatch_functorch",
"torch._functorch.pyfunctorch.nested",
Expand Down Expand Up @@ -3233,6 +3234,7 @@ def _module_dir(m: types.ModuleType):
"torch._higher_order_ops.strict_mode",
"torch._higher_order_ops.while_loop",
"torch._higher_order_ops.associative_scan",
"torch._functorch.functional_call",
}


Expand Down Expand Up @@ -3429,7 +3431,7 @@ def check_verbose(obj, is_inlined_call=False):
rule = torch._dynamo.trace_rules.lookup_inner(
fi.py_obj, fi.name, fi.filename, is_inlined_call, reasons
)
if rule in [UserFunctionVariable, FunctorchHigherOrderVariable]:
if issubclass(rule, UserFunctionVariable):
return SkipResult(
False,
f"inlined according trace_rules.lookup {reasons.pop()}",
Expand Down
1 change: 1 addition & 0 deletions torch/_dynamo/variables/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
UserMethodVariable,
)
from .higher_order_ops import (
FunctionalCallVariable,
FunctorchHigherOrderVariable,
TorchHigherOrderOperatorVariable,
)
Expand Down
14 changes: 14 additions & 0 deletions torch/_dynamo/variables/higher_order_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,6 +1265,7 @@ def call_function(
"jacfwd": "jacfwd",
"hessian": "hessian",
"linearize": "linearize",
"functional_call": "functional_call",
}.get(name)
assert name is not None
unimplemented(
Expand All @@ -1275,6 +1276,19 @@ def call_function(
return super().call_function(tx, args, kwargs)


class FunctionalCallVariable(FunctorchHigherOrderVariable):
def call_function(
self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
) -> VariableTracker:
if not torch._dynamo.config.inline_inbuilt_nn_modules:
unimplemented(
"torch.func.functional_call capture is disabled, "
"it can be turned on by setting "
"`torch._dynamo.config.inline_inbuilt_nn_modules=True`"
)
return super().call_function(tx, args, kwargs)


class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable):
def create_wrapped_node(
self, tx: "InstructionTranslator", args, kwargs, description
Expand Down
7 changes: 5 additions & 2 deletions torch/_functorch/functional_call.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# mypy: allow-untyped-decorators
# mypy: allow-untyped-defs
from collections import Counter
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import torch
Expand Down Expand Up @@ -128,7 +127,11 @@ def compute_loss(params, x, t):
"Expected all elements of parameter_and_buffer_dicts to be dictionaries"
)
all_keys = [k for d in parameter_and_buffer_dicts for k in d.keys()]
repeated_keys = [key for key, n in Counter(all_keys).items() if n > 1]
all_keys_counter: Dict[str, int] = {}
for k in all_keys:
v = all_keys_counter.get(k, 0)
all_keys_counter[k] = v + 1
repeated_keys = [key for key, n in all_keys_counter.items() if n > 1]
if len(repeated_keys) > 0:
raise ValueError(
f"{repeated_keys} appeared in multiple dictionaries; behavior of functional call is ambiguous"
Expand Down
12 changes: 5 additions & 7 deletions torch/nn/utils/_named_member_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,13 +59,11 @@ def swap_tensor(
else:
del module._buffers[name]
else:
try:
if hasattr(module, name):
orig_tensor = getattr(module, name)
except AttributeError as ex:
else:
if not allow_missing:
raise AttributeError(
f"{module._get_name()} has no attribute `{name}`"
) from ex
raise AttributeError(f"{module._get_name()} has no attribute `{name}`")
orig_tensor = _MISSING
if (
orig_tensor is not _MISSING
Expand Down Expand Up @@ -132,9 +130,9 @@ def get_submodule(self, name: str) -> "torch.nn.Module":
if not name:
return self.module

try:
if name in self.memo:
return self.memo[name]
except KeyError:
else:
prefix, dot, attr = name.rpartition(".")
if dot:
module = self.get_submodule(prefix)
Expand Down
Loading

0 comments on commit a843178

Please sign in to comment.