Skip to content

Commit

Permalink
add backward of jagged_to_padded_dense (pytorch#1008)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#1008

and reuse in jagged_2d_to_dense and jagged_1d_to_dense

Reviewed By: jasonjk-park

Differential Revision: D35104497

fbshipit-source-id: 138498fc3153c249b9eff674fd8f51f0d7673492
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Mar 25, 2022
1 parent 9cf1a9f commit f99e161
Show file tree
Hide file tree
Showing 5 changed files with 227 additions and 198 deletions.
10 changes: 0 additions & 10 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,16 +222,6 @@ at::Tensor batched_unary_embeddings_backward_cuda(
const at::Tensor& offsets,
const at::Tensor& indices);

at::Tensor jagged_2d_to_dense_forward_cuda(
at::Tensor values,
at::Tensor offsets,
int32_t max_L);

at::Tensor jagged_2d_to_dense_backward_cuda(
at::Tensor grad_padded_values,
at::Tensor offsets,
int32_t total_L);

std::tuple<std::vector<at::Tensor>, std::vector<at::Tensor>>
stacked_jagged_2d_to_dense_forward_cuda(
at::Tensor values,
Expand Down
210 changes: 116 additions & 94 deletions fbgemm_gpu/src/jagged_tensor_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -226,58 +226,6 @@ void jagged_dense_elementwise_dense_output_(
#undef INVOKE_KERNEL_WITH_DIM
}

// Almost identical copy of jagged_to_padded_dense in jagged_tensor_ops_cpu.cpp
Tensor jagged_to_padded_dense(
const Tensor& values,
const std::vector<Tensor>& offsets,
const std::vector<int64_t>& max_lengths,
const int64_t padding_value) {
const size_t num_jagged_dim = offsets.size();
TORCH_CHECK(
max_lengths.size() == num_jagged_dim,
"max_lengths.size(), ",
max_lengths.size(),
" != num_jagged_dim, ",
num_jagged_dim);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(values.get_device());

const Tensor values_canonicalized = values.view(
{values.size(0),
std::accumulate(
values.sizes().begin() + 1,
values.sizes().end(),
1,
std::multiplies<size_t>())});
at::DimVector padded_values_shape({offsets[0].size(0) - 1});
padded_values_shape.insert(
padded_values_shape.end(), max_lengths.begin(), max_lengths.end());
if (values.dim() > 1) {
padded_values_shape.push_back(values.size(-1));
}
Tensor padded_values = at::empty(padded_values_shape, values.options());
Tensor padded_values_view =
values.dim() == 1 ? padded_values.unsqueeze(-1) : padded_values;

AT_DISPATCH_ALL_TYPES_AND(
at::ScalarType::Half,
values.scalar_type(),
"jagged_to_padded_dense",
[&] {
jagged_dense_elementwise_dense_output_<scalar_t>(
values_canonicalized,
offsets,
padded_values_view, // dummy not used in the lambda function
padded_values_view,
[] __device__(scalar_t x, scalar_t /*unused*/) -> scalar_t {
return x;
},
static_cast<scalar_t>(padding_value));
});

return padded_values;
}

template <typename scalar_t, typename F>
Tensor jagged_dense_elementwise_dense_output_(
const Tensor& x_values,
Expand Down Expand Up @@ -396,6 +344,117 @@ Tensor jagged_dense_elementwise_jagged_output_(
return output;
}

class JaggedToPaddedDenseGPUOp
: public torch::autograd::Function<JaggedToPaddedDenseGPUOp> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const Tensor& values,
const std::vector<Tensor>& offsets,
const std::vector<int64_t>& max_lengths,
const double padding_value) {
ctx->save_for_backward(offsets);
ctx->saved_data["total_L"] = values.size(0);

const size_t num_jagged_dim = offsets.size();
TORCH_CHECK(
max_lengths.size() == num_jagged_dim,
"max_lengths.size(), ",
max_lengths.size(),
" != num_jagged_dim, ",
num_jagged_dim);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(values.get_device());

const Tensor values_canonicalized = values.view(
{values.size(0),
std::accumulate(
values.sizes().begin() + 1,
values.sizes().end(),
1,
std::multiplies<size_t>())});
at::DimVector padded_values_shape({offsets[0].size(0) - 1});
padded_values_shape.insert(
padded_values_shape.end(), max_lengths.begin(), max_lengths.end());
if (values.dim() > 1) {
padded_values_shape.push_back(values.size(-1));
}
Tensor padded_values = at::empty(padded_values_shape, values.options());
Tensor padded_values_view =
values.dim() == 1 ? padded_values.unsqueeze(-1) : padded_values;

AT_DISPATCH_ALL_TYPES_AND(
at::ScalarType::Half,
values.scalar_type(),
"jagged_to_padded_dense",
[&] {
jagged_dense_elementwise_dense_output_<scalar_t>(
values_canonicalized,
offsets,
padded_values_view, // dummy not used in the lambda function
padded_values_view,
[] __device__(scalar_t x, scalar_t /*unused*/) -> scalar_t {
return x;
},
static_cast<scalar_t>(padding_value));
});

return {padded_values};
}

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_outputs) {
auto offsets = ctx->get_saved_variables();
int32_t total_L = ctx->saved_data["total_L"].toInt();
TORCH_CHECK(grad_outputs.size() == 1);

TORCH_CHECK(total_L >= 0);
auto grad_padded_values = grad_outputs[0];
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_padded_values.get_device());

int32_t D = grad_padded_values.size(-1);
auto grad_values = at::zeros({total_L, D}, grad_padded_values.options());

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_padded_values.scalar_type(),
"jagged_2d_to_dense_backward_kernel",
[&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
grad_values, // dummy not used in the lambda function
{offsets},
grad_padded_values,
grad_values,
[] __device__(scalar_t /*unused*/, scalar_t y) -> scalar_t {
return y;
});
});

