Skip to content

Commit

Permalink
batched dense vector x jagged 2D multiplication (pytorch#997)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#997

Reviewed By: xing-liu

Differential Revision: D34876009

fbshipit-source-id: 72f621b8ad96a987fb8d11e9e18116b2a3f5adfb
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Mar 21, 2022
1 parent 2f83ca0 commit 7dff50b
Show file tree
Hide file tree
Showing 3 changed files with 518 additions and 0 deletions.
252 changes: 252 additions & 0 deletions fbgemm_gpu/src/jagged_tensor_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,255 @@ std::tuple<Tensor, std::vector<Tensor>> jagged_dense_elementwise_mul(
return {JaggedDenseMulGPUOp::apply(x_values, x_offsets, y)[0], x_offsets};
}

template <typename index_t, typename scalar_t>
__global__ void dense_vec_jagged_2d_bmm(
const at::PackedTensorAccessor32<scalar_t, 2> v,
const at::PackedTensorAccessor32<scalar_t, 2> a_values,
const at::PackedTensorAccessor32<index_t, 1> a_offsets,
at::PackedTensorAccessor32<scalar_t, 2> output) {
const int B = a_offsets.size(0) - 1;
const int H = v.size(0) / B;
const int max_L = v.size(1);
const int D = output.size(1);

const int b_h_begin = blockIdx.x * blockDim.y + threadIdx.y;
const int b_h_step = gridDim.x * blockDim.y;
for (int b_h = b_h_begin; b_h < B * H; b_h += b_h_step) {
const int b = b_h / H;
const int h = b_h % H;

const int row_start = a_offsets[b];
const int row_end = a_offsets[b + 1];
const int length = std::min(row_end - row_start, max_L);
if (length == 0) {
for (int d = threadIdx.x; d < D; d += blockDim.x) {
output[b_h][d] = 0;
}
} else {
// TODO: use shared memory
for (int d = threadIdx.x; d < D; d += blockDim.x) {
at::acc_type<scalar_t, true> acc =
v[b_h][0] * a_values[row_start][h * D + d];
for (int l = 1; l < length; ++l) {
acc += v[b_h][l] * a_values[row_start + l][h * D + d];
}
output[b_h][d] = acc;
}
}
}
}

template <typename index_t, typename scalar_t>
__global__ void dense_vec_jagged_2d_transposed_bmm(
const at::PackedTensorAccessor32<scalar_t, 2> v,
const at::PackedTensorAccessor32<scalar_t, 2> a_values,
const at::PackedTensorAccessor32<index_t, 1> a_offsets,
at::PackedTensorAccessor32<scalar_t, 2> output) {
const int B = a_offsets.size(0) - 1;
const int H = v.size(0) / B;
const int max_L = output.size(1);
const int D = v.size(1);

const int b_h_begin = blockIdx.x * blockDim.y + threadIdx.y;
const int b_h_step = gridDim.x * blockDim.y;
for (int b_h = b_h_begin; b_h < B * H; b_h += b_h_step) {
const int b = b_h / H;
const int h = b_h % H;

const int row_start = a_offsets[b];
const int row_end = a_offsets[b + 1];
const int length = std::min(row_end - row_start, max_L);
if (D == 0) {
for (int l = threadIdx.x; l < max_L; ++l) {
output[b_h][l] = 0;
}
} else {
int l;
for (l = threadIdx.x; l < length; l += blockDim.x) {
at::acc_type<scalar_t, true> acc =
v[b_h][0] * a_values[row_start + l][h * D];
for (int d = 1; d < D; ++d) {
acc += v[b_h][d] * a_values[row_start + l][h * D + d];
}
output[b_h][l] = acc;
}
for (; l < max_L; l += blockDim.x) {
output[b_h][l] = 0;
}
}
}
}

template <typename index_t, typename scalar_t>
__global__ void outer_prod_jagged_2d_output(
const at::PackedTensorAccessor32<scalar_t, 2> x,
const at::PackedTensorAccessor32<scalar_t, 2> y,
const at::PackedTensorAccessor32<index_t, 1> offsets,
at::PackedTensorAccessor32<scalar_t, 2> output_values) {
const int B = offsets.size(0) - 1;
const int H = x.size(0) / B;
const int max_L = x.size(1);
const int D = y.size(1);

const int b_h_l_begin = blockIdx.x * blockDim.y + threadIdx.y;
const int b_h_l_step = gridDim.x * blockDim.y;
for (int b_h_l = b_h_l_begin; b_h_l < B * H * max_L; b_h_l += b_h_l_step) {
const int b_h = b_h_l / max_L;
const int b = b_h / H;
const int h = b_h % H;
const int l = b_h_l % max_L;

const int row_start = offsets[b];
const int row_end = offsets[b + 1];
const int length = row_end - row_start;
if (l < length) {
for (int d = threadIdx.x; d < D; d += blockDim.x) {
output_values[row_start + l][h * D + d] = x[b_h][l] * y[b_h][d];
}
}
}
}

// batched dense vector x jagged 2D tensor multiplication
// dense vector [B H, N]
// jagged tensor [B, N, H D] where N is jagged
class BatchedDenseVecJagged2DMulGPUOp
: public torch::autograd::Function<BatchedDenseVecJagged2DMulGPUOp> {
public:
static torch::autograd::variable_list forward(
torch::autograd::AutogradContext* ctx,
const Tensor& v,
const Tensor& a_values,
const Tensor& a_offsets) {
ctx->save_for_backward({v, a_values, a_offsets});

TENSOR_ON_CUDA_GPU(v);
TENSOR_ON_CUDA_GPU(a_values);
TENSOR_ON_CUDA_GPU(a_offsets);

at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(v.get_device());

const int B = a_offsets.numel() - 1;
TORCH_CHECK(B == 0 || v.size(0) % B == 0);
const int H = (B == 0) ? 1 : v.size(0) / B;
const int D = a_values.size(-1) / H;
const int max_L = v.size(-1);
auto output = at::empty({B * H, D}, v.options());

if (B > 0 && D > 0) {
const int block_dim_x =
std::min(div_round_up(D, kWarpSize) * kWarpSize, kMaxThreads);
const int block_dim_y = kMaxThreads / block_dim_x;

AT_DISPATCH_INDEX_TYPES(
a_offsets.scalar_type(), "dense_vec_jagged_2d_bmm_kernel_1", [&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
a_values.scalar_type(),
"dense_vec_jagged_2d_bmm_kernel_2",
[&] {
dense_vec_jagged_2d_bmm<index_t, scalar_t>
<<<div_round_up(B * H, block_dim_y),
dim3(block_dim_x, block_dim_y),
0,
at::cuda::getCurrentCUDAStream()>>>(
v.packed_accessor32<scalar_t, 2>(),
a_values.packed_accessor32<scalar_t, 2>(),
a_offsets.packed_accessor32<index_t, 1>(),
output.packed_accessor32<scalar_t, 2>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
}
return {output};
}
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_outputs) {
const auto saved = ctx->get_saved_variables();
auto savedItr = std::begin(saved);
const Tensor v = *savedItr++;
const Tensor a_values = *savedItr++;
const Tensor a_offsets = *savedItr++;
TORCH_CHECK(grad_outputs.size() == 1);
TENSOR_ON_CUDA_GPU(grad_outputs[0]);
TENSOR_ON_CUDA_GPU(a_values);
TENSOR_ON_CUDA_GPU(a_offsets);
TENSOR_ON_CUDA_GPU(v);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(grad_outputs[0].get_device());
const int B = a_offsets.numel() - 1;
const int D = grad_outputs[0].size(-1);
Tensor a_values_grad = at::empty_like(a_values);
Tensor v_grad = at::empty_like(v);
if (B > 0 && D > 0) {
TORCH_CHECK(v.size(0) % B == 0);
const int H = v.size(0) / B;
const int max_L = v.size(-1);
AT_DISPATCH_INDEX_TYPES(
a_offsets.scalar_type(),
"dense_vec_jagged_2d_bmm_baackward_kernel_1",
[&] {
AT_DISPATCH_FLOATING_TYPES_AND_HALF(
grad_outputs[0].scalar_type(),
"dense_vec_jagged_2d_bmm_baackward_kernel_2",
[&] {
int block_dim_x = std::min(
div_round_up(max_L, kWarpSize) * kWarpSize, kMaxThreads);
int block_dim_y = kMaxThreads / block_dim_x;
dense_vec_jagged_2d_transposed_bmm<index_t, scalar_t>
<<<div_round_up(B * H, block_dim_y),
dim3(block_dim_x, block_dim_y),
0,
at::cuda::getCurrentCUDAStream()>>>(
grad_outputs[0].packed_accessor32<scalar_t, 2>(),
a_values.packed_accessor32<scalar_t, 2>(),
a_offsets.packed_accessor32<index_t, 1>(),
v_grad.packed_accessor32<scalar_t, 2>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
block_dim_x = std::min(
div_round_up(D, kWarpSize) * kWarpSize, kMaxThreads);
block_dim_y = kMaxThreads / block_dim_x;
outer_prod_jagged_2d_output<index_t, scalar_t>
<<<div_round_up(B * H * max_L, block_dim_y),
dim3(block_dim_x, block_dim_y),
0,
at::cuda::getCurrentCUDAStream()>>>(
v.packed_accessor32<scalar_t, 2>(),
grad_outputs[0].packed_accessor32<scalar_t, 2>(),
a_offsets.packed_accessor32<index_t, 1>(),
a_values_grad.packed_accessor32<scalar_t, 2>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
}
return {
v_grad,
a_values_grad,
torch::autograd::Variable(), // a_offsets
};
}
};
Tensor batched_dense_vec_jagged_2d_mul(
const Tensor& v,
const Tensor& a_values,
const Tensor& a_offsets) {
return BatchedDenseVecJagged2DMulGPUOp::apply(v, a_values, a_offsets)[0];
}
} // namespace
Tensor
Expand Down Expand Up @@ -826,4 +1075,7 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
"jagged_dense_elementwise_add", fbgemm_gpu::jagged_dense_elementwise_add);
DISPATCH_TO_CUDA(
"jagged_dense_elementwise_mul", fbgemm_gpu::jagged_dense_elementwise_mul);
DISPATCH_TO_CUDA(
"batched_dense_vec_jagged_2d_mul",
fbgemm_gpu::batched_dense_vec_jagged_2d_mul);
}
Loading

0 comments on commit 7dff50b

Please sign in to comment.