Skip to content

Commit

Permalink
Revert "[Dynamo] Trace torch function modes entered outside of torch.…
Browse files Browse the repository at this point in the history
…compile (pytorch#133137)"

This reverts commit 4528777.

Reverted pytorch#133137 on behalf of https://github.com/mlazos due to broke python test/quantization/pt2e/test_numeric_debugger.py TestNumericDebugger.test_re_export_preserve_handle modified yesterday ([comment](pytorch#134732 (comment)))
  • Loading branch information
pytorchmergebot committed Sep 14, 2024
1 parent 46f5037 commit 8c8a308
Show file tree
Hide file tree
Showing 14 changed files with 204 additions and 457 deletions.
135 changes: 0 additions & 135 deletions test/dynamo/test_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,6 @@
from torch.utils._python_dispatch import TorchDispatchMode


class TestMode(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}

if func == torch.add:
return torch.zeros(2, 2)

return super().__torch_function__(func, types, args, kwargs)


class TorchDispatchModeTests(torch._dynamo.test_case.TestCase):
@classmethod
def setUpClass(cls):
Expand Down Expand Up @@ -335,130 +324,6 @@ def fn(x):
fn(inp)
self.assertEqual(cnt.frame_count, 2)

def test_nested_torch_function_mode(self):
mode_1_called = False
mode_2_called = False

def reset_state():
nonlocal mode_1_called
nonlocal mode_2_called
mode_1_called = False
mode_2_called = False

ones = torch.ones(2, 2)
zeros = torch.zeros(2, 2)

class TestMode1(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}

nonlocal mode_1_called

mode_1_called = True

if func == torch.add:
return zeros

return super().__torch_function__(func, types, args, kwargs)

class TestMode2(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}

nonlocal mode_2_called

mode_2_called = True

if func == torch.mul:
return ones

return super().__torch_function__(func, types, args, kwargs)

def fn(x):
return torch.add(x, 3)

def fn_2(x):
return torch.mul(x, 3) + torch.add(x, 3)

inp = torch.ones(2, 2) + 1

for fn_i in [fn, fn_2]:
fn_opt = torch.compile(fn_i, fullgraph=True)
with TestMode1(), TestMode2():
expected = fn_i(inp), mode_1_called, mode_2_called
reset_state()
actual = fn_opt(inp), mode_1_called, mode_2_called
reset_state()

self.assertEqual(expected, actual)

def test_torch_function_mode_disable(self):
class TestSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}
if func == torch.add:
return torch.ones(2, 2)
return super().__torch_function__(func, types, args, kwargs)

class TestMode(BaseTorchFunctionMode):
def __torch_function__(self, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}

if func == torch.add:
return torch.zeros(2, 2)

return super().__torch_function__(func, types, args, kwargs)

def fn(x):
return torch.add(x, 3)

inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass)

fn_opt = torch.compile(fn, fullgraph=True)
with TestMode(), torch._dynamo.config.patch(
"traceable_tensor_subclasses", {TestSubclass}
):
with torch._C.DisableTorchFunctionSubclass():
expected = fn(inp)
actual = fn_opt(inp)

self.assertEqual(expected, actual)

with torch._C.DisableTorchFunction():
expected = fn(inp)
actual = fn_opt(inp)

self.assertEqual(expected, actual)

def test_torch_function_mode_highest_priority(self):
class TestSubclass(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args, kwargs=None):
if not kwargs:
kwargs = {}
if func == torch.add:
return torch.ones(2, 2)
return super().__torch_function__(func, types, args, kwargs)

def fn(x):
return torch.add(x, 3)

inp = (torch.ones(2, 2) + 1).as_subclass(TestSubclass)

fn_opt = torch.compile(fn, fullgraph=True)
with TestMode(), torch._dynamo.config.patch(
"traceable_tensor_subclasses", {TestSubclass}
):
expected = fn(inp)
actual = fn_opt(inp)

self.assertEqual(expected, actual)


if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
Expand Down
4 changes: 0 additions & 4 deletions test/functorch/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
parametrize,
requires_cuda,
run_tests,
skipIfCrossRef,
skipIfRocm,
skipIfTorchDynamo,
TEST_WITH_TORCHDYNAMO,
Expand Down Expand Up @@ -2883,7 +2882,6 @@ def f(fct, init, xs):
gm = make_fx(f, tracing_mode="symbolic")(add_wrong_dtype, init, x)

@skipIfNoDynamoSupport
@skipIfCrossRef # Arg order changes with crossref
def test_scan_simple_graph(self):
from torch._dynamo.testing import EagerAndRecordGraphs

