Skip to content

Commit

Permalink
[aotd] Fix rrelu compilation (pytorch#136008)
Browse files Browse the repository at this point in the history
Issues:
pytorch#135083
pytorch#120292

rrelu decomposition contains mutation, copy_. Decompositions are executed below Functionalization, as a result AOT produces non-functional graph.

Also that decomposition is registered as python_dispatch kernel for AutogradCUDA.
Autograd dispatch happens above Functionalization, so registering it for Autograd to handle all backends makes functionalization running after this.

Testing:
```
python test/functorch/test_aotdispatch.py -k test_rrelu
```

Pull Request resolved: pytorch#136008
Approved by: https://github.com/bdhirsh
  • Loading branch information
IvanKobzarev authored and pytorchmergebot committed Sep 25, 2024
1 parent c3fdf58 commit 370c1c4
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 10 deletions.
13 changes: 13 additions & 0 deletions test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -5989,6 +5989,18 @@ def forward(self, x):
with torch.no_grad():
torch.compile(m, fullgraph=True)(inp)

def test_rrelu(self):
def fn(x):
return torch.rrelu(x, training=True)

def fn_(x):
torch.rrelu_(x, training=True)
return x

x = torch.randn(4, 4)
torch.compile(fn, backend="inductor", fullgraph=True)(x)
torch.compile(fn_, backend="inductor", fullgraph=True)(x)


# entries in here don't work and need to be fixed.
# Each one of these is a bug (or needs to be investigated)
Expand Down Expand Up @@ -6141,6 +6153,7 @@ def _test_aot_autograd_helper(self, device, dtype, op, dynamic=False):
self.assertEqual,
check_gradients=True,
try_check_data_specialization=try_check_data_specialization,
skip_correctness_check=op.skip_correctness_check_compile_vs_eager,
)
except DynamicOutputShapeException:
self.skipTest("Dynamic output shape operation in trace")
Expand Down
1 change: 0 additions & 1 deletion test/inductor/test_torchinductor_opinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,7 +227,6 @@ def format_op(op):
"nn.functional.avg_pool2d": {i64},
"nn.functional.avg_pool3d": {i64},
"nn.functional.local_response_norm": {i64},
"nn.functional.rrelu": {f32, f64},
"nonzero_static": {b8, f16, f32, f64, i32, i64},
("normal", "in_place"): {f16, f32, f64},
("normal", "number_mean"): {f16, f32, f64},
Expand Down
4 changes: 2 additions & 2 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def _prelu_kernel_backward(


@register_decomposition(aten.rrelu_with_noise)
@aten.rrelu_with_noise.default.py_impl(DispatchKey.AutogradCUDA)
@aten.rrelu_with_noise.default.py_impl(DispatchKey.Autograd)
@out_wrapper()
@pw_cast_for_opmath
def rrelu_with_noise(
Expand All @@ -330,7 +330,7 @@ def rrelu_with_noise(


@register_decomposition(aten.rrelu_with_noise_)
@aten.rrelu_with_noise_.default.py_impl(DispatchKey.AutogradCUDA)
@aten.rrelu_with_noise_.default.py_impl(DispatchKey.Autograd)
@pw_cast_for_opmath
def rrelu_with_noise_(
self: Tensor,
Expand Down
4 changes: 3 additions & 1 deletion torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -16149,7 +16149,9 @@ def sample_inputs_alias_copy(op_info, device, dtype, requires_grad, **kwargs):
DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_contig_vs_every_other'),
DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_non_contig_expand'),
DecorateInfo(unittest.skip("Different noise"), 'TestUnaryUfuncs', 'test_contig_vs_transposed'),
DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'))),
DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu')),
skip_correctness_check_compile_vs_eager=True,
),
UnaryUfuncInfo(
'nn.functional.selu',
ref=lambda x, inplace=False:
Expand Down
2 changes: 2 additions & 0 deletions torch/testing/_internal/opinfo/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,8 @@ class OpInfo:

is_factory_function: bool = False

skip_correctness_check_compile_vs_eager: bool = False

def __post_init__(self):
self._original_opinfo_args = asdict(self).copy()

Expand Down
15 changes: 9 additions & 6 deletions torch/testing/_internal/optests/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def aot_autograd_check(
assert_raises_regex_fn=assert_raises_regex,
assert_equals_fn=torch.testing._comparison.assert_close,
check_gradients=True,
try_check_data_specialization=False):
try_check_data_specialization=False,
skip_correctness_check=False):
"""Compares func(*args, **kwargs) in eager-mode to under AOTAutograd.
Compares outputs and (if check_gradients=True) gradients produced by
Expand Down Expand Up @@ -73,11 +74,12 @@ def func_no_tensors(args):
check_gradients = any_tensor_requires_grad and any_output_requires_grad
if not check_gradients:
compiled_out = wrapper_set_seed(compiled_f, args)
assert_equals_fn(compiled_out, out, msg=outputs_msg)
if not skip_correctness_check:
assert_equals_fn(compiled_out, out, msg=outputs_msg)
return
_test_aot_autograd_forwards_backwards_helper(
func_no_tensors, compiled_f, args, assert_raises_regex_fn, assert_equals_fn,
try_check_data_specialization)
try_check_data_specialization, skip_correctness_check)

outputs_msg = (
"Outputs of the operator are different in eager-mode PyTorch vs "
Expand All @@ -89,7 +91,7 @@ def func_no_tensors(args):

def _test_aot_autograd_forwards_backwards_helper(
f, compiled_f, args, assert_raises_regex_fn, assert_equals_fn,
try_check_data_specialization):
try_check_data_specialization, skip_correctness_check=False):
# Verify grads are equal between compiled and non-compiled versions of f.

def call_forwards_backwards(f, args):
Expand Down Expand Up @@ -134,8 +136,9 @@ def check(args, ignore_failure=False):
)

compiled_out, compiled_grad = call_forwards_backwards(compiled_f, args)
assert_equals_fn(compiled_out, orig_out, msg=outputs_msg)
assert_equals_fn(compiled_grad, orig_grad, msg=msg)
if not skip_correctness_check:
assert_equals_fn(compiled_out, orig_out, msg=outputs_msg)
assert_equals_fn(compiled_grad, orig_grad, msg=msg)

check(args, ignore_failure=False)

Expand Down

0 comments on commit 370c1c4

Please sign in to comment.