Skip to content

Commit

Permalink
Add Int types in jagged_index_select (pytorch#1218)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1218

Add support for Int types in jagged_index_select and update the test
to validate Int types

Reviewed By: jianyuh

Differential Revision: D38105237

fbshipit-source-id: c9392da8378759bda384abd768a44f0aef5ad1b1
  • Loading branch information
Sarunya Pumma authored and facebook-github-bot committed Jul 26, 2022
1 parent 09099ea commit 949420c
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 13 deletions.
14 changes: 10 additions & 4 deletions fbgemm_gpu/src/jagged_tensor_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1623,8 +1623,11 @@ Tensor jagged_index_select_2d_cuda(
at::empty({num_dense_output_rows, num_cols}, values.options());
if (num_blocks > 0) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
values.scalar_type(), "jagged_index_select_2d_kernel_wrapper_1", [&] {
AT_DISPATCH_ALL_TYPES_AND(
at::ScalarType::Half,
values.scalar_type(),
"jagged_index_select_2d_kernel_wrapper_1",
[&] {
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(),
"jagged_index_select_2d_kernel_wrapper_2",
Expand Down Expand Up @@ -1715,8 +1718,11 @@ Tensor jagged_index_add_2d_cuda(
Tensor output = at::zeros({num_output_rows, num_cols}, grad.options());
if (num_blocks > 0) {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad.scalar_type(), "jagged_index_add_2d_kernel_wrapper_1", [&] {
AT_DISPATCH_ALL_TYPES_AND(
at::ScalarType::Half,
grad.scalar_type(),
"jagged_index_add_2d_kernel_wrapper_1",
[&] {
AT_DISPATCH_INDEX_TYPES(
indices.scalar_type(),
"jagged_index_add_2d_kernel_wrapper_2",
Expand Down
35 changes: 26 additions & 9 deletions fbgemm_gpu/test/jagged_tensor_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,7 +972,9 @@ def jagged_index_select_2d_ref(
num_cols=st.integers(1, 128),
num_jagged_tensor_rows=st.integers(1, 128),
index_dtype=st.sampled_from([torch.int, torch.long]),
jagged_tensor_dtype=st.sampled_from([torch.float, torch.half]),
jagged_tensor_dtype=st.sampled_from(
[torch.float, torch.half, torch.int, torch.long]
),
)
@settings(max_examples=20, deadline=None)
def test_jagged_index_select_2d(
Expand All @@ -984,6 +986,7 @@ def test_jagged_index_select_2d(
index_dtype: torch.dtype,
jagged_tensor_dtype: torch.dtype,
) -> None:
is_float = jagged_tensor_dtype in [torch.float, torch.half]
lengths = torch.randint(
low=0,
high=max_seq_length,
Expand All @@ -1000,21 +1003,35 @@ def test_jagged_index_select_2d(
device="cuda",
)
)
values = torch.rand(
int(lengths.sum().item()),
num_cols,
dtype=jagged_tensor_dtype,
device="cuda",
)
if is_float:
values = torch.rand(
int(lengths.sum().item()),
num_cols,
dtype=jagged_tensor_dtype,
device="cuda",
)
else:
values = torch.randint(
2**16,
(int(lengths.sum().item()), num_cols),
dtype=jagged_tensor_dtype,
device="cuda",
)
values_ref = values.detach().clone()
values.requires_grad = True
values_ref.requires_grad = True

# Only float tensors can require grad
if is_float:
values.requires_grad = True
values_ref.requires_grad = True

output, _ = torch.ops.fbgemm.jagged_index_select(values, lengths, indices)
output_ref = self.jagged_index_select_2d_ref(values_ref, lengths, indices)

assert torch.equal(output, output_ref)

if not is_float:
return

grad = torch.rand_like(output)
grad_ref = grad.detach().clone()

Expand Down

0 comments on commit 949420c

Please sign in to comment.