From 6bdc26b9ca39c43a3e43f0cb29e8ac716a0b2c44 Mon Sep 17 00:00:00 2001 From: Jongsoo Park Date: Mon, 21 Mar 2022 09:52:52 -0700 Subject: [PATCH] jagged tensor elementwise add backward (#996) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/996 Use jagged - dense -> jagged kernel in D34840551 (https://github.com/pytorch/FBGEMM/commit/43c0f12abda5e1e4c7b424968bddfabf23de9ed5) to implement elementwise add backward, and create its autograd function Reviewed By: jianyuh Differential Revision: D34844686 fbshipit-source-id: c28b1dd218d10387f6d29dfc1ae79df582c47fc2 --- fbgemm_gpu/src/jagged_tensor_ops.cu | 93 ++++++++++++++++++------ fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp | 75 +++++++++++++++---- fbgemm_gpu/test/sparse_ops_test.py | 16 +++- 3 files changed, 144 insertions(+), 40 deletions(-) diff --git a/fbgemm_gpu/src/jagged_tensor_ops.cu b/fbgemm_gpu/src/jagged_tensor_ops.cu index 4aa24116da..ec8db79ad4 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops.cu @@ -9,6 +9,7 @@ #include #include #include +#include #include // clang-format off @@ -263,28 +264,6 @@ Tensor jagged_dense_elementwise_dense_output_( return output; } -// output = x + y where x is jagged, y and output are dense -Tensor jagged_dense_elementwise_add( - const Tensor& x_values, - const std::vector& x_offsets, - const Tensor& y) { - at::cuda::OptionalCUDAGuard device_guard; - device_guard.set_index(x_values.get_device()); - - Tensor output; - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - x_values.scalar_type(), "jagged_scalars", [&] { - output = jagged_dense_elementwise_dense_output_( - x_values, - x_offsets, - y, - [] __device__(scalar_t x, scalar_t y) -> scalar_t { - return x + y; - }); - }); - return output; -} - template __global__ void jagged_dense_elementwise_jagged_output_kernel_( const at::PackedTensorAccessor32 @@ -403,6 +382,76 @@ std::tuple> jagged_dense_elementwise_mul( return {output, x_offsets}; } +class JaggedDenseAddGPUOp + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y) { + ctx->save_for_backward(x_offsets); + ctx->saved_data["x_values_shape"] = x_values.sizes(); + ctx->saved_data["y_shape"] = y.sizes(); + + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(x_values.get_device()); + + Tensor output; + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + x_values.scalar_type(), "jagged_dense_add_forward", [&] { + output = jagged_dense_elementwise_dense_output_( + x_values, + x_offsets, + y, + [] __device__(scalar_t x, scalar_t y) -> scalar_t { + return x + y; + }); + }); + + return {output}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_outputs) { + auto offsets = ctx->get_saved_variables(); + auto x_values_shape = ctx->saved_data["x_values_shape"].toIntVector(); + auto y_shape = ctx->saved_data["y_shape"].toIntVector(); + TORCH_CHECK(grad_outputs.size() == 1); + + at::cuda::OptionalCUDAGuard device_guard; + device_guard.set_index(grad_outputs[0].get_device()); + + Tensor x_values_grad = at::empty(x_values_shape, grad_outputs[0].options()); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + x_values_grad.scalar_type(), "jagged_dense_add_backward", [&] { + jagged_dense_elementwise_jagged_output_( + x_values_grad, // dummy not used in the lambda function + offsets, + grad_outputs[0], + x_values_grad, + [] __device__(scalar_t /*unused*/, scalar_t y) -> scalar_t { + return y; + }); + }); + + return { + x_values_grad, + torch::autograd::Variable(), // x_offsets + grad_outputs[0]}; + } +}; + +// output = x + y where x is jagged, y and output are dense +Tensor jagged_dense_elementwise_add( + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y) { + return JaggedDenseAddGPUOp::apply(x_values, x_offsets, y)[0]; +} + } // namespace Tensor diff --git a/fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp b/fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp index b4329bd209..592df3575a 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp +++ b/fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp @@ -6,6 +6,7 @@ */ #include +#include #include #include "fbgemm_gpu/sparse_ops_utils.h" @@ -244,21 +245,6 @@ Tensor jagged_dense_elementwise_dense_output_( return output; } -Tensor jagged_dense_elementwise_add( - const Tensor& x_values, - const std::vector& x_offsets, - const Tensor& y) { - Tensor output; - AT_DISPATCH_FLOATING_TYPES_AND_HALF( - x_values.scalar_type(), "jagged_scalars", [&] { - output = jagged_dense_elementwise_dense_output_( - x_values, x_offsets, y, [](scalar_t x, scalar_t y) -> scalar_t { - return x + y; - }); - }); - return output; -} - template < int NUM_JAGGED_DIM, bool NO_INNER_DENSE, @@ -387,6 +373,65 @@ std::tuple> jagged_dense_elementwise_mul( return {output, x_offsets}; } +class JaggedDenseAddCPUOp + : public torch::autograd::Function { + public: + static torch::autograd::variable_list forward( + torch::autograd::AutogradContext* ctx, + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y) { + ctx->save_for_backward(x_offsets); + ctx->saved_data["x_values_shape"] = x_values.sizes(); + ctx->saved_data["y_shape"] = y.sizes(); + + Tensor output; + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + x_values.scalar_type(), "jagged_scalars", [&] { + output = jagged_dense_elementwise_dense_output_( + x_values, x_offsets, y, [](scalar_t x, scalar_t y) -> scalar_t { + return x + y; + }); + }); + + return {output}; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_outputs) { + auto offsets = ctx->get_saved_variables(); + auto x_values_shape = ctx->saved_data["x_values_shape"].toIntVector(); + auto y_shape = ctx->saved_data["y_shape"].toIntVector(); + TORCH_CHECK(grad_outputs.size() == 1); + + Tensor x_values_grad = at::empty(x_values_shape, grad_outputs[0].options()); + + AT_DISPATCH_FLOATING_TYPES_AND_HALF( + x_values_grad.scalar_type(), "jagged_scalars", [&] { + jagged_dense_elementwise_jagged_output_( + x_values_grad, // dummy not used in the lambda function + offsets, + grad_outputs[0], + x_values_grad, + [](scalar_t /*unused*/, scalar_t y) -> scalar_t { return y; }); + }); + + return { + x_values_grad, + torch::autograd::Variable(), // x_offsets + grad_outputs[0]}; + } +}; + +// output = x + y where x is jagged, y and output are dense +Tensor jagged_dense_elementwise_add( + const Tensor& x_values, + const std::vector& x_offsets, + const Tensor& y) { + return JaggedDenseAddCPUOp::apply(x_values, x_offsets, y)[0]; +} + } // namespace Tensor diff --git a/fbgemm_gpu/test/sparse_ops_test.py b/fbgemm_gpu/test/sparse_ops_test.py index 3460032a44..1563d79dd8 100644 --- a/fbgemm_gpu/test/sparse_ops_test.py +++ b/fbgemm_gpu/test/sparse_ops_test.py @@ -1660,9 +1660,9 @@ def test_jagged_to_padded_dense( # pyre-ignore [56] @given( - num_jagged_dim=st.integers(min_value=1, max_value=5), - outer_dense_size=st.integers(min_value=1, max_value=5), - inner_dense_size=st.integers(min_value=1, max_value=5), + num_jagged_dim=st.integers(min_value=1, max_value=4), + outer_dense_size=st.integers(min_value=1, max_value=4), + inner_dense_size=st.integers(min_value=1, max_value=4), operation=st.sampled_from(["add", "mul"]), use_cpu=st.booleans() if gpu_available else st.just(True), ) @@ -1703,6 +1703,16 @@ def test_jagged_elementwise_binary( torch.testing.assert_close(output, output_ref) + if operation == "add": + torch.autograd.gradcheck( + torch.ops.fbgemm.jagged_dense_elementwise_add, + ( + x_values.double().requires_grad_(True), + x_offsets, + y.double().requires_grad_(True), + ), + ) + if __name__ == "__main__": unittest.main()