Skip to content

Commit

Permalink
jagged tensor elementwise add (pytorch#992)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#992

Using the generic element-wise kernel in the previous diff, implement jagged tensor element-wise add

Reviewed By: jasonjk-park

Differential Revision: D34812899

fbshipit-source-id: d0671e8934d502ebd627b29d0379c637f38405cd
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Mar 19, 2022
1 parent ca07944 commit 8910ce6
Show file tree
Hide file tree
Showing 3 changed files with 102 additions and 0 deletions.
37 changes: 37 additions & 0 deletions fbgemm_gpu/src/jagged_tensor_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,41 @@ Tensor jagged_to_padded_dense(
return padded_values;
}

template <typename scalar_t, typename F>
Tensor jagged_dense_elementwise_dense_output_(
const Tensor& x_values,
const std::vector<Tensor>& x_offsets,
const Tensor& y,
F f,
const scalar_t padding_value = static_cast<scalar_t>(0)) {
Tensor output = at::empty_like(y);
jagged_dense_elementwise_dense_output_(
x_values, x_offsets, y, output, f, padding_value);
return output;
}

// output = x + y where x is jagged, y and output are dense
Tensor jagged_dense_elementwise_add(
const Tensor& x_values,
const std::vector<Tensor>& x_offsets,
const Tensor& y) {
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(x_values.get_device());

Tensor output;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x_values.scalar_type(), "jagged_scalars", [&] {
output = jagged_dense_elementwise_dense_output_<scalar_t>(
x_values,
x_offsets,
y,
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
return x + y;
});
});
return output;
}

} // namespace

Tensor
Expand Down Expand Up @@ -532,4 +567,6 @@ std::vector<Tensor> stacked_jagged_1d_to_dense_gpu(
TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
DISPATCH_TO_CUDA(
"jagged_to_padded_dense", fbgemm_gpu::jagged_to_padded_dense);
DISPATCH_TO_CUDA(
"jagged_dense_elementwise_add", fbgemm_gpu::jagged_dense_elementwise_add);
}
32 changes: 32 additions & 0 deletions fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -247,6 +247,34 @@ Tensor jagged_to_padded_dense(
return padded_values;
}

template <typename scalar_t, typename F>
Tensor jagged_dense_elementwise_dense_output_(
const Tensor& x_values,
const std::vector<Tensor>& x_offsets,
const Tensor& y,
F f,
const scalar_t& padding_value = static_cast<scalar_t>(0)) {
Tensor output = at::empty_like(y);
jagged_dense_elementwise_dense_output_(
x_values, x_offsets, y, output, f, padding_value);
return output;
}

Tensor jagged_dense_elementwise_add(
const Tensor& x_values,
const std::vector<Tensor>& x_offsets,
const Tensor& y) {
Tensor output;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x_values.scalar_type(), "jagged_scalars", [&] {
output = jagged_dense_elementwise_dense_output_<scalar_t>(
x_values, x_offsets, y, [](scalar_t x, scalar_t y) -> scalar_t {
return x + y;
});
});
return output;
}

} // namespace

Tensor
Expand Down Expand Up @@ -286,11 +314,15 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"stacked_jagged_2d_to_dense(Tensor values, Tensor lengths, int[] offset_per_key, int[] max_lengths_per_key) -> Tensor[]");
m.def(
"jagged_to_padded_dense(Tensor values, Tensor[] offsets, int[] max_lengths, int padding_value = 0) -> Tensor");
m.def(
"jagged_dense_elementwise_add(Tensor x_values, Tensor[] x_offsets, Tensor y) -> Tensor");
}

TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
DISPATCH_TO_CPU(
"jagged_2d_to_dense", fbgemm_gpu::jagged_2d_to_dense_forward_cpu);
DISPATCH_TO_CPU("jagged_1d_to_dense", fbgemm_gpu::jagged_1d_to_dense_cpu);
DISPATCH_TO_CPU("jagged_to_padded_dense", fbgemm_gpu::jagged_to_padded_dense);
DISPATCH_TO_CPU(
"jagged_dense_elementwise_add", fbgemm_gpu::jagged_dense_elementwise_add);
}
33 changes: 33 additions & 0 deletions fbgemm_gpu/test/sparse_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1658,6 +1658,39 @@ def test_jagged_to_padded_dense(

torch.testing.assert_close(output, output_ref)

# pyre-ignore [56]
@given(
num_jagged_dim=st.integers(min_value=1, max_value=5),
outer_dense_size=st.integers(min_value=1, max_value=5),
inner_dense_size=st.integers(min_value=1, max_value=5),
use_cpu=st.booleans() if gpu_available else st.just(True),
)
@settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None)
def test_jagged_elementwise_add(
self,
num_jagged_dim: int,
outer_dense_size: int,
inner_dense_size: int,
use_cpu: bool,
) -> None:
device = torch.device("cpu" if use_cpu else "cuda")

x_values, x_offsets, max_lengths = self._generate_jagged_tensor(
num_jagged_dim, outer_dense_size, inner_dense_size, device
)
y = torch.rand(
outer_dense_size * np.prod(max_lengths) * inner_dense_size,
dtype=torch.float,
device=device,
).reshape((outer_dense_size,) + tuple(max_lengths) + (inner_dense_size,))

x_padded = self._to_padded_dense(x_values, x_offsets, max_lengths)
output_ref = x_padded + y

output = torch.ops.fbgemm.jagged_dense_elementwise_add(x_values, x_offsets, y)

torch.testing.assert_close(output, output_ref)


if __name__ == "__main__":
unittest.main()

0 comments on commit 8910ce6

Please sign in to comment.