Skip to content

Commit

Permalink
jagged-jagged -> dense elementwise op (pytorch#995)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#995

Add a variation of kernel that computes element-wise operations btw jagged tensors and outputs dense. This is meant to be internal use for example for a backward operation of element-wise multiplication (doesn't make sense that an element-wise operation between jagged tensors with same sparsity pattern results in a dense tensor in a user facing operation).

Reviewed By: xing-liu

Differential Revision: D34845894

fbshipit-source-id: 7fb3843bd3cc1f3b4255c3db83d81873f756d147
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Mar 21, 2022
1 parent 6bdc26b commit 2f83ca0
Show file tree
Hide file tree
Showing 3 changed files with 379 additions and 46 deletions.
208 changes: 186 additions & 22 deletions fbgemm_gpu/src/jagged_tensor_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -360,28 +360,6 @@ Tensor jagged_dense_elementwise_jagged_output_(
return output;
}

std::tuple<Tensor, std::vector<Tensor>> jagged_dense_elementwise_mul(
const Tensor& x_values,
const std::vector<Tensor>& x_offsets,
const Tensor& y) {
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(x_values.get_device());

Tensor output;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x_values.scalar_type(), "jagged_scalars", [&] {
output = jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values,
x_offsets,
y,
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
return x * y;
});
});

return {output, x_offsets};
}

class JaggedDenseAddGPUOp
: public torch::autograd::Function<JaggedDenseAddGPUOp> {
public:
Expand Down Expand Up @@ -452,6 +430,192 @@ Tensor jagged_dense_elementwise_add(
return JaggedDenseAddGPUOp::apply(x_values, x_offsets, y)[0];
}

/**
* output = f(x, y) where x and y are jagged (and share x_offsets), and output
* is dense.
*
* @param padding_value padding_value for the output, not for inputs
*/
template <int NUM_JAGGED_DIM, typename index_t, typename scalar_t, typename F>
__global__ void jagged_jagged_elementwise_dense_output_kernel_(
const at::PackedTensorAccessor32<scalar_t, 2, at::RestrictPtrTraits>
x_values,
const std::array<index_t*, NUM_JAGGED_DIM> x_offsets,
const at::PackedTensorAccessor32<scalar_t, 2, at::RestrictPtrTraits>
y_values,
at::PackedTensorAccessor32<scalar_t, 3, at::RestrictPtrTraits> output,
const int64_t* jagged_dims,
F f,
const scalar_t padding_value) {
const int outer_dense_size = output.size(0);
const int jagged_folded_size = output.size(1);
const int inner_dense_size = output.size(2);

const int outer_begin = blockIdx.x * blockDim.y + threadIdx.y;
const int outer_stride = gridDim.x * blockDim.y;
for (int outer = outer_begin; outer < outer_dense_size * jagged_folded_size;
outer += outer_stride) {
const int oidx = outer / jagged_folded_size;
const int jidx = outer % jagged_folded_size;

int offset = oidx;
const bool is_zero = walk_down_tensor_storage_tree_<NUM_JAGGED_DIM>(
offset, jidx, jagged_dims, x_offsets);

if (is_zero) {
for (int iidx = threadIdx.x; iidx < inner_dense_size;
iidx += blockDim.x) {
output[oidx][jidx][iidx] = padding_value;
}
} else {
for (int iidx = threadIdx.x; iidx < inner_dense_size;
iidx += blockDim.x) {
output[oidx][jidx][iidx] =
f(x_values[offset][iidx], y_values[offset][iidx]);
}
}
}
}

template <typename scalar_t, typename F>
void jagged_jagged_elementwise_dense_output_(
const Tensor& x_values,
const std::vector<Tensor>& x_offsets,
const Tensor& y_values,
const Tensor& output,
F f,
const scalar_t padding_value = static_cast<scalar_t>(0)) {
TENSOR_ON_CUDA_GPU(x_values);
for (auto& x_offset : x_offsets) {
TENSOR_ON_CUDA_GPU(x_offset);
}

const int num_jagged_dim = output.dim() - 2;
TORCH_CHECK(x_offsets.size() == static_cast<size_t>(num_jagged_dim));

dim3 threads, blocks;
Tensor jagged_dims_tensor;
std::tie(threads, blocks, jagged_dims_tensor) =
check_shape_and_partition_(x_values, x_offsets, output);

// Canonicalize output to 3D, collapsing jagged dimensions.
Tensor output_reshaped = output.view({output.size(0), -1, output.size(-1)});

#define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \
{ \
Tensor x_offsets_contig[num_jagged_dim]; \
std::array<index_t*, NUM_JAGGED_DIM> x_offset_ptrs; \
for (int d = 0; d < num_jagged_dim; ++d) { \
x_offsets_contig[d] = x_offsets[d].contiguous(); \
x_offset_ptrs[d] = x_offsets_contig[d].template data_ptr<index_t>(); \
} \
jagged_jagged_elementwise_dense_output_kernel_<NUM_JAGGED_DIM, index_t> \
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
x_values.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>(), \
x_offset_ptrs, \
y_values.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>(), \
output_reshaped \
.packed_accessor32<scalar_t, 3, at::RestrictPtrTraits>(), \
jagged_dims_tensor.data_ptr<int64_t>(), \
f, \
padding_value); \
}

JAGGED_TENSOR_DISPATCH_DIMS();
C10_CUDA_KERNEL_LAUNCH_CHECK();

#undef INVOKE_KERNEL_WITH_DIM
}

