Skip to content

Commit

Permalink
handle corner case of empty tensors in jagged tensor ops (pytorch#998)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#998

Handle when B or D is 0

Reviewed By: xing-liu, jasonjk-park

Differential Revision: D34990893

fbshipit-source-id: b3e3f57961d6d0cf0c91929b06b4f35e9d2b47d0
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Mar 21, 2022
1 parent 7dff50b commit 443e391
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 9 deletions.
14 changes: 14 additions & 0 deletions fbgemm_gpu/src/jagged_tensor_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ void jagged_dense_elementwise_dense_output_(
const int num_jagged_dim = y.dim() - 2;
TORCH_CHECK(x_offsets.size() == static_cast<size_t>(num_jagged_dim));

if (y.numel() == 0) {
return;
}

dim3 threads, blocks;
Tensor jagged_dims_tensor;
std::tie(threads, blocks, jagged_dims_tensor) =
Expand Down Expand Up @@ -314,6 +318,10 @@ void jagged_dense_elementwise_jagged_output_(
const int num_jagged_dim = y.dim() - 2;
TORCH_CHECK(x_offsets.size() == static_cast<size_t>(num_jagged_dim));

if (y.numel() == 0) {
return;
}

dim3 threads, blocks;
Tensor jagged_dims_tensor;
std::tie(threads, blocks, jagged_dims_tensor) =
Expand Down Expand Up @@ -493,6 +501,10 @@ void jagged_jagged_elementwise_dense_output_(
const int num_jagged_dim = output.dim() - 2;
TORCH_CHECK(x_offsets.size() == static_cast<size_t>(num_jagged_dim));

if (output.numel() == 0) {
return;
}

dim3 threads, blocks;
Tensor jagged_dims_tensor;
std::tie(threads, blocks, jagged_dims_tensor) =
Expand Down Expand Up @@ -848,6 +860,8 @@ class BatchedDenseVecJagged2DMulGPUOp
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
} else {
v_grad.zero_();
}
return {
Expand Down
15 changes: 15 additions & 0 deletions fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,11 @@ void jagged_dense_elementwise_dense_output_kernel_(
TORCH_CHECK(!NO_INNER_DENSE || y.size(-1) == 1);
const int inner_dense_size = NO_INNER_DENSE ? 1 : y.size(-1);
TORCH_CHECK(inner_dense_size == x_values.size(-1));

if (y.numel() == 0) {
return;
}

const int jagged_folded_size =
y.numel() / (outer_dense_size * inner_dense_size);
const int jagged_innermost_size = y.size(-2);
Expand Down Expand Up @@ -269,6 +274,11 @@ void jagged_dense_elementwise_jagged_output_kernel_(
TORCH_CHECK(!NO_INNER_DENSE || y.size(-1) == 1);
const int inner_dense_size = NO_INNER_DENSE ? 1 : y.size(-1);
TORCH_CHECK(inner_dense_size == x_values.size(-1));

if (y.numel() == 0) {
return;
}

const int jagged_folded_size =
y.numel() / (outer_dense_size * inner_dense_size);
const int jagged_innermost_size = y.size(-2);
Expand Down Expand Up @@ -442,6 +452,11 @@ void jagged_jagged_elementwise_dense_output_kernel_(
TORCH_CHECK(!NO_INNER_DENSE || output.size(-1) == 1);
const int inner_dense_size = NO_INNER_DENSE ? 1 : output.size(-1);
TORCH_CHECK(inner_dense_size == x_values.size(-1));

if (output.numel() == 0) {
return;
}

const int jagged_folded_size =
output.numel() / (outer_dense_size * inner_dense_size);
const int jagged_innermost_size = output.size(-2);
Expand Down
19 changes: 10 additions & 9 deletions fbgemm_gpu/test/sparse_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import hypothesis.strategies as st
import numpy as np
import torch
from hypothesis import Verbosity, given, settings
from hypothesis import Verbosity, assume, given, settings

try:
# pyre-ignore[21]
Expand Down Expand Up @@ -1610,9 +1610,9 @@ def _generate_jagged_tensor(

# pyre-ignore [56]
@given(
num_jagged_dim=st.integers(min_value=1, max_value=5),
outer_dense_size=st.integers(min_value=1, max_value=5),
inner_dense_size=st.integers(min_value=1, max_value=5),
num_jagged_dim=st.integers(1, 5),
outer_dense_size=st.integers(0, 5),
inner_dense_size=st.integers(0, 5),
use_cpu=st.booleans() if gpu_available else st.just(True),
)
@settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None)
Expand Down Expand Up @@ -1660,9 +1660,9 @@ def test_jagged_to_padded_dense(

# pyre-ignore [56]
@given(
num_jagged_dim=st.integers(min_value=1, max_value=4),
outer_dense_size=st.integers(min_value=1, max_value=4),
inner_dense_size=st.integers(min_value=1, max_value=4),
num_jagged_dim=st.integers(1, 4),
outer_dense_size=st.integers(0, 4),
inner_dense_size=st.integers(0, 4),
operation=st.sampled_from(["add", "mul"]),
use_cpu=st.booleans() if gpu_available else st.just(True),
)
Expand Down Expand Up @@ -1720,10 +1720,10 @@ def test_jagged_elementwise_binary(
)
# pyre-ignore [56]
@given(
B=st.integers(1, 32),
B=st.integers(0, 32),
H=st.integers(1, 3),
max_L=st.integers(1, 32),
D=st.integers(1, 32),
D=st.integers(0, 32),
use_cpu=st.booleans() if gpu_available else st.just(True),
)
def test_batched_dense_vec_jagged_2d_mul(
Expand All @@ -1734,6 +1734,7 @@ def test_batched_dense_vec_jagged_2d_mul(
D: int,
use_cpu: bool,
) -> None:
assume(H == 1 or B != 0)
device = torch.device("cpu" if use_cpu else "cuda")
torch.backends.cuda.matmul.allow_tf32 = False

Expand Down

0 comments on commit 443e391

Please sign in to comment.