Skip to content

Commit

Permalink
symeig supports complex backward (pytorch#55085)
Browse files Browse the repository at this point in the history
Summary:
Fixes pytorch#53651
I did not put much effort in improving the docs, as I will go over all these docs in future PRs
cc anjali411

Pull Request resolved: pytorch#55085

Reviewed By: nikithamalgifb

Differential Revision: D27493604

Pulled By: anjali411

fbshipit-source-id: 413363013e188bc869c404b2d54ce1f87eef4425
  • Loading branch information
lezcano authored and facebook-github-bot committed Apr 12, 2021
1 parent e05ca75 commit 211d31a
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 25 deletions.
22 changes: 1 addition & 21 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from torch.testing._internal.common_cuda import TEST_CUDA, _get_torch_cuda_version
from torch.testing._internal.common_utils import (TestCase, run_tests, skipIfNoLapack,
suppress_warnings, slowTest,
load_tests, random_symmetric_matrix,
load_tests,
IS_WINDOWS, IS_MACOS, CudaMemoryLeakCheck,
TemporaryFileName, TEST_WITH_ROCM,
gradcheck, gradgradcheck)
Expand Down Expand Up @@ -2839,26 +2839,6 @@ def test_var_mean_differentiable(self):
torch.autograd.backward(r2, grad)
self.assertTrue(torch.allclose(input1.grad, input2.grad, rtol=0.01, atol=0.0))

@skipIfNoLapack
def test_symeig(self):
def func(root, upper):
x = 0.5 * (root + root.transpose(-2, -1))
return torch.symeig(x, eigenvectors=True, upper=upper)

def run_test(upper, dims):
root = torch.rand(*dims, requires_grad=True)

gradcheck(func, [root, upper])
gradgradcheck(func, [root, upper])

root = random_symmetric_matrix(dims[-1], *dims[:-2]).requires_grad_()
w, v = root.symeig(eigenvectors=True)
(w.sum() + v.sum()).backward()
self.assertEqual(root.grad, root.grad.transpose(-1, -2)) # Check the gradient is symmetric

for upper, dims in product([True, False], [(3, 3), (5, 3, 3), (4, 3, 2, 2)]):
run_test(upper, dims)

@slowTest
@skipIfNoLapack
def test_lobpcg(self):
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/gen_variable_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
'linalg_solve', 'sqrt', 'stack', 'gather', 'index_select', 'index_add_', 'linalg_inv',
'l1_loss_backward', 'baddbmm', 'addbmm', 'addmm', 'addmv', 'addr', 'linalg_householder_product',
'constant_pad_nd', 'reflection_pad1d', 'reflection_pad2d',
'reflection_pad1d_backward', 'reflection_pad2d_backward',
'reflection_pad1d_backward', 'reflection_pad2d_backward', 'symeig',
'replication_pad1d', 'replication_pad2d', 'replication_pad3d', 'take', 'put_',
'replication_pad1d_backward', 'replication_pad2d_backward', 'replication_pad3d_backward',
'diag', 'masked_scatter', 'masked_select', 'index_fill', 'trace', 'polar', 'cumsum', 'rsub',
Expand Down
6 changes: 3 additions & 3 deletions torch/_torch_docs.py
Original file line number Diff line number Diff line change
Expand Up @@ -8409,7 +8409,7 @@ def merge_dicts(*dicts):
symeig(input, eigenvectors=False, upper=True, *, out=None) -> (Tensor, Tensor)
This function returns eigenvalues and eigenvectors
of a real symmetric matrix :attr:`input` or a batch of real symmetric matrices,
of a real symmetric or complex Hermitian matrix :attr:`input` or a batch thereof,
represented by a namedtuple (eigenvalues, eigenvectors).
This function calculates all eigenvalues (and vectors) of :attr:`input`
Expand All @@ -8421,7 +8421,7 @@ def merge_dicts(*dicts):
If it is ``False``, only eigenvalues are computed. If it is ``True``,
both eigenvalues and eigenvectors are computed.
Since the input matrix :attr:`input` is supposed to be symmetric,
Since the input matrix :attr:`input` is supposed to be symmetric or Hermitian,
only the upper triangular portion is used by default.
If :attr:`upper` is ``False``, then lower triangular portion is used.
Expand All @@ -8438,7 +8438,7 @@ def merge_dicts(*dicts):
Args:
input (Tensor): the input tensor of size :math:`(*, n, n)` where `*` is zero or more
batch dimensions consisting of symmetric matrices.
batch dimensions consisting of symmetric or Hermitian matrices.
eigenvectors(bool, optional): controls whether eigenvectors have to be computed
upper(boolean, optional): controls whether to consider upper-triangular or lower-triangular region
Expand Down
21 changes: 21 additions & 0 deletions torch/testing/_internal/common_methods_invocations.py
Original file line number Diff line number Diff line change
Expand Up @@ -1558,6 +1558,16 @@ def sample_inputs_linalg_cholesky(op_info, device, dtype, requires_grad=False, *
out.append(SampleInput(a))
return out

def sample_inputs_symeig(op_info, device, dtype, requires_grad=False):
out = sample_inputs_linalg_invertible(op_info, device, dtype, requires_grad)

for o in out:
o.kwargs = {"upper": bool(np.random.choice([True, False])),
"eigenvectors": True}
# A gauge-invariant function
o.output_process_fn_grad = lambda output: (output[0], abs(output[1]))
return out


def sample_inputs_linalg_eigh(op_info, device, dtype, requires_grad=False, **kwargs):
"""
Expand Down Expand Up @@ -2782,6 +2792,17 @@ def gradcheck_wrapper_triangular_input(op, input, *args, upper=False, **kwargs):
skips=(
# cholesky_inverse does not correctly warn when resizing out= inputs
SkipInfo('TestCommon', 'test_out'),)),
OpInfo('symeig',
dtypes=floating_and_complex_types(),
check_batched_gradgrad=False,
sample_inputs_func=sample_inputs_symeig,
gradcheck_wrapper=gradcheck_wrapper_hermitian_input,
decorators=[skipCUDAIfNoMagma, skipCUDAIfRocm, skipCPUIfNoLapack],
skips=(
# cuda gradchecks are slow
# see discussion https://github.com/pytorch/pytorch/pull/47761#issuecomment-747316775
SkipInfo('TestGradients', 'test_fn_gradgrad', device_type='cuda'),)
),
UnaryUfuncInfo('clamp',
aliases=('clip', ),
decorators=(precisionOverride({torch.bfloat16: 7e-2, torch.float16: 1e-2}),),
Expand Down

0 comments on commit 211d31a

Please sign in to comment.