Skip to content

Commit

Permalink
[export] Disable backend decomps for capture_pre_autograd (pytorch#12…
Browse files Browse the repository at this point in the history
…7120)

Differential Revision: D57785713

Pull Request resolved: pytorch#127120
Approved by: https://github.com/ydwu4
  • Loading branch information
angelayi authored and pytorchmergebot committed May 28, 2024
1 parent c404088 commit cbb79a2
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 2 deletions.
11 changes: 11 additions & 0 deletions test/export/test_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,17 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256))
self.assertEqual(exported_program.module()(*args), m(*args))

from torch._export import capture_pre_autograd_graph

gm: torch.fx.GraphModule = capture_pre_autograd_graph(
m, args=example_args, dynamic_shapes=dynamic_shapes
)

args = (torch.randn(17, 3, 256, 256), torch.ones(17, 32, 256, 256))
self.assertEqual(gm(*args), m(*args))
args = (torch.randn(15, 3, 256, 256), torch.ones(15, 32, 256, 256))
self.assertEqual(gm(*args), m(*args))

def test_basic_non_strict_real_tensor(self):
class Basic(torch.nn.Module):
def __init__(self):
Expand Down
4 changes: 2 additions & 2 deletions torch/_export/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def capture_pre_autograd_graph(
An nn.Module containing the traced method.
"""
from torch.export._trace import _convert_input_to_fake, DEFAULT_EXPORT_DYNAMO_CONFIG
from torch.export._trace import _convert_input_to_fake, DEFAULT_EXPORT_DYNAMO_CONFIG, _ignore_backend_decomps
from torch._utils_internal import export_api_rollout_check

capture_pre_autograd_graph_warning()
Expand All @@ -165,7 +165,7 @@ def print_export_warning():
for op in FunctionalTensor.maybe_aliasing_or_mutating_ops
if op != torch.ops.aten.dropout.default
}
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)):
with torch._dynamo.config.patch(dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG)), _ignore_backend_decomps():
m = torch._dynamo.export(
f,
dynamic_shapes=dynamic_shapes,
Expand Down

0 comments on commit cbb79a2

Please sign in to comment.