Skip to content

Commit

Permalink
[export] Skip noop runtime assertion pass. (pytorch#109395)
Browse files Browse the repository at this point in the history
Summary:
If there's no inline constraints added, just return the original graph.
We want to do this because sometimes this pass mess up the node names,
before we actually fix this, we could make the behavior a bit less buggy
by skipping noop passes.

Test Plan:

Reviewers:

Subscribers:

Tasks:

Tags:

Fixes #ISSUE_NUMBER

Pull Request resolved: pytorch#109395
Approved by: https://github.com/angelayi
  • Loading branch information
zhxchen17 authored and pytorchmergebot committed Sep 18, 2023
1 parent 550b0ec commit 6f4b9cc
Showing 1 changed file with 7 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(
super().__init__()
self.range_constraints: Dict[sympy.Symbol, RangeConstraint] = range_constraints
self.equality_constraints: List[Tuple[InputDim, InputDim]] = equality_constraints
self.counter = 0

def _assert_range_constraint(self, proxy, lower, upper, assert_msg):
if lower > -math.inf:
Expand All @@ -72,6 +73,7 @@ def _insert_assert_async(self, operator, lower, upper, assert_msg):
Inserts assert_async call_function nodes in the graph. This function is
called **during** the interpreter-based pass.
"""
self.counter += 1
cmp = super().call_operator(operator, (lower, upper), {}, self._create_dummy_node_metadata())
cmp_tensor = super().call_operator(torch.ops.aten.scalar_tensor.default, (cmp,), {}, self._create_dummy_node_metadata())
super().call_operator(
Expand Down Expand Up @@ -139,6 +141,11 @@ def call(self, graph_module):
# Add runtime asserts for inline constraints
val = super().call(graph_module)

# Sometimes this pass would return a wrong graph where we have mismatched
# node names in signature. Before we fix it, let's just skip it.
if self.counter == 0 and type(self) is _AddRuntimeAssertionsForInlineConstraintsPass:
return PassResult(graph_module, False)

# Populate the stack trace with dummy vals to respect IR
for node in val.graph_module.graph.nodes:
if not node.meta.get("stack_trace", None):
Expand Down

0 comments on commit 6f4b9cc

Please sign in to comment.