Skip to content

Commit

Permalink
jagged tensor elementwise add backward (pytorch#996)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#996

Use jagged - dense -> jagged kernel in D34840551 (pytorch@43c0f12) to implement elementwise add backward, and create its autograd function

Reviewed By: jianyuh

Differential Revision: D34844686

fbshipit-source-id: c28b1dd218d10387f6d29dfc1ae79df582c47fc2
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Mar 21, 2022
1 parent 27a9d08 commit 6bdc26b
Show file tree
Hide file tree
Showing 3 changed files with 144 additions and 40 deletions.
93 changes: 71 additions & 22 deletions fbgemm_gpu/src/jagged_tensor_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/Exceptions.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/csrc/autograd/custom_function.h>
#include <torch/library.h>

// clang-format off
Expand Down Expand Up @@ -263,28 +264,6 @@ Tensor jagged_dense_elementwise_dense_output_(
return output;
}

// output = x + y where x is jagged, y and output are dense
Tensor jagged_dense_elementwise_add(
const Tensor& x_values,
const std::vector<Tensor>& x_offsets,
const Tensor& y) {
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(x_values.get_device());

Tensor output;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x_values.scalar_type(), "jagged_scalars", [&] {
output = jagged_dense_elementwise_dense_output_<scalar_t>(
x_values,
x_offsets,
y,
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
return x + y;
});
});
return output;
}

template <int NUM_JAGGED_DIM, typename index_t, typename scalar_t, typename F>
__global__ void jagged_dense_elementwise_jagged_output_kernel_(
const at::PackedTensorAccessor32<scalar_t, 2, at::RestrictPtrTraits>
Expand Down Expand Up @@ -403,6 +382,76 @@ std::tuple<Tensor, std::vector<Tensor>> jagged_dense_elementwise_mul(
return {output, x_offsets};
}

class JaggedDenseAddGPUOp
: public torch::autograd::Function<JaggedDenseAddGPUOp> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const Tensor& x_values,
const std::vector<Tensor>& x_offsets,
const Tensor& y) {
ctx->save_for_backward(x_offsets);
ctx->saved_data["x_values_shape"] = x_values.sizes();
ctx->saved_data["y_shape"] = y.sizes();

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(x_values.get_device());

Tensor output;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x_values.scalar_type(), "jagged_dense_add_forward", [&] {
output = jagged_dense_elementwise_dense_output_<scalar_t>(
x_values,
x_offsets,
y,
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
return x + y;
});
});

return {output};
}

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_outputs) {
auto offsets = ctx->get_saved_variables();
auto x_values_shape = ctx->saved_data["x_values_shape"].toIntVector();
auto y_shape = ctx->saved_data["y_shape"].toIntVector();
TORCH_CHECK(grad_outputs.size() == 1);

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_outputs[0].get_device());

Tensor x_values_grad = at::empty(x_values_shape, grad_outputs[0].options());

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x_values_grad.scalar_type(), "jagged_dense_add_backward", [&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values_grad, // dummy not used in the lambda function
offsets,
grad_outputs[0],
x_values_grad,
[] __device__(scalar_t /*unused*/, scalar_t y) -> scalar_t {
return y;
});
});

return {
x_values_grad,
torch::autograd::Variable(), // x_offsets
grad_outputs[0]};
}
};

// output = x + y where x is jagged, y and output are dense
Tensor jagged_dense_elementwise_add(
const Tensor& x_values,
const std::vector<Tensor>& x_offsets,
const Tensor& y) {
return JaggedDenseAddGPUOp::apply(x_values, x_offsets, y)[0];
}

} // namespace

Tensor
Expand Down
75 changes: 60 additions & 15 deletions fbgemm_gpu/src/jagged_tensor_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
*/

#include <ATen/ATen.h>
#include <torch/csrc/autograd/custom_function.h>
#include <torch/library.h>

