Skip to content

Commit

Permalink
jagged tensor elementwise op with jagged output (pytorch#993)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#993

jagged-dense -> dense element-wise kernel (for ops that results zero when any of input is zero like multiplication)
Use the kernel to implement jagged_to_padded_dense backward

Reviewed By: jasonjk-park

Differential Revision: D34840551

fbshipit-source-id: 8f22688c0af999ba21c499c6b0b28b57c4b1fbf0
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Mar 19, 2022
1 parent 8910ce6 commit 43c0f12
Show file tree
Hide file tree
Showing 3 changed files with 298 additions and 180 deletions.
25 changes: 25 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -262,3 +262,28 @@ constexpr uint32_t cuda_calc_block_count(
return std::min(
cuda_calc_xblock_count(num_items, threads_per_block), max_blocks);
}

// Used in jagged_tensor_ops.cu and jagged_tensor_ops_cpu.cpp
#define JAGGED_TENSOR_DISPATCH_DIMS() \
AT_DISPATCH_INDEX_TYPES(x_offsets[0].scalar_type(), "jagged_indices", [&] { \
switch (num_jagged_dim) { \
case 1: \
INVOKE_KERNEL_WITH_DIM(1); \
break; \
case 2: \
INVOKE_KERNEL_WITH_DIM(2); \
break; \
case 3: \
INVOKE_KERNEL_WITH_DIM(3); \
break; \
case 4: \
INVOKE_KERNEL_WITH_DIM(4); \
break; \
case 5: \
INVOKE_KERNEL_WITH_DIM(5); \
break; \
default: \
TORCH_CHECK( \
false, "unsupported number of jagged dim ", num_jagged_dim); \
} \
});
220 changes: 114 additions & 106 deletions fbgemm_gpu/src/jagged_tensor_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -197,28 +197,8 @@ void jagged_dense_elementwise_dense_output_(
padding_value); \
}

