Skip to content

Commit

Permalink
[export] Clean up verifier [1/n]. (pytorch#112505)
Browse files Browse the repository at this point in the history
Summary: Some adjustments to verifier so that it's easier to use it correctly. We will enable verifier later, so the current diff is no-op.

Test Plan: CI

Differential Revision: D50839295

Pull Request resolved: pytorch#112505
Approved by: https://github.com/tugsbayasgalan, https://github.com/angelayi
  • Loading branch information
zhxchen17 authored and pytorchmergebot committed Nov 2, 2023
1 parent 8198474 commit 50767a0
Show file tree
Hide file tree
Showing 9 changed files with 173 additions and 173 deletions.
3 changes: 3 additions & 0 deletions docs/source/export.rst
Original file line number Diff line number Diff line change
Expand Up @@ -577,3 +577,6 @@ API Reference
.. autoclass:: InputSpec
.. autoclass:: OutputKind
.. autoclass:: OutputSpec
.. autoclass:: ExportGraphSignature

.. automethod:: replace_all_uses
45 changes: 13 additions & 32 deletions test/export/test_verifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from torch._dynamo.eval_frame import is_dynamo_supported
from torch._export import export

from torch._export.verifier import ATenDialectVerifier, SpecViolationError, Verifier
from torch._export.verifier import SpecViolationError, Verifier
from torch.export.exported_program import InputKind, InputSpec, TensorArgument
from torch.testing._internal.common_utils import run_tests, TestCase

Expand All @@ -20,7 +20,7 @@ def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
ep = export(f, (torch.randn(100), torch.randn(100)))

verifier = Verifier()
verifier(ep.graph_module)
verifier.check(ep)

def test_verifier_call_module(self) -> None:
class M(torch.nn.Module):
Expand All @@ -35,7 +35,7 @@ def forward(self, x: Tensor) -> Tensor:

verifier = Verifier()
with self.assertRaises(SpecViolationError):
verifier(gm)
verifier._check_graph_module(gm)

def test_verifier_no_functional(self) -> None:
def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
Expand All @@ -48,7 +48,7 @@ def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

verifier = Verifier()
with self.assertRaises(SpecViolationError):
verifier(ep.graph_module)
verifier.check(ep)

def test_verifier_higher_order(self) -> None:
def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
Expand All @@ -65,7 +65,7 @@ def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
ep = export(f, (torch.randn(3, 3), torch.randn(3, 3)))

verifier = Verifier()
verifier(ep.graph_module)
verifier.check(ep)

def test_verifier_nested_invalid_module(self) -> None:
def f(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
Expand All @@ -86,22 +86,7 @@ def false_fn(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:

verifier = Verifier()
with self.assertRaises(SpecViolationError):
verifier(ep.graph_module)

def test_aten_verifier_wrong_op(self) -> None:
class TestModel(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, x):
return torch.ops.aten._add_relu(x, x)

m = TestModel()
egm = torch.fx.symbolic_trace(m)
verifier = ATenDialectVerifier()
with self.assertRaises(SpecViolationError):
verifier(egm)
self.assertFalse(verifier.is_valid(egm))
verifier.check(ep)

def test_ep_verifier_basic(self) -> None:
class M(torch.nn.Module):
Expand Down Expand Up @@ -219,17 +204,13 @@ def forward(self, x1, x2):
ep = export(M(), (torch.tensor(5.0), torch.tensor(6.0)))

output_node = list(ep.graph.nodes)[-1]
with ep.graph.inserting_before(output_node):
additional_output_node = ep.graph.call_function(
torch.add, args=(output_node.args[0][0], output_node.args[0][0])
)
output_node.args = (
(
output_node.args[0][0],
additional_output_node,
output_node.args[0][1],
),
)
output_node.args = (
(
output_node.args[0][0],
list(ep.graph.nodes)[0],
output_node.args[0][1],
),
)

with self.assertRaisesRegex(SpecViolationError, "Number of output nodes"):
ep._validate()
Expand Down
9 changes: 5 additions & 4 deletions torch/_export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
)
from .passes.lift_constant_tensor_pass import lift_constant_tensor_pass
from .passes.remove_runtime_assertions import _RemoveRuntimeAssertionsPass
from .passes.replace_sym_size_ops_pass import _ReplaceSymSizeOpPass
from .passes.replace_sym_size_ops_pass import _replace_sym_size_ops_pass
from .passes.replace_view_ops_with_view_copy_ops_pass import (
ReplaceViewOpsWithViewCopyOpsPass,
)
Expand Down Expand Up @@ -804,11 +804,13 @@ def make_argument_spec(node) -> ArgumentSpec:
}

if len(preserve_module_call_signature) > 0:
res = CollectTracepointsPass(module_call_signatures)(gm)
res = CollectTracepointsPass(module_call_signatures, export_graph_signature)(gm)
assert res is not None
gm = res.graph_module

assert orig_out_spec is not None
lift_constant_tensor_pass(gm, export_graph_signature, params_buffers)
_replace_sym_size_ops_pass(gm)
exported_program = ExportedProgram(
gm,
gm.graph,
Expand All @@ -826,9 +828,8 @@ def make_argument_spec(node) -> ArgumentSpec:
exported_program = exported_program._transform(
_AddRuntimeAssertionsForInlineConstraintsPass(range_constraints, equality_constraints)
)
exported_program = lift_constant_tensor_pass(exported_program)

return exported_program._transform(_ReplaceSymSizeOpPass())
return exported_program


def _reorder_kwargs_by_names(arg_names: List[str], args: Tuple[Any], kwargs: Dict[str, Any]):
Expand Down
4 changes: 3 additions & 1 deletion torch/_export/passes/collect_tracepoints_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@ class CollectTracepointsPass(PassBase):
Performs constant folding and constant propagation.
"""

def __init__(self, specs) -> None:
def __init__(self, specs, sig) -> None:
super().__init__()
self.specs = specs
self.sig = sig

def call(self, gm):
def get_arg_spec(arg):
Expand Down Expand Up @@ -55,6 +56,7 @@ def get_arg_spec(arg):
assert isinstance(user.args[1], int)
if user.args[1] == i:
user.replace_all_uses_with(arg)
self.sig.replace_all_uses(user.name, arg.name)
break
users = list(node.users)
for user in users:
Expand Down
28 changes: 13 additions & 15 deletions torch/_export/passes/lift_constant_tensor_pass.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,47 @@
from torch.export.exported_program import InputKind, InputSpec, TensorArgument


def lift_constant_tensor_pass(ep):
def lift_constant_tensor_pass(gm, graph_signature, state_dict):
"""
Takes an ExportedProgram and returns the ExportedProgram modified in-place,
with the constant tensors as buffers.
"""
if len([node for node in ep.graph.nodes if node.op == "placeholder"]) == 0:
return ep
if len([node for node in gm.graph.nodes if node.op == "placeholder"]) == 0:
return

graph_signature = ep.graph_signature
buffers = graph_signature.buffers

fake_mode = detect_fake_mode(
tuple(node.meta["val"] for node in ep.graph.nodes if node.op == "placeholder")
tuple(node.meta["val"] for node in gm.graph.nodes if node.op == "placeholder")
)
assert fake_mode is not None

first_user_input = None
lifted_buffers = []
for node in ep.graph.nodes:
for node in gm.graph.nodes:
if node.op == "placeholder" and node.name in graph_signature.user_inputs:
first_user_input = node
break

for node in ep.graph.nodes:
for node in gm.graph.nodes:
if node.op == "get_attr":
constant_tensor = getattr(ep.graph_module, node.target)
constant_tensor = getattr(gm, node.target)
if not isinstance(constant_tensor, torch.Tensor):
continue

constant_tensor_fqn = f"_lifted_tensor_constant{len(buffers)}"

with ep.graph.inserting_before(first_user_input):
with gm.graph.inserting_before(first_user_input):
# Insert the constant node before the first user input
const_placeholder_node = ep.graph.placeholder(constant_tensor_fqn)
const_placeholder_node = gm.graph.placeholder(constant_tensor_fqn)
for k, v in node.meta.items():
const_placeholder_node.meta[k] = v
const_placeholder_node.meta["val"] = fake_mode.from_tensor(
constant_tensor, static_shapes=True
)
const_placeholder_node.meta["val"].constant = constant_tensor
node.replace_all_uses_with(const_placeholder_node)
ep.graph.erase_node(node)
gm.graph.erase_node(node)

# Add the constant as a buffer to the graph signature
lifted_buffers.append(
Expand All @@ -55,14 +54,13 @@ def lift_constant_tensor_pass(ep):
)
)
buffers.append(constant_tensor_fqn)
ep.state_dict[constant_tensor_fqn] = constant_tensor
state_dict[constant_tensor_fqn] = constant_tensor

new_input_specs = []
for s in graph_signature.input_specs:
if s.kind == InputKind.USER_INPUT and len(lifted_buffers) > 0:
new_input_specs.extend(lifted_buffers)
lifted_buffers.clear()
new_input_specs.append(s)
ep.graph_signature.input_specs = new_input_specs
ep.graph_module.recompile()
return ep
graph_signature.input_specs = new_input_specs
gm.recompile()
24 changes: 7 additions & 17 deletions torch/_export/passes/replace_sym_size_ops_pass.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from typing import Dict

import torch
from torch.fx.passes.infra.pass_base import PassBase, PassResult

replacements: Dict[torch._ops.OpOverloadPacket, torch._ops.OpOverload] = {
torch.ops.aten.sym_size: torch.ops.aten.sym_size.int,
Expand All @@ -10,19 +9,10 @@
}


class _ReplaceSymSizeOpPass(PassBase):
"""
Replace torch.ops.aten.sym_size with torch.ops.aten.sym_size.int
and torch.ops.aten.sym_stride with torch.ops.aten.sym_stride.int
"""

def call(self, graph_module) -> PassResult:
modified = False
for module in graph_module.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
for node in module.graph.nodes:
if node.target in replacements:
node.target = replacements[node.target]
modified = True
return PassResult(graph_module, modified)
def _replace_sym_size_ops_pass(gm: torch.fx.GraphModule):
for module in gm.modules():
if not isinstance(module, torch.fx.GraphModule):
continue
for node in module.graph.nodes:
if node.target in replacements:
node.target = replacements[node.target]
Loading

0 comments on commit 50767a0

Please sign in to comment.