Skip to content

Commit

Permalink
async _to_offsets(...) for KJT
Browse files Browse the repository at this point in the history
Summary:
1. Open source cumsum kernels
2. Use asynchronous_complete_cumsum to implement _to_offsets(...)

Reviewed By: dstaay-fb

Differential Revision: D31298164

fbshipit-source-id: c11d83766055e8b75d051ea6360433ea06697377
  • Loading branch information
xing-liu authored and facebook-github-bot committed Oct 5, 2021
1 parent 96e7cd8 commit 5243fc4
Show file tree
Hide file tree
Showing 5 changed files with 187 additions and 10 deletions.
12 changes: 11 additions & 1 deletion fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,17 @@
namespace at {

// Return array of size T_in.numel(), representing incomplete exclusive cumsum
Tensor asynchronous_exclusive_cumsum(const Tensor& t_in);
at::Tensor asynchronous_exclusive_cumsum_gpu(const at::Tensor& t_in);

at::Tensor asynchronous_complete_cumsum_gpu(const at::Tensor& t_in);

at::Tensor asynchronous_inclusive_cumsum_gpu(const at::Tensor& t_in);

at::Tensor asynchronous_exclusive_cumsum_cpu(const at::Tensor& t_in);

at::Tensor asynchronous_complete_cumsum_cpu(const at::Tensor& t_in);

at::Tensor asynchronous_inclusive_cumsum_cpu(const at::Tensor& t_in);

std::tuple<Tensor, Tensor, c10::optional<Tensor>> permute_sparse_data_cuda(
const Tensor& permute,
Expand Down
48 changes: 42 additions & 6 deletions fbgemm_gpu/src/sparse_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#include "cub/device/device_scan.cuh"

namespace at {
Tensor asynchronous_inclusive_cumsum(const Tensor& t_in) {
Tensor asynchronous_inclusive_cumsum_gpu(const Tensor& t_in) {
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(t_in.get_device());
size_t temp_storage_bytes = 0;
Expand Down Expand Up @@ -55,7 +55,7 @@ Tensor asynchronous_inclusive_cumsum(const Tensor& t_in) {
return t_out;
}

Tensor asynchronous_exclusive_cumsum(const Tensor& t_in) {
Tensor asynchronous_exclusive_cumsum_gpu(const Tensor& t_in) {
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(t_in.get_device());
size_t temp_storage_bytes = 0;
Expand Down Expand Up @@ -89,6 +89,42 @@ Tensor asynchronous_exclusive_cumsum(const Tensor& t_in) {
return t_out;
}

Tensor asynchronous_complete_cumsum_gpu(const Tensor& t_in) {
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(t_in.get_device());
size_t temp_storage_bytes = 0;
TORCH_CHECK(t_in.is_contiguous());
TORCH_CHECK(t_in.dtype() == kInt || t_in.dtype() == kLong);
// CUB only handles up to INT_MAX elements.
TORCH_CHECK(t_in.numel() < std::numeric_limits<int32_t>::max());
TORCH_CHECK(t_in.dim() == 1);
auto t_out = at::empty({t_in.numel() + 1}, t_in.options());
t_out[0].zero_();
AT_DISPATCH_INTEGRAL_TYPES(
t_in.scalar_type(), "cub_inclusive_sum_wrapper1", ([&] {
AT_CUDA_CHECK(cub::DeviceScan::InclusiveSum(
nullptr,
temp_storage_bytes,
t_in.data_ptr<scalar_t>(),
t_out.data_ptr<scalar_t>() + 1,
t_in.numel(),
at::cuda::getCurrentCUDAStream()));
}));
auto temp_storage = at::empty(
{static_cast<int64_t>(temp_storage_bytes)}, t_in.options().dtype(kByte));
AT_DISPATCH_INTEGRAL_TYPES(
t_in.scalar_type(), "cub_inclusive_sum_wrapper2", ([&] {
AT_CUDA_CHECK(cub::DeviceScan::InclusiveSum(
temp_storage.data_ptr(),
temp_storage_bytes,
t_in.data_ptr<scalar_t>(),
t_out.data_ptr<scalar_t>() + 1,
t_in.numel(),
at::cuda::getCurrentCUDAStream()));
}));
return t_out;
}

std::tuple<Tensor, Tensor, c10::optional<Tensor>> permute_sparse_data_cuda(
const Tensor& permute,
const Tensor& lengths,
Expand Down Expand Up @@ -137,8 +173,8 @@ std::tuple<Tensor, Tensor, c10::optional<Tensor>> permute_sparse_data_cuda(
}));

// convert lengths to offsets
const auto input_offsets = asynchronous_exclusive_cumsum(lengths_contig);
const auto output_offsets = asynchronous_exclusive_cumsum(permuted_lengths);
const auto input_offsets = asynchronous_exclusive_cumsum_gpu(lengths_contig);
const auto output_offsets = asynchronous_exclusive_cumsum_gpu(permuted_lengths);
int64_t permuted_indices_size = 0;
if (permuted_lengths_sum.has_value()) {
permuted_indices_size = permuted_lengths_sum.value();
Expand Down Expand Up @@ -245,7 +281,7 @@ block_bucketize_sparse_features_cuda(
Tensor new_pos;
Tensor unbucketize_permute;
// count nonzeros
offsets_contig = asynchronous_inclusive_cumsum(lengths);
offsets_contig = asynchronous_inclusive_cumsum_gpu(lengths);
int threads_per_block = 256;
int num_blocks = (lengths_size + threads_per_block - 1) / threads_per_block;
AT_DISPATCH_INDEX_TYPES(
Expand Down Expand Up @@ -274,7 +310,7 @@ block_bucketize_sparse_features_cuda(
}));
// bucketize nonzeros
new_offsets = asynchronous_exclusive_cumsum(new_lengths);
new_offsets = asynchronous_exclusive_cumsum_gpu(new_lengths);
if (sequence) {
const auto lengths_sum = indices.numel();
unbucketize_permute = at::empty({lengths_sum}, indices.options());
Expand Down
85 changes: 83 additions & 2 deletions fbgemm_gpu/src/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,12 @@ void _permute_lengths_cpu_kernel(
}
}

template <bool sequence, bool has_weight, typename offset_t, typename index_t, typename scalar_t>
template <
bool sequence,
bool has_weight,
typename offset_t,
typename index_t,
typename scalar_t>
void _block_bucketize_sparse_features_cpu(
at::Tensor lengths,
at::Tensor indices,
Expand Down Expand Up @@ -533,16 +538,92 @@ block_bucketize_sparse_features_cpu(
return {new_lengths, new_indices, new_weights, new_pos, unbucketize_permute};
}

// 1D exclusive scan: output[i] = input[i-1] + input[i-2] + input[i-3]
// Used as a helper to several functions below.
template <class T, class U>
U exclusive_scan_ptrs_cpu(
const int64_t N,
const T* const input,
U* const output) {
U cumsum = 0;
for (const auto i : c10::irange(N)) {
output[i] = cumsum;
cumsum += input[i];
}
return cumsum;
}

at::Tensor asynchronous_exclusive_cumsum_cpu(const at::Tensor& t_in) {
TENSOR_ON_CPU(t_in);

const auto t_in_contig = t_in.expect_contiguous();
auto output = native_empty_like(*t_in_contig);
AT_DISPATCH_ALL_TYPES(
t_in_contig->type(), "asynchronous_exclusive_cumsum_cpu_kernel", ([&] {
exclusive_scan_ptrs_cpu(
t_in_contig->numel(),
t_in_contig->data_ptr<scalar_t>(),
output.data_ptr<scalar_t>());
}));
return output;
}

at::Tensor asynchronous_inclusive_cumsum_cpu(const at::Tensor& t_in) {
TENSOR_ON_CPU(t_in);

const auto t_in_contig = t_in.expect_contiguous();
auto output = native_empty_like(*t_in_contig);
AT_DISPATCH_ALL_TYPES(
t_in_contig->type(), "asynchronous_inclusive_cumsum_cpu_kernel", ([&] {
scalar_t cumsum = 0;
const auto* input_ptr = t_in_contig->data_ptr<scalar_t>();
const auto N = t_in_contig->numel();
auto* output_ptr = output.data_ptr<scalar_t>();

for (const auto i : c10::irange(N)) {
cumsum += input_ptr[i];
output_ptr[i] = cumsum;
}
}));
return output;
}

at::Tensor asynchronous_complete_cumsum_cpu(const at::Tensor& t_in) {
TENSOR_ON_CPU(t_in);
TORCH_CHECK(t_in.dim() == 1);

const auto t_in_contig = t_in.expect_contiguous();
auto output = at::zeros({t_in.numel() + 1}, t_in.options());
AT_DISPATCH_ALL_TYPES(
t_in_contig->type(), "asynchronous_complete_cumsum_cpu_kernel", ([&] {
const auto N = t_in_contig->numel();
const auto last_sum = exclusive_scan_ptrs_cpu(
N, t_in_contig->data_ptr<scalar_t>(), output.data_ptr<scalar_t>());
output.data_ptr<scalar_t>()[N] = last_sum;
}));
return output;
}

} // namespace at

TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
m.def(
"permute_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, int? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)");
m.def(
"block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, int my_size, Tensor? weights=None) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)");
m.def("asynchronous_exclusive_cumsum(Tensor t_in) -> Tensor");
m.def("asynchronous_inclusive_cumsum(Tensor t_in) -> Tensor");
m.def("asynchronous_complete_cumsum(Tensor t_in) -> Tensor");
}

TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
m.impl("permute_sparse_data", at::permute_sparse_data_cpu);
m.impl("block_bucketize_sparse_features", at::block_bucketize_sparse_features_cpu);
m.impl(
"block_bucketize_sparse_features",
at::block_bucketize_sparse_features_cpu);
m.impl(
"asynchronous_exclusive_cumsum", at::asynchronous_exclusive_cumsum_cpu);
m.impl(
"asynchronous_inclusive_cumsum", at::asynchronous_inclusive_cumsum_cpu);
m.impl("asynchronous_complete_cumsum", at::asynchronous_complete_cumsum_cpu);
}
10 changes: 9 additions & 1 deletion fbgemm_gpu/src/sparse_ops_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,13 @@

TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
DISPATCH_TO_CUDA("permute_sparse_data", at::permute_sparse_data_cuda);
DISPATCH_TO_CUDA("block_bucketize_sparse_features", at::block_bucketize_sparse_features_cuda);
DISPATCH_TO_CUDA(
"block_bucketize_sparse_features",
at::block_bucketize_sparse_features_cuda);
DISPATCH_TO_CUDA(
"asynchronous_exclusive_cumsum", at::asynchronous_exclusive_cumsum_gpu);
DISPATCH_TO_CUDA(
"asynchronous_complete_cumsum", at::asynchronous_complete_cumsum_gpu);
DISPATCH_TO_CUDA(
"asynchronous_inclusive_cumsum", at::asynchronous_inclusive_cumsum_gpu);
}
42 changes: 42 additions & 0 deletions fbgemm_gpu/test/sparse_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,6 +415,48 @@ def test_block_bucketize_sparse_features_long_indices(
torch.testing.assert_allclose(new_lengths_gpu.cpu(), new_lengths_ref)
torch.testing.assert_allclose(new_indices_gpu.cpu(), new_indices_ref)

# pyre-ignore [56]
@given(
n=st.integers(min_value=1, max_value=100),
long_index=st.booleans(),
)
@settings(verbosity=Verbosity.verbose, max_examples=20, deadline=None)
def test_cumsum(self, n: int, long_index: bool) -> None:
index_dtype = torch.int64 if long_index else torch.int32
np_index_dtype = np.int64 if long_index else np.int32

# cpu tests
x = torch.randint(low=0, high=100, size=(n,)).type(index_dtype)
ze = torch.ops.fbgemm.asynchronous_exclusive_cumsum(x)
zi = torch.ops.fbgemm.asynchronous_inclusive_cumsum(x)
zc = torch.ops.fbgemm.asynchronous_complete_cumsum(x)
torch.testing.assert_allclose(
np.cumsum(x.cpu().numpy()).astype(np_index_dtype), zi.cpu()
)
torch.testing.assert_allclose(
(np.cumsum([0] + x.cpu().numpy().tolist())[:-1]).astype(np_index_dtype),
ze.cpu(),
)
torch.testing.assert_allclose(
(np.cumsum([0] + x.cpu().numpy().tolist())).astype(np_index_dtype), zc.cpu()
)

if torch.cuda.is_available():
x = x.cuda()
ze = torch.ops.fbgemm.asynchronous_exclusive_cumsum(x)
zi = torch.ops.fbgemm.asynchronous_inclusive_cumsum(x)
zc = torch.ops.fbgemm.asynchronous_complete_cumsum(x)
torch.testing.assert_allclose(
np.cumsum(x.cpu().numpy()).astype(np_index_dtype), zi.cpu()
)
torch.testing.assert_allclose(
(np.cumsum([0] + x.cpu().numpy().tolist())[:-1]).astype(np_index_dtype),
ze.cpu(),
)
torch.testing.assert_allclose(
(np.cumsum([0] + x.cpu().numpy().tolist())).astype(np_index_dtype), zc.cpu()
)

@unittest.skipIf(not torch.cuda.is_available(), "Skip when CUDA is not available")
# pyre-ignore [56]: Invalid decoration, was not able to infer the type of argument
@given(
Expand Down

0 comments on commit 5243fc4

Please sign in to comment.