AT_DISPATCH_INDEX_TYPES(x_offsets[0].scalar_type(), "jagged_indices", [&] {
switch (num_jagged_dim) {
case 1:
INVOKE_KERNEL_WITH_DIM(1);
break;
case 2:
INVOKE_KERNEL_WITH_DIM(2);
break;
case 3:
INVOKE_KERNEL_WITH_DIM(3);
break;
case 4:
INVOKE_KERNEL_WITH_DIM(4);
break;
case 5:
INVOKE_KERNEL_WITH_DIM(5);
break;
default:
TORCH_CHECK(false, "unsupported number of jagged dim ", num_jagged_dim);
}
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
JAGGED_TENSOR_DISPATCH_DIMS();
C10_CUDA_KERNEL_LAUNCH_CHECK();

#undef INVOKE_KERNEL_WITH_DIM
}
Expand Down Expand Up @@ -305,6 +285,90 @@ Tensor jagged_dense_elementwise_add(
return output;
}

template <int NUM_JAGGED_DIM, typename index_t, typename scalar_t, typename F>
__global__ void jagged_dense_elementwise_jagged_output_kernel_(
const at::PackedTensorAccessor32<scalar_t, 2, at::RestrictPtrTraits>
x_values,
const std::array<index_t*, NUM_JAGGED_DIM> x_offsets,
const at::PackedTensorAccessor32<scalar_t, 3, at::RestrictPtrTraits> y,
at::PackedTensorAccessor32<scalar_t, 2, at::RestrictPtrTraits>
output_values,
const int64_t* jagged_dims,
F f) {
const int outer_dense_size = y.size(0);
const int jagged_folded_size = y.size(1);
const int inner_dense_size = y.size(2);

const int outer_begin = blockIdx.x * blockDim.y + threadIdx.y;
const int outer_stride = gridDim.x * blockDim.y;
for (int outer = outer_begin; outer < outer_dense_size * jagged_folded_size;
outer += outer_stride) {
const int oidx = outer / jagged_folded_size;
const int jidx = outer % jagged_folded_size;

int offset = oidx;
const bool is_zero = walk_down_tensor_storage_tree_<NUM_JAGGED_DIM>(
offset, jidx, jagged_dims, x_offsets);

if (!is_zero) {
for (int iidx = threadIdx.x; iidx < inner_dense_size;
iidx += blockDim.x) {
output_values[offset][iidx] =
f(x_values[offset][iidx], y[oidx][jidx][iidx]);
}
}
}
}

template <typename scalar_t, typename F>
void jagged_dense_elementwise_jagged_output_(
const Tensor& x_values,
const std::vector<Tensor>& x_offsets,
const Tensor& y,
const Tensor& output_values,
F f) {
TENSOR_ON_CUDA_GPU(x_values);
for (auto& x_offset : x_offsets) {
TENSOR_ON_CUDA_GPU(x_offset);
}

const int num_jagged_dim = y.dim() - 2;
TORCH_CHECK(x_offsets.size() == static_cast<size_t>(num_jagged_dim));

dim3 threads, blocks;
Tensor jagged_dims_tensor;
std::tie(threads, blocks, jagged_dims_tensor) =
check_shape_and_partition_(x_values, x_offsets, y);

// Canonicalize y to 3D, collapsing jagged dimensions.
const Tensor y_reshaped = y.view({y.size(0), -1, y.size(-1)});

#define INVOKE_KERNEL_WITH_DIM(NUM_JAGGED_DIM) \
{ \
Tensor x_offsets_contig[num_jagged_dim]; \
std::array<index_t*, NUM_JAGGED_DIM> x_offset_ptrs; \
for (int d = 0; d < num_jagged_dim; ++d) { \
x_offsets_contig[d] = x_offsets[d].contiguous(); \
x_offset_ptrs[d] = x_offsets_contig[d].template data_ptr<index_t>(); \
} \
jagged_dense_elementwise_jagged_output_kernel_<NUM_JAGGED_DIM, index_t> \
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>( \
x_values.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>(), \
x_offset_ptrs, \
y_reshaped \
.packed_accessor32<scalar_t, 3, at::RestrictPtrTraits>(), \
output_values \
.packed_accessor32<scalar_t, 2, at::RestrictPtrTraits>(), \
jagged_dims_tensor.data_ptr<int64_t>(), \
f); \
}

JAGGED_TENSOR_DISPATCH_DIMS();
C10_CUDA_KERNEL_LAUNCH_CHECK();

#undef INVOKE_KERNEL_WITH_DIM
}

} // namespace

Tensor
Expand All @@ -316,70 +380,30 @@ jagged_2d_to_dense_forward_cuda(Tensor values, Tensor offsets, int32_t max_L) {
return jagged_to_padded_dense(values, {offsets}, {max_L}, 0);
}

template <typename index_t, typename scalar_t>
__global__ void jagged_2d_to_dense_backward_kernel(
int32_t B,
int32_t max_L,
int32_t D,
index_t* offsets,
scalar_t* grad_padded_values,
scalar_t* grad_values) {
int32_t b_l = blockIdx.x * blockDim.y + threadIdx.y;
int32_t l = b_l / B;
int32_t b = b_l % B;
if (b_l >= B * max_L) {
return;
}
int32_t row_start = offsets[b];
int32_t row_end = offsets[b + 1];
int32_t length = row_end - row_start;
if (l < length) {
for (int32_t d = threadIdx.x; d < D; d += kWarpSize) {
grad_values[(row_start + l) * D + d] =
grad_padded_values[b * max_L * D + l * D + d];
}
}
}

Tensor jagged_2d_to_dense_backward_cuda(
Tensor grad_padded_values,
Tensor offsets,
int32_t total_L) {
TENSOR_ON_CUDA_GPU(grad_padded_values);
TENSOR_ON_CUDA_GPU(offsets);

TORCH_CHECK(grad_padded_values.dim() == 3);
TORCH_CHECK(offsets.dim() == 1);
TORCH_CHECK(total_L >= 0);
TORCH_CHECK(offsets.numel() == grad_padded_values.size(0) + 1);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_padded_values.get_device());

