Skip to content

Commit

Permalink
Check index size during decomp of index_add (pytorch#108826)
Browse files Browse the repository at this point in the history
This partially fixes the `test_index_add_correctness` test (pytorch#108181)
when run under inductor: it causes an exception to be raised [here][1]
as expected.

The test as a whole still cannot be made to pass under inductor because
the [last assert][2] still fails, likely due to pytorch#108798.

[1]: https://github.com/pytorch/pytorch/blob/dec2b267d4af159bd0a8669a679700f385a1dc98/test/test_torch.py#L6049
[2]: https://github.com/pytorch/pytorch/blob/dec2b267d4af159bd0a8669a679700f385a1dc98/test/test_torch.py#L6051
Pull Request resolved: pytorch#108826
Approved by: https://github.com/eellison
  • Loading branch information
int3 authored and pytorchmergebot committed Sep 13, 2023
1 parent d2d36aa commit db48bc8
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions torch/_decomp/decompositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -2094,6 +2094,12 @@ def _index_add(
index.ndim <= 1,
lambda: f"Index should have dimension 1 or 0 (got {index.ndim})",
)
index_size = index.size(0) if index.ndim == 1 else 1
tensor_size = tensor.size(dim) if tensor.ndim > 0 else 1
torch._check(
tensor_size == index_size,
lambda: f"Number of indices ({index_size}) should be equal to tensor.size(dim) ({tensor_size}), for {dim=}",
)
if alpha != 1:
python_type = utils.dtype_to_type(x.dtype)
torch._check(
Expand Down

0 comments on commit db48bc8

Please sign in to comment.