Skip to content

Commit

Permalink
introduce populate bucketize offsets kernel (#2533)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2533

* this kernel is a batching runtime used unparametrized kernel version of block_bucketize, which produces the bucketization permute from the permuted data to the source data.
* we do have limited paralleism on T,B, as we grow on L massively these days, an overall optimzaiton for intra-bag parallelism is to be revisited ASAP.

Reviewed By: jiayisuse

Differential Revision: D55511724

fbshipit-source-id: 6388a073cdf49d9304b7831f198e631184f32304
  • Loading branch information
YazhiGao authored and facebook-github-bot committed Apr 24, 2024
1 parent 26c55c6 commit e6c7541
Show file tree
Hide file tree
Showing 5 changed files with 258 additions and 1 deletion.
12 changes: 12 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,12 @@ block_bucketize_sparse_features_inference_cuda(
const c10::optional<std::vector<at::Tensor>>& block_bucketize_pos,
const bool return_bucket_mapping);

///@ingroup sparse-data-cuda
at::Tensor populate_bucketized_permute_cuda(
const at::Tensor& length,
const at::Tensor& bucketized_length,
const at::Tensor& bucket_mapping);

std::tuple<
at::Tensor,
at::Tensor,
Expand All @@ -216,6 +222,12 @@ block_bucketize_sparse_features_inference_cpu(
const c10::optional<std::vector<at::Tensor>>& block_bucketize_pos,
const bool return_bucket_mapping);

///@ingroup sparse-data-cpu
at::Tensor populate_bucketized_permute_cpu(
const at::Tensor& length,
const at::Tensor& bucketized_length,
const at::Tensor& bucket_mapping);

std::tuple<
at::Tensor,
at::Tensor,
Expand Down
69 changes: 69 additions & 0 deletions fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,27 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_sequence_sparse_features_cu
}
}

template <typename offset_t, typename index_t>
__global__
__launch_bounds__(kMaxThreads) void _populate_bucketized_permute_cuda_kernel(
const offset_t* const length_data,
const offset_t* const offset_data,
offset_t* const bucketized_offsets_data,
const index_t* const bucket_mapping_data,
index_t* const bucketized_permute_data_out,
int32_t lengths_size) {
CUDA_KERNEL_LOOP(b_t, lengths_size) {
const auto length = length_data[b_t];
const auto offset = offset_data[b_t];
for (size_t i = 0; i < length; i++) {
const auto index = offset + i;
const auto bucket = bucket_mapping_data[index];
bucketized_permute_data_out[index] =
bucketized_offsets_data[bucket * lengths_size + b_t]++;
}
}
}

