From e6c754160d0d9948764d12a50bfe81524b08cc8e Mon Sep 17 00:00:00 2001 From: Leon Gao Date: Tue, 23 Apr 2024 18:36:32 -0700 Subject: [PATCH] introduce populate bucketize offsets kernel (#2533) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/2533 * this kernel is a batching runtime used unparametrized kernel version of block_bucketize, which produces the bucketization permute from the permuted data to the source data. * we do have limited paralleism on T,B, as we grow on L massively these days, an overall optimzaiton for intra-bag parallelism is to be revisited ASAP. Reviewed By: jiayisuse Differential Revision: D55511724 fbshipit-source-id: 6388a073cdf49d9304b7831f198e631184f32304 --- fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h | 12 ++ .../sparse_block_bucketize_features.cu | 69 ++++++++++++ fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp | 55 +++++++++ .../test/sparse/block_bucketize_test.py | 105 +++++++++++++++++- fbgemm_gpu/test/sparse/failures_dict.json | 18 +++ 5 files changed, 258 insertions(+), 1 deletion(-) diff --git a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h index 0863b11483..8cd2966b26 100644 --- a/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h +++ b/fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h @@ -194,6 +194,12 @@ block_bucketize_sparse_features_inference_cuda( const c10::optional>& block_bucketize_pos, const bool return_bucket_mapping); +///@ingroup sparse-data-cuda +at::Tensor populate_bucketized_permute_cuda( + const at::Tensor& length, + const at::Tensor& bucketized_length, + const at::Tensor& bucket_mapping); + std::tuple< at::Tensor, at::Tensor, @@ -216,6 +222,12 @@ block_bucketize_sparse_features_inference_cpu( const c10::optional>& block_bucketize_pos, const bool return_bucket_mapping); +///@ingroup sparse-data-cpu +at::Tensor populate_bucketized_permute_cpu( + const at::Tensor& length, + const at::Tensor& bucketized_length, + const at::Tensor& bucket_mapping); + std::tuple< at::Tensor, at::Tensor, diff --git a/fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu b/fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu index 262f4ec988..f1642111b8 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu +++ b/fbgemm_gpu/src/sparse_ops/sparse_block_bucketize_features.cu @@ -297,6 +297,27 @@ __launch_bounds__(kMaxThreads) void _block_bucketize_sequence_sparse_features_cu } } +template +__global__ +__launch_bounds__(kMaxThreads) void _populate_bucketized_permute_cuda_kernel( + const offset_t* const length_data, + const offset_t* const offset_data, + offset_t* const bucketized_offsets_data, + const index_t* const bucket_mapping_data, + index_t* const bucketized_permute_data_out, + int32_t lengths_size) { + CUDA_KERNEL_LOOP(b_t, lengths_size) { + const auto length = length_data[b_t]; + const auto offset = offset_data[b_t]; + for (size_t i = 0; i < length; i++) { + const auto index = offset + i; + const auto bucket = bucket_mapping_data[index]; + bucketized_permute_data_out[index] = + bucketized_offsets_data[bucket * lengths_size + b_t]++; + } + } +} + // This function partitions sparse features // continuously along the sparse dimension into my_size blocks std::tuple< @@ -932,6 +953,50 @@ block_bucketize_sparse_features_inference_cuda( return_bucket_mapping); } +DLL_PUBLIC Tensor populate_bucketized_permute_cuda( + const Tensor& lengths, + const Tensor& bucketized_lengths, + const Tensor& bucket_mapping) { + TENSORS_ON_SAME_CUDA_GPU_IF_NOT_OPTIONAL( + lengths, bucketized_lengths, bucket_mapping); + CUDA_DEVICE_GUARD(lengths); + const auto lengths_contig = lengths.expect_contiguous(); + const auto bucketized_lengths_contig = bucketized_lengths.expect_contiguous(); + const auto bucket_mapping_contig = bucket_mapping.expect_contiguous(); + Tensor bucketized_permute = at::empty_like(*bucket_mapping_contig); + const auto offsets = asynchronous_complete_cumsum_gpu(*lengths_contig); + const auto bucketized_offsets = + asynchronous_complete_cumsum_gpu(*bucketized_lengths_contig); + constexpr auto threads_per_block = 256; + const auto lengths_size = lengths.numel(); + const auto num_blocks = + cuda_calc_xblock_count(lengths_size, threads_per_block); + AT_DISPATCH_INDEX_TYPES( + lengths_contig->scalar_type(), + "_populate_bucketized_permute_cuda_kernel1", + [&] { + using offset_t = index_t; + AT_DISPATCH_INDEX_TYPES( + bucket_mapping_contig->scalar_type(), + "_populate_bucketized_permute_cuda_kernel2", + [&] { + _populate_bucketized_permute_cuda_kernel<<< + num_blocks, + threads_per_block, + 0, + at::cuda::getCurrentCUDAStream()>>>( + lengths_contig->data_ptr(), + offsets.data_ptr(), + bucketized_offsets.data_ptr(), + bucket_mapping_contig->data_ptr(), + bucketized_permute.data_ptr(), + lengths.numel()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + return bucketized_permute; +} + } // namespace fbgemm_gpu FBGEMM_OP_DISPATCH( @@ -942,3 +1007,7 @@ FBGEMM_OP_DISPATCH( CUDA, "block_bucketize_sparse_features_inference", fbgemm_gpu::block_bucketize_sparse_features_inference_cuda); +FBGEMM_OP_DISPATCH( + CUDA, + "populate_bucketized_permute", + fbgemm_gpu::populate_bucketized_permute_cuda); diff --git a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp index d331273699..1e24220f30 100644 --- a/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp +++ b/fbgemm_gpu/src/sparse_ops/sparse_ops_cpu.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include @@ -935,6 +936,55 @@ Tensor invert_permute_cpu(const Tensor& permute) { return inversed_permute; } +template +void _populate_bucketized_permute_cpu( + const offset_t* const length_data, + const offset_t* const offset_data, + offset_t* const bucketized_offsets_data, + const index_t* const bucket_mapping_data, + index_t* const bucketized_permute_data_out, + int64_t length_size) { + for (const auto i : c10::irange(length_size)) { + const auto length = length_data[i]; + const auto offset = offset_data[i]; + for (const auto j : c10::irange(length)) { + const auto index = offset + j; + const auto bucket = bucket_mapping_data[index]; + bucketized_permute_data_out[index] = + bucketized_offsets_data[bucket * length_size + i]++; + } + } +} + +Tensor populate_bucketized_permute_cpu( + const Tensor& lengths, + const Tensor& bucketized_lengths, + const Tensor& bucket_mapping) { + const auto lengths_contig = lengths.expect_contiguous(); + const auto bucketized_lengths_contig = bucketized_lengths.expect_contiguous(); + const auto bucket_mapping_contig = bucket_mapping.expect_contiguous(); + Tensor bucketized_permute = native_empty_like(*bucket_mapping_contig); + const auto offsets = asynchronous_complete_cumsum_cpu(*lengths_contig); + const auto bucketized_offsets = + asynchronous_complete_cumsum_cpu(*bucketized_lengths_contig); + AT_DISPATCH_INDEX_TYPES( + lengths.scalar_type(), "populate_bucketized_permute_cpu_1", ([&] { + using offset_t = index_t; + AT_DISPATCH_INDEX_TYPES( + bucket_mapping_contig->scalar_type(), + "populate_bucketized_permute_cpu_2", + ([&] { + _populate_bucketized_permute_cpu( + lengths_contig->data_ptr(), + offsets.data_ptr(), + bucketized_offsets.data_ptr(), + bucket_mapping_contig->data_ptr(), + bucketized_permute.data_ptr(), + lengths_contig->numel()); + })); + })); + return bucketized_permute; +} std::tuple< Tensor, @@ -2888,6 +2938,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) { m.def( "expand_into_jagged_permute(Tensor permute, Tensor input_offset, Tensor output_offset, SymInt output_size) -> Tensor", {PT2_COMPLIANT_TAG}); + m.def( + "populate_bucketized_permute(Tensor lengths, Tensor bucketized_lengths, Tensor bucket_mapping) -> Tensor"); m.def( "block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, SymInt my_size, Tensor? weights=None, Tensor? batch_size_per_feature=None, SymInt max_B= -1, Tensor[]? block_bucketize_pos=None) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)"); m.def( @@ -2988,6 +3040,9 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) { DISPATCH_TO_CPU("invert_permute", fbgemm_gpu::invert_permute_cpu); DISPATCH_TO_CPU( "expand_into_jagged_permute", fbgemm_gpu::expand_into_jagged_permute_cpu); + DISPATCH_TO_CPU( + "populate_bucketized_permute", + fbgemm_gpu::populate_bucketized_permute_cpu); DISPATCH_TO_CPU( "block_bucketize_sparse_features", fbgemm_gpu::block_bucketize_sparse_features_cpu); diff --git a/fbgemm_gpu/test/sparse/block_bucketize_test.py b/fbgemm_gpu/test/sparse/block_bucketize_test.py index d242564c41..cfac0f15a8 100644 --- a/fbgemm_gpu/test/sparse/block_bucketize_test.py +++ b/fbgemm_gpu/test/sparse/block_bucketize_test.py @@ -357,7 +357,6 @@ def test_block_bucketize_sparse_features_inference( self, index_type: Type[torch.dtype], ) -> None: - B = 2 # pyre-ignore [6] lengths = torch.tensor([0, 2, 1, 3, 2, 3, 3, 1], dtype=index_type) indices = torch.tensor( @@ -429,6 +428,110 @@ def test_block_bucketize_sparse_features_inference( bucket_mapping, ) + @skipIfRocm(ROCM_FAILURE_MESSAGE) + @given( + index_type=st.sampled_from([torch.int, torch.long]), + ) + @settings(verbosity=Verbosity.verbose, max_examples=16, deadline=None) + def test_populate_bucketized_permute( + self, + index_type: Type[torch.dtype], + ) -> None: + # pyre-ignore [6] + lengths = torch.tensor([0, 2, 1, 3, 2, 3, 3, 1], dtype=index_type) + indices = torch.tensor( + [3, 4, 15, 11, 28, 29, 1, 10, 11, 12, 13, 11, 22, 20, 20], + # pyre-ignore [6] + dtype=index_type, + ) + # pyre-ignore [6] + block_sizes = torch.tensor([5, 15, 10, 20], dtype=index_type) + my_size = 2 + + new_lengths_ref = torch.tensor( + [0, 2, 0, 1, 1, 0, 1, 0, 0, 0, 1, 2, 1, 3, 2, 1], + # pyre-ignore [6] + dtype=index_type, + ) + new_indices_ref = torch.tensor( + [3, 4, 11, 1, 11, 0, 13, 14, 0, 1, 2, 3, 2, 0, 0], + # pyre-ignore [6] + dtype=index_type, + ) + ( + new_lengths_cpu, + new_indices_cpu, + _, + _, + unbucketize_permute_cpu, + bucket_mapping_cpu, + ) = torch.ops.fbgemm.block_bucketize_sparse_features_inference( + lengths, + indices, + False, + True, + block_sizes, + my_size, + None, + return_bucket_mapping=True, + ) + + unbucketize_permute_populated_cpu = ( + torch.ops.fbgemm.populate_bucketized_permute( + lengths, + new_lengths_cpu, + bucket_mapping_cpu, + ) + ) + torch.testing.assert_close( + unbucketize_permute_populated_cpu, unbucketize_permute_cpu, rtol=0, atol=0 + ) + torch.testing.assert_close(new_lengths_cpu, new_lengths_ref, rtol=0, atol=0) + torch.testing.assert_close(new_indices_cpu, new_indices_ref, rtol=0, atol=0) + + if gpu_available: + ( + new_lengths_gpu, + new_indices_gpu, + _, + _, + unbucketize_permute_gpu, + bucket_mapping_gpu, + ) = torch.ops.fbgemm.block_bucketize_sparse_features_inference( + lengths.cuda(), + indices.cuda(), + False, + True, + block_sizes.cuda(), + my_size, + None, + return_bucket_mapping=True, + ) + + unbucketize_permute_populated_gpu = ( + torch.ops.fbgemm.populate_bucketized_permute( + lengths.cuda(), + new_lengths_gpu, + bucket_mapping_gpu, + ) + ) + torch.testing.assert_close( + unbucketize_permute_gpu.cpu(), + unbucketize_permute_populated_gpu.cpu(), + rtol=0, + atol=0, + ) + torch.testing.assert_close( + new_lengths_gpu.cpu(), new_lengths_ref, rtol=0, atol=0 + ) + torch.testing.assert_close( + new_lengths_gpu.cpu(), new_lengths_ref, rtol=0, atol=0 + ) + torch.testing.assert_allclose( + bucket_mapping_gpu.cpu(), + bucket_mapping_cpu, + ) + @skipIfRocm(ROCM_FAILURE_MESSAGE) @given( index_type=st.sampled_from([torch.int, torch.long]), diff --git a/fbgemm_gpu/test/sparse/failures_dict.json b/fbgemm_gpu/test/sparse/failures_dict.json index a301178d0d..fefa85fa01 100644 --- a/fbgemm_gpu/test/sparse/failures_dict.json +++ b/fbgemm_gpu/test/sparse/failures_dict.json @@ -66,9 +66,17 @@ "comment": "", "status": "xfail" }, + "BlockBucketizeTest.test_aot_dispatch_dynamic__test_populate_bucketized_permute": { + "comment": "", + "status": "xfail" + }, "BlockBucketizeTest.test_faketensor__test_block_bucketize_sparse_features_inference": { "comment": "", "status": "xfail" + }, + "BlockBucketizeTest.test_faketensor__test_populate_bucketized_permute": { + "comment": "", + "status": "xfail" } }, "fbgemm::bottom_k_per_row": { @@ -182,6 +190,16 @@ "status": "xfail" } }, + "fbgemm::populate_bucketized_permute": { + "BlockBucketizeTest.test_aot_dispatch_dynamic__test_populate_bucketized_permute": { + "comment": "", + "status": "xfail" + }, + "BlockBucketizeTest.test_faketensor__test_populate_bucketized_permute": { + "comment": "", + "status": "xfail" + } + }, "fbgemm::reorder_batched_ad_indices": { "ReorderBatchedTest.test_aot_dispatch_dynamic__test_reorder_batched_ad_indices": { "comment": "",