Skip to content

Commit

Permalink
[pt2] test if core decomps are differentiable (pytorch#107241)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#107241
Approved by: https://github.com/ezyang
  • Loading branch information
nkaretnikov authored and pytorchmergebot committed Aug 18, 2023
1 parent 5b7b9e7 commit 77f080e
Showing 1 changed file with 71 additions and 14 deletions.
85 changes: 71 additions & 14 deletions test/test_decomp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import defaultdict
from torch import Tensor
import torch.autograd
from torch._decomp import decomposition_table
from torch._decomp import core_aten_decompositions, decomposition_table
from torch.utils._python_dispatch import TorchDispatchMode

from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten
Expand All @@ -25,7 +25,7 @@
instantiate_device_type_tests,
onlyCUDA,
)
from torch.testing._internal.common_methods_invocations import op_db
from torch.testing._internal.common_methods_invocations import op_db, skip, skipOps, xfail
from torch._dispatch.python import enable_python_dispatcher
from torch._ops import DispatchKey

Expand All @@ -44,12 +44,19 @@ def overload_to_aten_name(overload):

# All operators that can have decomp tests
decomposition_names = {overload_to_aten_name(k) for k in decomposition_table}
core_decomposition_names = {overload_to_aten_name(k) for k in core_aten_decompositions()}
_decomp_test_ops = [
op
for op in op_db
if op.aten_name in decomposition_names
or op.aten_backward_name in decomposition_names
]
_decomp_test_ops_core_autograd = [
op
for op in op_db
if op.aten_name in core_decomposition_names
and op.supports_autograd
]


def diff_arg(arg, requires_grad=True):
Expand Down Expand Up @@ -403,6 +410,40 @@ def test_unsupported(t):
return any(test_unsupported(x) for x in itertools.chain(flat_args, flat_kwargs))


core_backward_failures = {
skip('_softmax_backward_data'), # slow: fails with --timeout=360 secs
xfail('addcdiv'),
skip('addcmul'), # slow: fails with --timeout=360 secs
skip('deg2rad'), # slow: fails with --timeout=360 secs
skip('diag_embed'), # slow: fails with --timeout=360 secs
skip('frac'), # slow: fails with --timeout=360 secs
skip('grid_sampler_2d'), # slow: fails with --timeout=360 secs
xfail('lerp'),
skip('logaddexp'), # slow: fails with --timeout=360 secs
skip('native_dropout_backward'), # slow: fails with --timeout=360 secs
xfail('nn.functional.binary_cross_entropy_with_logits'),
xfail('nn.functional.hardshrink'),
xfail('nn.functional.softshrink'),
skip('nn.functional.unfold'), # slow: fails with --timeout=360 secs
xfail('norm'),
xfail('norm', 'fro'),
xfail('norm', 'inf'),
xfail('norm', 'nuc'),
skip('rad2deg'), # slow: fails with --timeout=360 secs
skip('renorm'), # slow: fails with --timeout=360 secs
skip('rot90'), # slow: fails with --timeout=360 secs
skip('rsub'), # slow: fails with --timeout=360 secs
skip('sgn'), # slow: fails with --timeout=360 secs
skip('special.xlog1py'), # slow: fails with --timeout=360 secs
xfail('stack'),
skip('tril'), # slow: fails with --timeout=360 secs
skip('triu'), # slow: fails with --timeout=360 secs
skip('unfold_copy'), # slow: fails with --timeout=360 secs
skip('xlogy'), # slow: fails with --timeout=360 secs
xfail('zero_'),
}


class TestDecomp(TestCase):
longMessage = True

Expand All @@ -417,6 +458,23 @@ class TestDecomp(TestCase):
def test_quick(self, device, dtype, op):
self.do_cross_ref(device, dtype, op, run_all=False)

@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
@skipOps('TestDecomp', 'test_quick_core_backward', core_backward_failures)
@onlyNativeDeviceTypes
@skipIfCrossRef
@suppress_warnings
@ops(_decomp_test_ops_core_autograd, allowed_dtypes=(torch.float64,))
def test_quick_core_backward(self, device, dtype, op):
for sample_input in op.sample_inputs(device, dtype, requires_grad=True):
aten_name = op.decomp_aten_name or op.aten_name
args = [sample_input.input] + list(sample_input.args)
kwargs = sample_input.kwargs
func = partial(op.get_op(), **kwargs)
with self.DecompCrossRefMode(self, self.precision, self.rel_tol, dtype, run_all=False)\
as mode, enable_python_dispatcher():
torch.autograd.gradcheck(func, args)
self.check_decomposed(aten_name, mode)

@unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN")
@onlyNativeDeviceTypes
@skipIfCrossRef
Expand Down Expand Up @@ -618,6 +676,14 @@ def __torch_dispatch__(self, func, types, args=(), kwargs=None):

return real_out_unflat

def check_decomposed(self, aten_name, mode):
self.assertTrue(
any(overload_to_aten_name(c) == aten_name for c in mode.decomposed),
msg=(f"aten.{aten_name} was not decomposed, saw calls for: "
f"{', '.join(map(str, list(mode.called)))}. If your op is "
f"CompositeImplicitAutograd you should skip this test "
f"by updating CROSS_REF_EXCLUDE_SET.")
)

@skipIfTorchDynamo("Test does not work with TorchDynamo")
def do_cross_ref(self, device, dtype, op, *, run_all):
Expand All @@ -642,15 +708,6 @@ def do_cross_ref(self, device, dtype, op, *, run_all):
)
samples = op.sample_inputs(device, dtype, requires_grad=requires_grad)

def check_decomposed(aten_name, mode):
self.assertTrue(
any(overload_to_aten_name(c) == aten_name for c in mode.decomposed),
msg=(f"aten.{aten_name} was not decomposed, saw calls for: "
f"{', '.join(map(str, list(mode.called)))}. If your op is "
f"CompositeImplicitAutograd you should skip this test "
"by updating CROSS_REF_EXCLUDE_SET.")
)

aten_name = op.decomp_aten_name or op.aten_name

func = op.get_op()
Expand All @@ -669,7 +726,7 @@ def check_decomposed(aten_name, mode):
as mode, enable_python_dispatcher():
decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals)
if aten_name in decomposition_names:
check_decomposed(aten_name, mode)
self.check_decomposed(aten_name, mode)

if not skip_decomp_vjp and (op.aten_backward_name in decomposition_names or run_all):
cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out)
Expand All @@ -678,7 +735,7 @@ def check_decomposed(aten_name, mode):
as mode, enable_python_dispatcher():
decomp_vjp_fn(cotangents)
if not run_all:
check_decomposed(op.aten_backward_name, mode)
self.check_decomposed(op.aten_backward_name, mode)

elif aten_name in decomposition_names or run_all:
args = [sample_input.input] + list(sample_input.args)
Expand All @@ -687,7 +744,7 @@ def check_decomposed(aten_name, mode):
as mode, enable_python_dispatcher():
func(*args, **kwargs)
if not run_all:
check_decomposed(aten_name, mode)
self.check_decomposed(aten_name, mode)
else:
assert op.supports_autograd
self.skipTest(
Expand Down

0 comments on commit 77f080e

Please sign in to comment.