Skip to content

Commit

Permalink
[pt2] add meta for _cdist_backward (pytorch#106680)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#106680
Approved by: https://github.com/Skylion007
  • Loading branch information
nkaretnikov authored and pytorchmergebot committed Aug 7, 2023
1 parent 05e1a50 commit f694bcc
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 1 deletion.
1 change: 0 additions & 1 deletion test/functorch/test_aotdispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2811,7 +2811,6 @@ def forward(self, x):

symbolic_aot_autograd_failures = {
xfail('block_diag', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('cdist', ''), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail('combinations', ''), # aten.masked_select.default
xfail('diff', ''), # aten.zeros_like.default - couldn't find symbolic meta function/decomposition
xfail('digamma', ''), # aten.polygamma.default - couldn't find symbolic meta function/decomposition
Expand Down
19 changes: 19 additions & 0 deletions torch/_meta_registrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -3004,6 +3004,25 @@ def meta_cdist_forward(x1, x2, p, compute_mode):
return x1.new_empty(output_shape)


@register_meta(aten._cdist_backward)
@out_wrapper()
def meta_cdist_backward(grad, x1, x2, p, cdist):
c1 = x1.shape[-1]
r1 = x1.shape[-2]
r2 = x2.shape[-2]
batch_tensor1 = x1.shape[:-2]
batch_tensor2 = x2.shape[:-2]
expand_batch_portion = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2))
tensor1_expand_size = expand_batch_portion[:]
tensor1_expand_size.extend([r1, c1])
batch_product = math.prod(expand_batch_portion)
if r1 == 0 or r2 == 0 or c1 == 0 or batch_product == 0:
return torch.zeros_like(x1)
if tensor1_expand_size != list(x1.shape):
x1 = x1.expand(tensor1_expand_size)
return torch.empty_like(x1, memory_format=torch.contiguous_format)


# NB: This meta function accepts non-meta arguments! When this behavior
# was originally introduced this was accidental, but it is now load bearing
# as people are using this so that they can conveniently test code involving
Expand Down

0 comments on commit f694bcc

Please sign in to comment.