// This function partitions sparse features
// continuously along the sparse dimension into my_size blocks
std::tuple<
Expand Down Expand Up @@ -932,6 +953,50 @@ block_bucketize_sparse_features_inference_cuda(
return_bucket_mapping);
}
DLL_PUBLIC Tensor populate_bucketized_permute_cuda(
const Tensor& lengths,
const Tensor& bucketized_lengths,
const Tensor& bucket_mapping) {
TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL(
lengths, bucketized_lengths, bucket_mapping);
CUDA_DEVICE_GUARD(lengths);
const auto lengths_contig = lengths.expect_contiguous();
const auto bucketized_lengths_contig = bucketized_lengths.expect_contiguous();
const auto bucket_mapping_contig = bucket_mapping.expect_contiguous();
Tensor bucketized_permute = at::empty_like(*bucket_mapping_contig);
const auto offsets = asynchronous_complete_cumsum_gpu(*lengths_contig);
const auto bucketized_offsets =
asynchronous_complete_cumsum_gpu(*bucketized_lengths_contig);
constexpr auto threads_per_block = 256;
const auto lengths_size = lengths.numel();
const auto num_blocks =
cuda_calc_xblock_count(lengths_size, threads_per_block);
AT_DISPATCH_INDEX_TYPES(
lengths_contig->scalar_type(),
"_populate_bucketized_permute_cuda_kernel1",
[&] {
using offset_t = index_t;
AT_DISPATCH_INDEX_TYPES(
bucket_mapping_contig->scalar_type(),
"_populate_bucketized_permute_cuda_kernel2",
[&] {
_populate_bucketized_permute_cuda_kernel<<<
num_blocks,
threads_per_block,
0,
at::cuda::getCurrentCUDAStream()>>>(
lengths_contig->data_ptr<offset_t>(),
offsets.data_ptr<offset_t>(),
bucketized_offsets.data_ptr<offset_t>(),
bucket_mapping_contig->data_ptr<index_t>(),
bucketized_permute.data_ptr<index_t>(),
lengths.numel());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
});
return bucketized_permute;
}
} // namespace fbgemm_gpu
FBGEMM_OP_DISPATCH(
Expand All @@ -942,3 +1007,7 @@ FBGEMM_OP_DISPATCH(
CUDA,
"block_bucketize_sparse_features_inference",
fbgemm_gpu::block_bucketize_sparse_features_inference_cuda);
FBGEMM_OP_DISPATCH(
CUDA,
"populate_bucketized_permute",
fbgemm_gpu::populate_bucketized_permute_cuda);
55 changes: 55 additions & 0 deletions fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <algorithm>
#include <cmath>
#include <cstdint>
#include <functional>

#include <ATen/ATen.h>
Expand Down Expand Up @@ -935,6 +936,55 @@ Tensor invert_permute_cpu(const Tensor& permute) {

return inversed_permute;
}
template <typename index_t, typename offset_t>
void _populate_bucketized_permute_cpu(
const offset_t* const length_data,
const offset_t* const offset_data,
offset_t* const bucketized_offsets_data,
const index_t* const bucket_mapping_data,
index_t* const bucketized_permute_data_out,
int64_t length_size) {
for (const auto i : c10::irange(length_size)) {
const auto length = length_data[i];
const auto offset = offset_data[i];
for (const auto j : c10::irange(length)) {
const auto index = offset + j;
const auto bucket = bucket_mapping_data[index];
bucketized_permute_data_out[index] =
bucketized_offsets_data[bucket * length_size + i]++;
}
}
}

Tensor populate_bucketized_permute_cpu(
const Tensor& lengths,
const Tensor& bucketized_lengths,
const Tensor& bucket_mapping) {
const auto lengths_contig = lengths.expect_contiguous();
const auto bucketized_lengths_contig = bucketized_lengths.expect_contiguous();
const auto bucket_mapping_contig = bucket_mapping.expect_contiguous();
Tensor bucketized_permute = native_empty_like(*bucket_mapping_contig);
const auto offsets = asynchronous_complete_cumsum_cpu(*lengths_contig);
const auto bucketized_offsets =
asynchronous_complete_cumsum_cpu(*bucketized_lengths_contig);
AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "populate_bucketized_permute_cpu_1", ([&] {
using offset_t = index_t;
AT_DISPATCH_INDEX_TYPES(
bucket_mapping_contig->scalar_type(),
"populate_bucketized_permute_cpu_2",
([&] {
_populate_bucketized_permute_cpu<index_t, offset_t>(
lengths_contig->data_ptr<offset_t>(),
offsets.data_ptr<offset_t>(),
bucketized_offsets.data_ptr<offset_t>(),
bucket_mapping_contig->data_ptr<index_t>(),
bucketized_permute.data_ptr<index_t>(),
lengths_contig->numel());
}));
}));
return bucketized_permute;
}

std::tuple<
Tensor,
Expand Down Expand Up @@ -2888,6 +2938,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"expand_into_jagged_permute(Tensor permute, Tensor input_offset, Tensor output_offset, SymInt output_size) -> Tensor",
{PT2_COMPLIANT_TAG});
m.def(
"populate_bucketized_permute(Tensor lengths, Tensor bucketized_lengths, Tensor bucket_mapping) -> Tensor");
m.def(
"block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1, Tensor[]? block_bucketize_pos=None) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)");
m.def(
Expand Down Expand Up @@ -2988,6 +3040,9 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
DISPATCH_TO_CPU("invert_permute", fbgemm_gpu::invert_permute_cpu);
DISPATCH_TO_CPU(
"expand_into_jagged_permute", fbgemm_gpu::expand_into_jagged_permute_cpu);
DISPATCH_TO_CPU(
"populate_bucketized_permute",
fbgemm_gpu::populate_bucketized_permute_cpu);
DISPATCH_TO_CPU(
"block_bucketize_sparse_features",
fbgemm_gpu::block_bucketize_sparse_features_cpu);
Expand Down
105 changes: 104 additions & 1 deletion fbgemm_gpu/test/sparse/block_bucketize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,6 @@ def test_block_bucketize_sparse_features_inference(
self,
index_type: Type[torch.dtype],
) -> None:
B = 2
# pyre-ignore [6]
lengths = torch.tensor([0, 2, 1, 3, 2, 3, 3, 1], dtype=index_type)
indices = torch.tensor(
Expand Down Expand Up @@ -429,6 +428,110 @@ def test_block_bucketize_sparse_features_inference(
bucket_mapping,
)

@skipIfRocm(ROCM_FAILURE_MESSAGE)
@given(
index_type=st.sampled_from([torch.int, torch.long]),
)
@settings(verbosity=Verbosity.verbose, max_examples=16, deadline=None)
def test_populate_bucketized_permute(
self,
index_type: Type[torch.dtype],
) -> None:
# pyre-ignore [6]
lengths = torch.tensor([0, 2, 1, 3, 2, 3, 3, 1], dtype=index_type)
indices = torch.tensor(
[3, 4, 15, 11, 28, 29, 1, 10, 11, 12, 13, 11, 22, 20, 20],
# pyre-ignore [6]
dtype=index_type,
)
# pyre-ignore [6]
block_sizes = torch.tensor([5, 15, 10, 20], dtype=index_type)
my_size = 2

new_lengths_ref = torch.tensor(
[0, 2, 0, 1, 1, 0, 1, 0, 0, 0, 1, 2, 1, 3, 2, 1],
# pyre-ignore [6]
dtype=index_type,
)
new_indices_ref = torch.tensor(
[3, 4, 11, 1, 11, 0, 13, 14, 0, 1, 2, 3, 2, 0, 0],
# pyre-ignore [6]
dtype=index_type,
)
(
new_lengths_cpu,
new_indices_cpu,
_,
_,
unbucketize_permute_cpu,
bucket_mapping_cpu,
) = torch.ops.fbgemm.block_bucketize_sparse_features_inference(
lengths,
indices,
False,
True,
block_sizes,
my_size,
None,
return_bucket_mapping=True,
)

unbucketize_permute_populated_cpu = (
torch.ops.fbgemm.populate_bucketized_permute(
lengths,
new_lengths_cpu,
bucket_mapping_cpu,
)
)
torch.testing.assert_close(
unbucketize_permute_populated_cpu, unbucketize_permute_cpu, rtol=0, atol=0
)
torch.testing.assert_close(new_lengths_cpu, new_lengths_ref, rtol=0, atol=0)
torch.testing.assert_close(new_indices_cpu, new_indices_ref, rtol=0, atol=0)

if gpu_available:
(
new_lengths_gpu,
new_indices_gpu,
_,
_,
unbucketize_permute_gpu,
bucket_mapping_gpu,
) = torch.ops.fbgemm.block_bucketize_sparse_features_inference(
lengths.cuda(),
indices.cuda(),
False,
True,
block_sizes.cuda(),
my_size,
None,
return_bucket_mapping=True,
)

unbucketize_permute_populated_gpu = (
torch.ops.fbgemm.populate_bucketized_permute(
lengths.cuda(),
new_lengths_gpu,
bucket_mapping_gpu,
)
)
torch.testing.assert_close(
unbucketize_permute_gpu.cpu(),
unbucketize_permute_populated_gpu.cpu(),
rtol=0,
atol=0,
)
torch.testing.assert_close(
new_lengths_gpu.cpu(), new_lengths_ref, rtol=0, atol=0
)
torch.testing.assert_close(
new_lengths_gpu.cpu(), new_lengths_ref, rtol=0, atol=0
)
torch.testing.assert_allclose(
bucket_mapping_gpu.cpu(),
bucket_mapping_cpu,
)

@skipIfRocm(ROCM_FAILURE_MESSAGE)
@given(
index_type=st.sampled_from([torch.int, torch.long]),
Expand Down
18 changes: 18 additions & 0 deletions fbgemm_gpu/test/sparse/failures_dict.json
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,17 @@
"comment": "",
"status": "xfail"
},
"BlockBucketizeTest.test_aot_dispatch_dynamic__test_populate_bucketized_permute": {
"comment": "",
"status": "xfail"
},
"BlockBucketizeTest.test_faketensor__test_block_bucketize_sparse_features_inference": {
"comment": "",
"status": "xfail"
},
"BlockBucketizeTest.test_faketensor__test_populate_bucketized_permute": {
"comment": "",
"status": "xfail"
}
},
"fbgemm::bottom_k_per_row": {
Expand Down Expand Up @@ -182,6 +190,16 @@
"status": "xfail"
}
},
"fbgemm::populate_bucketized_permute": {
"BlockBucketizeTest.test_aot_dispatch_dynamic__test_populate_bucketized_permute": {
"comment": "",
"status": "xfail"
},
"BlockBucketizeTest.test_faketensor__test_populate_bucketized_permute": {
"comment": "",
"status": "xfail"
}
},
"fbgemm::reorder_batched_ad_indices": {
"ReorderBatchedTest.test_aot_dispatch_dynamic__test_reorder_batched_ad_indices": {
"comment": "",
Expand Down

0 comments on commit e6c7541

Please sign in to comment.