Skip to content

Commit

Permalink
Expand GRADIENT_IMPLEMENTED_FOR_COMPLEX to allow named tensors (pytor…
Browse files Browse the repository at this point in the history
…ch#47289)

Summary:
Complex-valued named tensors do not support backpropagation currently. This is due to `tools/autograd/gen_variable_type.py` not containing `alias` in `GRADIENT_IMPLEMENTED_FOR_COMPLEX` which is required to constructed named tensors.

This fixes pytorch#47157. Also removed a duplicate `cholesky` in the list and added a test in `test_autograd.py`.

Apologies, this is a duplicate of pytorch#47181 as I accidently removed my pytorch fork.

cc: zou3519 anjali411

Pull Request resolved: pytorch#47289

Reviewed By: agolynski

Differential Revision: D24706571

Pulled By: zou3519

fbshipit-source-id: 2cc48ce38eb180183c5b4ce2f8f4eef8bcac0316
  • Loading branch information
jonasteuwen authored and facebook-github-bot committed Nov 4, 2020
1 parent 5d82311 commit a11bc04
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
8 changes: 8 additions & 0 deletions test/test_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -4697,6 +4697,14 @@ def fn(a, b):
with self.assertRaisesRegex(RuntimeError, bad_mark_dirty_err):
fn(a, b)

def test_named_tensor_for_complex_views(self):
names = ["batch", "height", "width", "complex"]
z = torch.ones((5, 12, 14, 2), requires_grad=True)
z_named = z.refine_names(*names)
z_complex = torch.view_as_complex(z_named.rename(None)).refine_names(*names[:-1])
z_complex.sum().backward()
self.assertEqual(z.grad, torch.view_as_real(torch.ones_like(z_complex).rename(None)))

def test_custom_function_return_view_in_nograd(self):
class Alias(Function):
@staticmethod
Expand Down
2 changes: 1 addition & 1 deletion tools/autograd/gen_variable_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@
'neg', 'complex', 'select', '_s_where', 'as_strided', 'slice', 'constant_pad_nd',
'unbind', 'split', 'split_with_sizes', 'unsafe_split', 'split_with_sizes_backward',
'dot', 'vdot', 'cholesky', 'triangular_solve', 'mm', '_unsafe_view', 'mv', 'ger',
'bmm', 'diagonal', 'cholesky', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal',
'bmm', 'diagonal', 'alias', 'atan', 'log', 'log10', 'log1p', 'log2', 'reciprocal',
'tan', 'pow', 'rsqrt', 'tanh', 'tanh_backward', 'asinh', 'acosh', 'take', 'fill_',
'exp', 'nonzero'
}
Expand Down

0 comments on commit a11bc04

Please sign in to comment.