diff --git a/fbgemm_gpu/src/jagged_tensor_ops.cu b/fbgemm_gpu/src/jagged_tensor_ops.cu index c9d3e50f5a..b6687d7edb 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops.cu @@ -270,6 +270,41 @@ Tensor jagged_to_padded_dense( return padded_values; } +template +Tensor jagged_dense_elementwise_dense_output_( + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y, + F f, + const scalar_t padding_value = static_cast(0)) { + Tensor output = at::empty_like(y); + jagged_dense_elementwise_dense_output_( + x_values, x_offsets, y, output, f, padding_value); + return output; +} + +// output = x + y where x is jagged, y and output are dense +Tensor jagged_dense_elementwise_add( + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y) { + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(x_values.get_device()); + + Tensor output; + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + x_values.scalar_type(), "jagged_scalars", [&] { + output = jagged_dense_elementwise_dense_output_( + x_values, + x_offsets, + y, + [] __device__(scalar_t x, scalar_t y) -> scalar_t { + return x + y; + }); + }); + return output; +} + } // namespace Tensor @@ -532,4 +567,6 @@ std::vector stacked_jagged_1d_to_dense_gpu( TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) { DISPATCH_TO_CUDA( "jagged_to_padded_dense", fbgemm_gpu::jagged_to_padded_dense); + DISPATCH_TO_CUDA( + "jagged_dense_elementwise_add", fbgemm_gpu::jagged_dense_elementwise_add); } diff --git a/fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp b/fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp index 049d8c9b9f..6723a7c45e 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp +++ b/fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp @@ -247,6 +247,34 @@ Tensor jagged_to_padded_dense( return padded_values; } +template +Tensor jagged_dense_elementwise_dense_output_( + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y, + F f, + const scalar_t& padding_value = static_cast(0)) { + Tensor output = at::empty_like(y); + jagged_dense_elementwise_dense_output_( + x_values, x_offsets, y, output, f, padding_value); + return output; +} + +Tensor jagged_dense_elementwise_add( + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y) { + Tensor output; + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + x_values.scalar_type(), "jagged_scalars", [&] { + output = jagged_dense_elementwise_dense_output_( + x_values, x_offsets, y, [](scalar_t x, scalar_t y) -> scalar_t { + return x + y; + }); + }); + return output; +} + } // namespace Tensor @@ -286,6 +314,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { "stacked_jagged_2d_to_dense(Tensor values, Tensor lengths, int[] offset_per_key, int[] max_lengths_per_key) -> Tensor[]"); m.def( "jagged_to_padded_dense(Tensor values, Tensor[] offsets, int[] max_lengths, int padding_value = 0) -> Tensor"); + m.def( + "jagged_dense_elementwise_add(Tensor x_values, Tensor[] x_offsets, Tensor y) -> Tensor"); } TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { @@ -293,4 +323,6 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { "jagged_2d_to_dense", fbgemm_gpu::jagged_2d_to_dense_forward_cpu); DISPATCH_TO_CPU("jagged_1d_to_dense", fbgemm_gpu::jagged_1d_to_dense_cpu); DISPATCH_TO_CPU("jagged_to_padded_dense", fbgemm_gpu::jagged_to_padded_dense); + DISPATCH_TO_CPU( + "jagged_dense_elementwise_add", fbgemm_gpu::jagged_dense_elementwise_add); } diff --git a/fbgemm_gpu/test/sparse_ops_test.py b/fbgemm_gpu/test/sparse_ops_test.py index 31bcbedadd..e61ac41484 100644 --- a/fbgemm_gpu/test/sparse_ops_test.py +++ b/fbgemm_gpu/test/sparse_ops_test.py @@ -1658,6 +1658,39 @@ def test_jagged_to_padded_dense( torch.testing.assert_close(output, output_ref) + # pyre-ignore [56] + @given( + num_jagged_dim=st.integers(min_value=1, max_value=5), + outer_dense_size=st.integers(min_value=1, max_value=5), + inner_dense_size=st.integers(min_value=1, max_value=5), + use_cpu=st.booleans() if gpu_available else st.just(True), + ) + @settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None) + def test_jagged_elementwise_add( + self, + num_jagged_dim: int, + outer_dense_size: int, + inner_dense_size: int, + use_cpu: bool, + ) -> None: + device = torch.device("cpu" if use_cpu else "cuda") + + x_values, x_offsets, max_lengths = self._generate_jagged_tensor( + num_jagged_dim, outer_dense_size, inner_dense_size, device + ) + y = torch.rand( + outer_dense_size * np.prod(max_lengths) * inner_dense_size, + dtype=torch.float, + device=device, + ).reshape((outer_dense_size,) + tuple(max_lengths) + (inner_dense_size,)) + + x_padded = self._to_padded_dense(x_values, x_offsets, max_lengths) + output_ref = x_padded + y + + output = torch.ops.fbgemm.jagged_dense_elementwise_add(x_values, x_offsets, y) + + torch.testing.assert_close(output, output_ref) + if __name__ == "__main__": unittest.main()