Expand Down Expand Up @@ -2990,7 +2988,6 @@ def f(x, y):
self.assertEqual(graph(x, torch.tensor(True)), f(x, torch.tensor(True)))

@skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
@skipIfCrossRef # Arg order changes with crossref
def test_cond_simple_with_linear_compile_check_graph(self):
from torch._dynamo.testing import EagerAndRecordGraphs

Expand Down Expand Up @@ -3253,7 +3250,6 @@ def test_while_loop_compile(self, backend, while_loop_test):
self._check_compile(fn, inp, backend=backend)

@skipIfTorchDynamo("Graph is not captured by backend if test with dynamo")
@skipIfCrossRef # Arg order changes with cross ref
def test_while_loop_simple_with_linear_compile_check_graph(self):
fn, inp = WHILE_LOOP_TESTS["simple_with_linear"]
from torch._dynamo.testing import EagerAndRecordGraphs
Expand Down
4 changes: 1 addition & 3 deletions test/quantization/pt2e/test_metadata_porting.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from torch.ao.quantization.quantizer.xnnpack_quantizer_utils import OP_TO_ANNOTATOR
from torch.fx import Node
from torch.testing._internal.common_quantization import QuantizationTestCase
from torch.testing._internal.common_utils import IS_WINDOWS, skipIfCrossRef
from torch.testing._internal.common_utils import IS_WINDOWS


class TestHelperModules:
Expand Down Expand Up @@ -139,8 +139,6 @@ def _test_metadata_porting(
self.assertEqual(v, node_tags[k])
return m

@skipIfCrossRef # mlazos: retracing FX graph with torch function mode doesn't propagate metadata, because the stack
# trace of the mode torch function impl doesn't match the traced graph stored lineno.
def test_simple_metadata_porting(self):
"""
Model under test
Expand Down
5 changes: 0 additions & 5 deletions torch/_dynamo/convert_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,10 +605,6 @@ def _compile(
output: Optional[OutputGraph] = None
tracer: Optional[InstructionTranslator] = None

tf_mode_stack: List[
torch.overrides.TorchFunctionMode
] = torch.overrides._get_current_function_mode_stack()

@preserve_global_state
def transform(
instructions: List[Instruction], code_options: Dict[str, object]
Expand All @@ -622,7 +618,6 @@ def transform(
locals,
globals,
builtins,
tf_mode_stack,
code_options,
compiler_fn,
one_graph,
Expand Down
14 changes: 0 additions & 14 deletions torch/_dynamo/guards.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,6 @@
ScriptObjectQualifiedNameSource,
ShapeEnvSource,
SubclassAttrListSource,
TorchFunctionModeStackSource,
TupleIteratorGetItemSource,
TypeSource,
UnspecializedBuiltinNNModuleSource,
Expand All @@ -112,7 +111,6 @@
dict_keys_repr,
get_custom_getattr,
get_torch_function_mode_stack,
get_torch_function_mode_stack_at,
guard_failures,
istype,
key_is_id,
Expand Down Expand Up @@ -316,7 +314,6 @@ def uninteresting_files():
"___dict_contains": lambda a, b: a in b,
"___tuple_iterator_len": tuple_iterator_len,
"___tuple_iterator_getitem": tuple_iterator_getitem,
"___get_torch_function_mode_stack_at": get_torch_function_mode_stack_at,
"__math_isnan": math.isnan,
"__numpy_isnan": None if np is None else np.isnan,
"inf": float("inf"),
Expand Down Expand Up @@ -904,15 +901,6 @@ def get_guard_manager_from_source(self, source):
):
assert base_guard_manager # to make mypy happy
out = base_guard_manager
elif istype(source, TorchFunctionModeStackSource):
out = root_guard_manager.lambda_manager(
python_lambda=lambda _: get_torch_function_mode_stack_at(
source._get_index()
),
source=source_name,
example_value=example_value,
guard_manager_enum=guard_manager_enum,
)
elif istype(source, GradSource):
assert base_guard_manager # to make mypy happy
out = base_guard_manager.grad_manager(
Expand Down Expand Up @@ -2226,8 +2214,6 @@ def __init__(
self.output_graph = output_graph
w_builder = None

# NB: Until we trace device contexts, we need to use the stack recorded at the beginning of tracing
# in case a set default device call was made in the graph.
self.torch_function_mode_stack = (
output_graph.torch_function_mode_stack if output_graph else None
)
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/source.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,7 +619,7 @@ class TorchFunctionModeStackSource(Source):
ind: int

def name(self):
return f"___get_torch_function_mode_stack_at({self._get_index()})"
return ""

def _get_index(self):
from .variables.torch_function import TorchFunctionModeStackVariable
Expand Down
Loading

0 comments on commit 8c8a308

Please sign in to comment.