#include "fbgemm_gpu/sparse_ops_utils.h"
Expand Down Expand Up @@ -244,21 +245,6 @@ Tensor jagged_dense_elementwise_dense_output_(
return output;
}

Tensor jagged_dense_elementwise_add(
const Tensor& x_values,
const std::vector<Tensor>& x_offsets,
const Tensor& y) {
Tensor output;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x_values.scalar_type(), "jagged_scalars", [&] {
output = jagged_dense_elementwise_dense_output_<scalar_t>(
x_values, x_offsets, y, [](scalar_t x, scalar_t y) -> scalar_t {
return x + y;
});
});
return output;
}

template <
int NUM_JAGGED_DIM,
bool NO_INNER_DENSE,
Expand Down Expand Up @@ -387,6 +373,65 @@ std::tuple<Tensor, std::vector<Tensor>> jagged_dense_elementwise_mul(
return {output, x_offsets};
}

class JaggedDenseAddCPUOp
: public torch::autograd::Function<JaggedDenseAddCPUOp> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const Tensor& x_values,
const std::vector<Tensor>& x_offsets,
const Tensor& y) {
ctx->save_for_backward(x_offsets);
ctx->saved_data["x_values_shape"] = x_values.sizes();
ctx->saved_data["y_shape"] = y.sizes();

Tensor output;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x_values.scalar_type(), "jagged_scalars", [&] {
output = jagged_dense_elementwise_dense_output_<scalar_t>(
x_values, x_offsets, y, [](scalar_t x, scalar_t y) -> scalar_t {
return x + y;
});
});

return {output};
}

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_outputs) {
auto offsets = ctx->get_saved_variables();
auto x_values_shape = ctx->saved_data["x_values_shape"].toIntVector();
auto y_shape = ctx->saved_data["y_shape"].toIntVector();
TORCH_CHECK(grad_outputs.size() == 1);

Tensor x_values_grad = at::empty(x_values_shape, grad_outputs[0].options());

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x_values_grad.scalar_type(), "jagged_scalars", [&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values_grad, // dummy not used in the lambda function
offsets,
grad_outputs[0],
x_values_grad,
[](scalar_t /*unused*/, scalar_t y) -> scalar_t { return y; });
});

return {
x_values_grad,
torch::autograd::Variable(), // x_offsets
grad_outputs[0]};
}
};

// output = x + y where x is jagged, y and output are dense
Tensor jagged_dense_elementwise_add(
const Tensor& x_values,
const std::vector<Tensor>& x_offsets,
const Tensor& y) {
return JaggedDenseAddCPUOp::apply(x_values, x_offsets, y)[0];
}

} // namespace

Tensor
Expand Down
16 changes: 13 additions & 3 deletions fbgemm_gpu/test/sparse_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1660,9 +1660,9 @@ def test_jagged_to_padded_dense(

# pyre-ignore [56]
@given(
num_jagged_dim=st.integers(min_value=1, max_value=5),
outer_dense_size=st.integers(min_value=1, max_value=5),
inner_dense_size=st.integers(min_value=1, max_value=5),
num_jagged_dim=st.integers(min_value=1, max_value=4),
outer_dense_size=st.integers(min_value=1, max_value=4),
inner_dense_size=st.integers(min_value=1, max_value=4),
operation=st.sampled_from(["add", "mul"]),
use_cpu=st.booleans() if gpu_available else st.just(True),
)
Expand Down Expand Up @@ -1703,6 +1703,16 @@ def test_jagged_elementwise_binary(

torch.testing.assert_close(output, output_ref)

if operation == "add":
torch.autograd.gradcheck(
torch.ops.fbgemm.jagged_dense_elementwise_add,
(
x_values.double().requires_grad_(True),
x_offsets,
y.double().requires_grad_(True),
),
)


if __name__ == "__main__":
unittest.main()

0 comments on commit 6bdc26b

Please sign in to comment.