class JaggedDenseMulGPUOp
: public torch::autograd::Function<JaggedDenseMulGPUOp> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const Tensor& x_values,
const std::vector<Tensor>& x_offsets,
const Tensor& y) {
std::vector<Tensor> tensors_to_save;
tensors_to_save.push_back(x_values);
tensors_to_save.insert(
tensors_to_save.end(), x_offsets.begin(), x_offsets.end());
tensors_to_save.push_back(y);
ctx->save_for_backward(tensors_to_save);

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(x_values.get_device());

Tensor output;
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x_values.scalar_type(), "jagged_scalars", [&] {
output = jagged_dense_elementwise_jagged_output_<scalar_t>(
x_values,
x_offsets,
y,
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
return x * y;
});
});

return {output};
}

static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_outputs) {
const Tensor x_values = ctx->get_saved_variables().front();
// Somehow, the following code generates a segfault during atomic
// operations probably related to ref counting.
// std::vector<Tensor> x_offsets(
// ctx->get_saved_variables().begin() + 1,
// ctx->get_saved_variables().end() - 1);
std::vector<Tensor> x_offsets;
for (int i = 1; i < ctx->get_saved_variables().size() - 1; ++i) {
x_offsets.push_back(ctx->get_saved_variables()[i]);
}
Tensor y = ctx->get_saved_variables().back();
TORCH_CHECK(grad_outputs.size() == 1);

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_outputs[0].get_device());

Tensor x_values_grad;
Tensor y_grad = at::empty_like(y);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
x_values.scalar_type(), "jagged_scalars", [&] {
x_values_grad = jagged_dense_elementwise_jagged_output_<scalar_t>(
grad_outputs[0],
x_offsets,
y,
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
return x * y;
});

jagged_jagged_elementwise_dense_output_<scalar_t>(
grad_outputs[0],
x_offsets,
x_values,
y_grad,
[] __device__(scalar_t x, scalar_t y) -> scalar_t {
return x * y;
});
});

return {
x_values_grad,
torch::autograd::Variable(), // x_offsets
y_grad};
}
};

std::tuple<Tensor, std::vector<Tensor>> jagged_dense_elementwise_mul(
const Tensor& x_values,
const std::vector<Tensor>& x_offsets,
const Tensor& y) {
return {JaggedDenseMulGPUOp::apply(x_values, x_offsets, y)[0], x_offsets};
}

} // namespace

Tensor
Expand Down
Loading

0 comments on commit 2f83ca0

Please sign in to comment.