Skip to content

Commit

Permalink
Fix graph partitioner and make runtime assertion work with submodules…
Browse files Browse the repository at this point in the history
… in export (pytorch#125793)

Summary: This fix does three things:

1. When we add inputs from partioner to the top level graph module, we insert in the order of partioner which is not guaranteed to be same as original graph inputs. This PR fixes that.
2. When we replace autograd ops with HOP, we create new submodules and access their outputs via getitem calls. As a result, previous node names associated with getitem gets updated, resulting in the graph being different from produced graph signature. So I just update the graph signature accordingly.
3. We run runtime_assertion pass before autograd HOP pass because the constraints won't be populated correctly.

Differential Revision: [D57130314](https://our.internmc.facebook.com/intern/diff/D57130314)
Pull Request resolved: pytorch#125793
Approved by: https://github.com/zhxchen17
  • Loading branch information
tugsbayasgalan authored and pytorchmergebot committed May 9, 2024
1 parent 98821b3 commit 0e419b9
Show file tree
Hide file tree
Showing 5 changed files with 177 additions and 91 deletions.
82 changes: 78 additions & 4 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@
"(Tensor x) -> (Tensor)",
tags=torch.Tag.pt2_compliant_tag,
)
torch.library.define(
"testlib::foo_unbacked",
"(Scalar x) -> (Tensor)",
tags=torch.Tag.pt2_compliant_tag,
)


@torch.library.impl("testlib::returns_tensor_symint", "cpu")
Expand Down Expand Up @@ -125,6 +130,15 @@ def foo_functional(x):
return a.cos()


@torch.library.impl("testlib::foo_unbacked", "CompositeImplicitAutograd")
def foo_unbacked(x):
if x > 2:
return torch.ones(4, 4)
if x < 6:
return torch.ones(4, 4)
return torch.ones(4, 4)


@dataclass
class Inp:
x: Tensor
Expand Down Expand Up @@ -2415,6 +2429,7 @@ def forward(self, x, y):

ep = export(M(), (torch.tensor(1), torch.ones(4, 5)))

# This is because we insert sym_constrain_range in the graph now
if is_non_strict_test(self._testMethodName):
error_msg = "Invalid value range"
else:
Expand Down Expand Up @@ -4078,16 +4093,16 @@ def forward(self, b_pred, b_t, x, y):
"""\
def forward(self, b_t, x, y):
submod_3 = self.submod_1
add_1 = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(True, submod_3, b_t, x, y); submod_3 = b_t = x = y = None
add_1 = torch._higher_order_ops.wrap.wrap_with_set_grad_enabled(True, submod_3, x, b_t, y); submod_3 = x = b_t = y = None
return (add_1,)""",
)

self.assertExpectedInline(
str(exported_program.graph_module.true_graph_0.submod_1.code.strip()),
"""\
def forward(self, b_t, x, y):
sub = torch.ops.aten.sub.Tensor(b_t, 1); b_t = None
add = torch.ops.aten.add.Tensor(sub, x); sub = x = None
def forward(self, x, b_t, y):
sub = torch.ops.aten.sub.Tensor(x, 1); x = None
add = torch.ops.aten.add.Tensor(sub, b_t); sub = b_t = None
add_1 = torch.ops.aten.add.Tensor(add, y); add = y = None
return add_1""",
)
Expand Down Expand Up @@ -4587,6 +4602,65 @@ def forward(self, x, y, div="floor"):
self.assertEqual(div_spec.arg.name, "div")
self.assertEqual(div_spec.arg.value, "floor")

def test_unbacked_deferred_runtime_retrace(self):
class Foo(torch.nn.Module):
def forward(self, x, y):
y_sum = y.sin().sum()
with torch.no_grad():
a = x.item()
torch._check_is_size(a)
torch._check(a > 2)
torch._check(a < 6)
unbacked_shape = torch.ops.testlib.foo_unbacked(a)
return y + y_sum + unbacked_shape.sum()

inps = (torch.tensor(4), torch.randn(5, 5))
from torch.export import _trace