int32_t B = grad_padded_values.size(0);
int32_t max_L = grad_padded_values.size(1);
int32_t D = grad_padded_values.size(2);
auto grad_values = at::zeros({total_L, D}, grad_padded_values.options());
const auto grad_padded_values_config = grad_padded_values.contiguous();
const auto offsets_contig = offsets.contiguous();

AT_DISPATCH_INDEX_TYPES(
offsets.scalar_type(), "jagged_2d_to_dense_backward_kernel_1", [&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_padded_values.scalar_type(),
"jagged_2d_to_dense_backward_kernel_2",
[&] {
jagged_2d_to_dense_backward_kernel<index_t, scalar_t>
<<<div_round_up((B * max_L), kMaxThreads / kWarpSize),
dim3(kWarpSize, kMaxThreads / kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
B,
max_L,
D,
offsets_contig.data_ptr<index_t>(),
grad_padded_values_config.data_ptr<scalar_t>(),
grad_values.data_ptr<scalar_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_padded_values.scalar_type(),
"jagged_2d_to_dense_backward_kernel",
[&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
grad_values, // dummy not used in the lambda function
{offsets},
grad_padded_values,
grad_values,
[] __device__(scalar_t /*unused*/, scalar_t y) -> scalar_t {
return y;
});
});

Expand Down Expand Up @@ -469,38 +493,22 @@ Tensor stacked_jagged_2d_to_dense_backward_cuda(
at::zeros({total_L, D}, grad_padded_values_per_key[0].options());
int32_t T = grad_padded_values_per_key.size();
for (int32_t t = 0; t < T; t++) {
const auto grad_padded_values_config =
grad_padded_values_per_key[t].contiguous();
int64_t start = offset_per_key[t] * D;
const auto offsets_config = offsets_tensor_per_key[t].contiguous();
TORCH_CHECK(grad_padded_values_config.dim() == 3);
TORCH_CHECK(grad_padded_values_config.size(0) == B);
TORCH_CHECK(grad_padded_values_config.size(2) == D);
int32_t max_L = grad_padded_values_config.size(1);
AT_DISPATCH_INDEX_TYPES(
offsets_config.scalar_type(),
"jagged_2d_to_dense_backward_kernel_1",
[&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_padded_values_config.scalar_type(),
"jagged_2d_to_dense_backward_kernel_2",
[&] {
jagged_2d_to_dense_backward_kernel<index_t, scalar_t>
<<<fbgemm_gpu::div_round_up(
(B * max_L),
fbgemm_gpu::kMaxThreads / fbgemm_gpu::kWarpSize),
dim3(
fbgemm_gpu::kWarpSize,
fbgemm_gpu::kMaxThreads / fbgemm_gpu::kWarpSize),
0,
at::cuda::getCurrentCUDAStream()>>>(
B,
max_L,
D,
offsets_config.data_ptr<index_t>(),
grad_padded_values_config.data_ptr<scalar_t>(),
&(grad_values.data_ptr<scalar_t>()[start]));
C10_CUDA_KERNEL_LAUNCH_CHECK();
TORCH_CHECK(grad_padded_values_per_key[t].dim() == 3);
TORCH_CHECK(grad_padded_values_per_key[t].size(0) == B);
TORCH_CHECK(grad_padded_values_per_key[t].size(2) == D);

Tensor grad_values_slice =
grad_values.slice(0, offset_per_key[t], offset_per_key[t + 1]);

AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_values.scalar_type(), "jagged_2d_to_dense_backward_kernel", [&] {
jagged_dense_elementwise_jagged_output_<scalar_t>(
grad_values_slice, // dummy not used in the lambda function
{offsets_tensor_per_key[t]},
grad_padded_values_per_key[t],
grad_values_slice,
[] __device__(scalar_t /*unused*/, scalar_t y) -> scalar_t {
return y;
});
});
}
Expand Down
Loading

0 comments on commit 43c0f12

Please sign in to comment.