return {
grad_values,
torch::autograd::Variable(), // offsets
torch::autograd::Variable(), // max_lengths
torch::autograd::Variable(), // padding_value
};
}
};

Tensor jagged_to_padded_dense(
const Tensor& values,
const std::vector<Tensor>& offsets,
const std::vector<int64_t>& max_lengths,
const double padding_value) {
return JaggedToPaddedDenseGPUOp::apply(
values, offsets, max_lengths, padding_value)[0];
}

Tensor
jagged_2d_to_dense(Tensor values, Tensor offsets, int64_t max_sequence_length) {
return jagged_to_padded_dense(
values, {offsets}, {max_sequence_length}, /*padding_value=*/0);
}

class JaggedDenseAddGPUOp
: public torch::autograd::Function<JaggedDenseAddGPUOp> {
public:
Expand Down Expand Up @@ -989,45 +1048,6 @@ Tensor batched_dense_vec_jagged_2d_mul(
} // namespace
Tensor
jagged_2d_to_dense_forward_cuda(Tensor values, Tensor offsets, int32_t max_L) {
TORCH_CHECK(values.dim() == 2);
TORCH_CHECK(offsets.dim() == 1);
TORCH_CHECK(max_L > 0);
return jagged_to_padded_dense(values, {offsets}, {max_L}, 0);
}
Tensor jagged_2d_to_dense_backward_cuda(
Tensor grad_padded_values,
Tensor offsets,
int32_t total_L) {
TORCH_CHECK(grad_padded_values.dim() == 3);
TORCH_CHECK(offsets.dim() == 1);
TORCH_CHECK(total_L >= 0);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_padded_values.get_device());
int32_t D = grad_padded_values.size(2);
auto grad_values = at::zeros({total_L, D}, grad_padded_values.options());
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_padded_values.scalar_type(),
"jagged_2d_to_dense_backward_kernel",
[&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
grad_values, // dummy not used in the lambda function
{offsets},
grad_padded_values,
grad_values,
[] __device__(scalar_t /*unused*/, scalar_t y) -> scalar_t {
return y;
});
});
return grad_values;
}
Tensor jagged_1d_to_dense_gpu(
Tensor values,
Tensor offsets,
Expand Down Expand Up @@ -1088,10 +1108,11 @@ stacked_jagged_2d_to_dense_forward_cuda(
});
offsets_tensor_per_key.push_back(offsets);
padded_values_per_key.push_back(jagged_2d_to_dense_forward_cuda(
padded_values_per_key.push_back(jagged_to_padded_dense(
values.slice(0, offset_per_key[t], offset_per_key[t + 1]),
offsets,
max_L));
{offsets},
{max_L},
/*padding_value=*/0));
}
return std::make_tuple(padded_values_per_key, offsets_tensor_per_key);
Expand Down Expand Up @@ -1193,6 +1214,7 @@ std::vector<Tensor> stacked_jagged_1d_to_dense_gpu(
TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
DISPATCH_TO_CUDA(
"jagged_to_padded_dense", fbgemm_gpu::jagged_to_padded_dense);
DISPATCH_TO_CUDA("jagged_2d_to_dense", fbgemm_gpu::jagged_2d_to_dense);
DISPATCH_TO_CUDA(
"jagged_dense_elementwise_add", fbgemm_gpu::jagged_dense_elementwise_add);
DISPATCH_TO_CUDA(
Expand Down
Loading

0 comments on commit f99e161

Please sign in to comment.