ep_pre = _trace._export(Foo(), inps, pre_dispatch=True, strict=False)
self.assertExpectedInline(
str(ep_pre.graph_module.submod_1.code).strip(),
"""\
def forward(self, x):
item = torch.ops.aten.item.default(x); x = None
sym_constrain_range_for_size_default = torch.ops.aten.sym_constrain_range_for_size.default(item)
sym_constrain_range_default = torch.ops.aten.sym_constrain_range.default(item, min = 3, max = 5)
mul = -1 * item
le = mul <= 0; mul = None
_assert_scalar_default = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression -u1 <= 0 on node 'le'"); le = None
mul_1 = -1 * item
lt = mul_1 < -2; mul_1 = None
_assert_scalar_default_1 = torch.ops.aten._assert_scalar.default(lt, "Runtime assertion failed for expression -u1 < -2 on node 'lt'"); lt = None
lt_1 = item < 6
_assert_scalar_default_2 = torch.ops.aten._assert_scalar.default(lt_1, "Runtime assertion failed for expression u1 < 6 on node 'lt_1'"); lt_1 = None
foo_unbacked = torch.ops.testlib.foo_unbacked.default(item); item = None
return foo_unbacked""",
)
ep_aot = ep_pre.run_decompositions()
self.assertExpectedInline(
str(ep_aot.graph_module.code).strip(),
"""\
def forward(self, x, y):
sin = torch.ops.aten.sin.default(y)
sum_1 = torch.ops.aten.sum.dim_IntList(sin, []); sin = None
_local_scalar_dense = torch.ops.aten._local_scalar_dense.default(x); x = None
sym_constrain_range_for_size = torch.ops.aten.sym_constrain_range_for_size.default(_local_scalar_dense)
sym_constrain_range = torch.ops.aten.sym_constrain_range.default(_local_scalar_dense, min = 3, max = 5)
mul = -1 * _local_scalar_dense
le = mul <= 0; mul = None
_assert_scalar = torch.ops.aten._assert_scalar.default(le, "Runtime assertion failed for expression -u1 <= 0 on node 'le'"); le = None
mul_1 = -1 * _local_scalar_dense
lt = mul_1 < -2; mul_1 = None
_assert_scalar_1 = torch.ops.aten._assert_scalar.default(lt, "Runtime assertion failed for expression -u1 < -2 on node 'lt'"); lt = None
lt_1 = _local_scalar_dense < 6; _local_scalar_dense = None
_assert_scalar_2 = torch.ops.aten._assert_scalar.default(lt_1, "Runtime assertion failed for expression u1 < 6 on node 'lt_1'"); lt_1 = None
full = torch.ops.aten.full.default([4, 4], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False)
add = torch.ops.aten.add.Tensor(y, sum_1); y = sum_1 = None
sum_2 = torch.ops.aten.sum.dim_IntList(full, []); full = None
add_1 = torch.ops.aten.add.Tensor(add, sum_2); add = sum_2 = None
return (add_1,)""",
)

def test_nested_dynamic_shapes_spec(self):
class Foo(torch.nn.Module):
def forward(self, x):
Expand Down
46 changes: 27 additions & 19 deletions torch/_export/passes/replace_set_grad_with_hop_pass.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import copy

import torch
Expand Down Expand Up @@ -125,7 +126,9 @@ def _remove_set_grad_and_inline(node: torch.fx.Node):
node_inline_(node)


def _sequential_split_and_maybe_inline_subgraphs(gm: torch.fx.GraphModule):
def _sequential_split_and_maybe_inline_subgraphs(
gm: torch.fx.GraphModule, graph_signature
):
"""
Helper function for replace_set_grad_with_hop_pass().
Split the graph module into multiple subgraphs based on the set_grad_enabled nodes.
Expand All @@ -141,35 +144,40 @@ def _sequential_split_and_maybe_inline_subgraphs(gm: torch.fx.GraphModule):
if need_replacing:
new_gm = sequential_split(gm, _is_set_grad_enabled_node)

def _maybe_inline_or_replace_with_hop(node: torch.fx.Node):
if _is_set_grad_enabled_sub_mod(node, omit_if_same_with_ambient=True):
_replace_with_hop(node)
else:
_remove_set_grad_and_inline(node)

nodes_map(
list(new_gm.graph.nodes),
lambda node: (
_maybe_inline_or_replace_with_hop(node)
if node.op == "call_module"
else node
),
)
replace_ctx = contextlib.nullcontext()
if graph_signature is not None:
replace_ctx = new_gm._set_replace_hook(graph_signature.get_replace_hook()) # type: ignore[assignment]

with replace_ctx:

def _maybe_inline_or_replace_with_hop(node: torch.fx.Node):
if _is_set_grad_enabled_sub_mod(node, omit_if_same_with_ambient=True):
_replace_with_hop(node)
else:
_remove_set_grad_and_inline(node)

nodes_map(
list(new_gm.graph.nodes),
lambda node: (
_maybe_inline_or_replace_with_hop(node)
if node.op == "call_module"
else node
),
)
return new_gm

return gm


def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule):
new_gm = _sequential_split_and_maybe_inline_subgraphs(gm)

def replace_set_grad_with_hop_pass(gm: torch.fx.GraphModule, graph_signature):
new_gm = _sequential_split_and_maybe_inline_subgraphs(gm, graph_signature)
# recursively call
for node in new_gm.graph.nodes:
if node.op == "get_attr":
subgm = getattr(new_gm, node.target)
if not isinstance(subgm, torch.fx.GraphModule):
continue
new_subgm = replace_set_grad_with_hop_pass(subgm)
new_subgm = replace_set_grad_with_hop_pass(subgm, None)
setattr(new_gm, node.target, new_subgm)

new_gm.recompile()
Expand Down
106 changes: 57 additions & 49 deletions torch/export/_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,6 +467,7 @@ def _export_non_strict(
*,
transform=lambda x: x, # TODO(zhxchen17) Revisit if this is needed later.
pre_dispatch=False,
should_insert_runtime_assertion=False,
):
# [NOTE] If the user is exporting under training mode, we want to detect if there is any
# state change in the autograd global state and error. If the user is exporting under inference
Expand Down Expand Up @@ -508,41 +509,6 @@ def _compiling_state_context():
if isinstance(mod, torch.fx.GraphModule) and hasattr(mod, "meta"):
gm.meta.update(mod.meta)

if pre_dispatch:
from torch._export.passes.replace_set_grad_with_hop_pass import (
replace_set_grad_with_hop_pass,
)

gm = replace_set_grad_with_hop_pass(gm)

# Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes.
for _mod in gm.modules():
if not isinstance(_mod, torch.fx.GraphModule):
continue
for node in _mod.graph.nodes:
if node.op in ["placeholder", "output"]:
node.meta.pop("nn_module_stack", None)
node.meta.pop("stack_trace", None)

# NOTE: aot_export adds symint metadata for placeholders with int values;
# since these become specialized, we replace such metadata with the original values
flat_args = pytree.tree_leaves((fake_args, fake_kwargs))
index = 0
total_non_user_inputs = (
len(graph_signature.parameters)
+ len(graph_signature.buffers)
+ len(graph_signature.input_tokens)
)
for node in gm.graph.nodes:
if node.op == "placeholder":
if index >= total_non_user_inputs:
user_arg = flat_args[index - total_non_user_inputs]
if not isinstance(user_arg, torch.Tensor):
node.meta["val"] = user_arg
index += 1

is_joint = graph_signature.backward_signature is not None

def make_argument_spec(i, node) -> ArgumentSpec:
if isinstance(node, (int, bool, float, type(None))):
# For const outputs we just directly return this
Expand Down Expand Up @@ -571,6 +537,25 @@ def make_argument_spec(i, node) -> ArgumentSpec:
f"while writing the metadata for exported program"
)

is_joint = graph_signature.backward_signature is not None

# NOTE: aot_export adds symint metadata for placeholders with int values;
# since these become specialized, we replace such metadata with the original values
flat_args = pytree.tree_leaves((fake_args, fake_kwargs))
index = 0
total_non_user_inputs = (
len(graph_signature.parameters)
+ len(graph_signature.buffers)
+ len(graph_signature.input_tokens)
)
for node in gm.graph.nodes:
if node.op == "placeholder":
if index >= total_non_user_inputs:
user_arg = flat_args[index - total_non_user_inputs]
if not isinstance(user_arg, torch.Tensor):
node.meta["val"] = user_arg
index += 1

input_specs, output_specs = _sig_to_specs(
user_inputs=set(graph_signature.user_inputs),
inputs_to_parameters=graph_signature.inputs_to_parameters, # type: ignore[arg-type]
Expand Down Expand Up @@ -599,6 +584,41 @@ def make_argument_spec(i, node) -> ArgumentSpec:
input_specs=input_specs, output_specs=output_specs
)

from torch._guards import detect_fake_mode

fake_mode = detect_fake_mode(flat_args)

if should_insert_runtime_assertion:
stack_trace = (
'File "torch/fx/passes/runtime_assert.py", line 24, '
"in insert_deferred_runtime_asserts"
)
with gm._set_create_node_hook(
functools.partial(_node_metadata_hook, stack_trace=stack_trace)
):
insert_deferred_runtime_asserts(
gm,
fake_mode.shape_env,
f"non strict exported program: {first_call_function_nn_module_stack(gm.graph)}",
export=True,
)

if pre_dispatch:
from torch._export.passes.replace_set_grad_with_hop_pass import (
replace_set_grad_with_hop_pass,
)

gm = replace_set_grad_with_hop_pass(gm, export_graph_signature)

# Remove nn_module_stack, stack_trace metadata from all placeholders/inputs nodes.
for _mod in gm.modules():
if not isinstance(_mod, torch.fx.GraphModule):
continue
for node in _mod.graph.nodes:
if node.op in ["placeholder", "output"]:
node.meta.pop("nn_module_stack", None)
node.meta.pop("stack_trace", None)

constants = rewrite_script_object_meta(gm)
constants.update(lift_constants_pass(gm, export_graph_signature, constant_attrs))

Expand Down Expand Up @@ -1040,6 +1060,7 @@ def forward(self, *args, **kwargs):
new_fake_constant_attrs,
pre_dispatch=pre_dispatch,
transform=_tuplify_outputs,
should_insert_runtime_assertion=not strict,
)
# ep_non_strict.constants contains only fake script objects, we need to map them back
ep_non_strict.constants = {
Expand All @@ -1049,20 +1070,6 @@ def forward(self, *args, **kwargs):
for fqn, obj in ep_non_strict.constants.items()
}

stack_trace = (
'File "torch/fx/passes/runtime_assert.py", line 24, '
"in insert_deferred_runtime_asserts"
)
with ep_non_strict.gm._set_create_node_hook(
functools.partial(_node_metadata_hook, stack_trace=stack_trace)
):
insert_deferred_runtime_asserts(
ep_non_strict.gm,
fake_mode.shape_env,
f"non strict exported program: {first_call_function_nn_module_stack(ep_non_strict.gm.graph)}",
export=True,
)

ep_non_strict.gm.meta["inline_constraints"] = {
k: v
for k, v in fake_mode.shape_env.var_to_range.items()
Expand Down Expand Up @@ -1246,6 +1253,7 @@ def forward(self, *args, **kwargs):
fake_params_buffers,
constant_attrs,
pre_dispatch=pre_dispatch,
should_insert_runtime_assertion=not strict,
)

gm = ep_non_strict.gm
Expand Down
18 changes: 0 additions & 18 deletions torch/export/exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,10 +530,6 @@ def run_decompositions(
For now, we do not decompose joint graphs.
"""
from torch._decomp import core_aten_decompositions
from torch._export.passes._node_metadata_hook import _node_metadata_hook
from torch._export.passes.add_runtime_assertions_for_constraints_pass import (
_AddRuntimeAssertionsForInlineConstraintsPass,
)
from torch._export.passes.lift_constants_pass import (
ConstantAttrMap,
lift_constants_pass,
Expand Down Expand Up @@ -663,20 +659,6 @@ def update_arg(old_arg, new_ph):

_replace_sym_size_ops_pass(gm)

if len(new_range_constraints) > 0:
stack_trace = (
'File "torch/_export/passes/add_runtime_assertions_for_constraints_pass.py", line 46, '
"in _AddRuntimeAssertionsForInlineConstraintsPass"
)
with gm._set_create_node_hook(
functools.partial(_node_metadata_hook, stack_trace=stack_trace)
):
res = _AddRuntimeAssertionsForInlineConstraintsPass(
new_range_constraints
)(gm)
assert res is not None
gm = res.graph_module

exported_program = ExportedProgram(
root=gm,
graph=gm.graph,
Expand Down
Loading

0 comments on commit 0e419b9

Please sign in to comment.