Skip to content

Commit

Permalink
expand permute for 1D sparse data (pytorch#975)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: pytorch#975

Reviewed By: jasonjk-park

Differential Revision: D34778095

fbshipit-source-id: 5c177cc7619e59046c645d59b36c7e52701d58e2
  • Loading branch information
YazhiGao authored and facebook-github-bot committed Mar 28, 2022
1 parent dbe0a2a commit 073ea44
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 1 deletion.
72 changes: 72 additions & 0 deletions fbgemm_gpu/bench/sparse_ops_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

import logging
import random

import click
import fbgemm_gpu
import torch

logging.basicConfig(level=logging.DEBUG)

open_source: bool = getattr(fbgemm_gpu, "open_source", False)

if open_source:
# pyre-ignore[21]
from bench_utils import benchmark_torch_function
else:
from fbgemm_gpu.bench.bench_utils import benchmark_torch_function

torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")


@click.group()
def cli() -> None:
pass


@cli.command()
@click.option("--world-size", default=128)
@click.option("--num-tables", default=10)
@click.option("--min-len", default=10000)
@click.option("--max-len", default=20000)
def device(
world_size: int,
num_tables: int,
min_len: int,
max_len: int,
) -> None:
lengths = torch.randint(min_len, max_len, size=(num_tables * world_size,))
offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(lengths)
permute = list(range(num_tables * world_size))
random.shuffle(permute)
permute_tensor = torch.tensor(permute)
permuted_length = torch.index_select(lengths, 0, permute_tensor)
permuted_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(permuted_length)
jagged_size = offsets[-1]

if torch.cuda.is_available():
permute_tensor = permute_tensor.cuda()
offsets = offsets.cuda()
permuted_offsets = permuted_offsets.cuda()

time, output = benchmark_torch_function(
torch.ops.fbgemm.expand_into_jagged_permute,
(permute_tensor, offsets, permuted_offsets, jagged_size),
)

num_bytes = (
permute_tensor.numel() * permute_tensor.element_size()
+ offsets.numel() * offsets.element_size()
+ permuted_offsets.numel() * permuted_offsets.element_size()
+ output.numel() * output.element_size()
)
logging.info(f"expand_into_jagged_permute {time} sec {num_bytes / time / 1e9} GB/s")


if __name__ == "__main__":
cli()
32 changes: 32 additions & 0 deletions fbgemm_gpu/include/fbgemm_gpu/sparse_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,38 @@ permute_1D_sparse_data_cuda(
const c10::optional<at::Tensor>& weights,
const c10::optional<int64_t>& permuted_lengths_sum);

/*
* expand_into_jagged_permute expand the sparse data permute index from
* table dimension to batch dimension, for cases where the sparse features
* has different batch sizes across ranks.
*
* "permute":
* the table level permute index.
* "input_offsets":
* the exclusive offsets of table-level length.
* "output_offsets":
* the exclusive offsets of table-level permuted length.
*
* The op expands the permute from table level to batch level by
* contiguously mapping each bag of its corresponding tables to the position the
* batch sits on after feature permute. we will derive offset array of table and
* batch to compute the output permute.
*
* The output follows the following formula:
* output_permute[table_offset[permute[table]] + batch] <- bag_offset[batch].
*/
at::Tensor expand_into_jagged_permute_cuda(
const at::Tensor& permute,
const at::Tensor& input_offsets,
const at::Tensor& output_offsets,
int64_t output_size);

at::Tensor expand_into_jagged_permute_cpu(
const at::Tensor& permute,
const at::Tensor& input_offsets,
const at::Tensor& output_offsets,
int64_t output_size);

std::tuple<
at::Tensor,
at::Tensor,
Expand Down
64 changes: 64 additions & 0 deletions fbgemm_gpu/src/sparse_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,70 @@ std::tuple<Tensor, Tensor, c10::optional<Tensor>> permute_1D_sparse_data_cuda(
return {permuted_lengths, permuted_indices, permuted_weights};
}
// Kernel for generate 1D data permute from dimension permute index.
// Used for permutation of sparse features.
template <typename index_t, typename offsets_t>
__global__ void expand_into_jagged_permute_kernel(
const offsets_t* __restrict__ input_offsets,
const offsets_t* __restrict__ output_offsets,
int32_t input_size,
const index_t* __restrict__ permute,
index_t* __restrict__ output_permute) {
const int32_t t_start = blockIdx.x * blockDim.y + threadIdx.y;
const int stride = gridDim.x * blockDim.y;
for (int t = t_start; t < input_size; t += stride) {
const offsets_t output_start = output_offsets[t];
const offsets_t segment_length = output_offsets[t + 1] - output_offsets[t];
const offsets_t input_start = input_offsets[permute[t]];
for (int32_t i = threadIdx.x; i < segment_length; i += blockDim.x) {
output_permute[output_start + i] = input_start + i;
}
}
}
Tensor expand_into_jagged_permute_cuda(
const Tensor& permute,
const Tensor& input_offsets,
const Tensor& output_offsets,
int64_t output_size) {
TENSOR_ON_CUDA_GPU(permute);
TENSOR_ON_CUDA_GPU(input_offsets);
TENSOR_ON_CUDA_GPU(output_offsets);
TENSORS_ON_SAME_DEVICE(permute, input_offsets);
TENSORS_ON_SAME_DEVICE(permute, output_offsets);
TORCH_CHECK(permute.numel() > 0);
TORCH_CHECK(permute.numel() == input_offsets.numel() - 1);
TORCH_CHECK(permute.numel() == output_offsets.numel() - 1);
at::cuda::OptionalCUDAGuard device_guard;
device_guard.set_index(permute.get_device());
const auto permute_contig = permute.contiguous();
const auto permute_size = permute.numel();
Tensor output_permute = at::empty({output_size}, permute.options());
// number of table per block
constexpr int32_t T_blocks = kMaxThreads / kWarpSize;
dim3 threads(kWarpSize, T_blocks);
const auto blocks = cuda_calc_xblock_count(permute_size, T_blocks);
AT_DISPATCH_INDEX_TYPES(
permute.scalar_type(), "expand_into_jagged_permute_kernel", [&] {
using offsets_t = index_t;
expand_into_jagged_permute_kernel<index_t, offsets_t>
<<<blocks, threads, 0, at::cuda::getCurrentCUDAStream()>>>(
input_offsets.data_ptr<offsets_t>(),
output_offsets.data_ptr<offsets_t>(),
permute_size,
permute.data_ptr<index_t>(),
output_permute.data_ptr<index_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
return output_permute;
}
// Kernel for bucketize lengths, with the Block distribution (vs. cyclic,
// block-cyclic distribution). Used for bucketize sparse feature, especially for
// checkpointing with row-wise partition (sparse_feature is partitioned
Expand Down
57 changes: 56 additions & 1 deletion fbgemm_gpu/src/sparse_ops_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -662,6 +662,58 @@ std::tuple<Tensor, Tensor, c10::optional<Tensor>> permute_1D_sparse_data_cpu(
return {permuted_lengths, permuted_indices, permuted_weights};
}

template <typename index_t, typename offsets_t>
void _expand_into_jagged_permute_cpu_kernel(
const offsets_t* const __restrict__ input_offsets,
const offsets_t* const __restrict__ output_offsets,
const int64_t permute_size,
const index_t* const __restrict__ permute,
index_t* const __restrict__ output_permute) {
at::parallel_for(
0, permute_size, FALSE_SHARING_PAD, [&](int64_t t_begin, int64_t t_end) {
for (int t = t_begin; t < std::min(t_end, permute_size); ++t) {
offsets_t permute_length = output_offsets[t + 1] - output_offsets[t];
const offsets_t input_start = input_offsets[permute[t]];
const offsets_t output_start = output_offsets[t];
for (const auto i : c10::irange(permute_length)) {
output_permute[output_start + i] = input_start + i;
}
}
}); // parallel_for T
}

Tensor expand_into_jagged_permute_cpu(
const Tensor& permute,
const Tensor& input_offsets,
const Tensor& output_offsets,
int64_t output_size) {
TENSOR_ON_CPU(permute);
TENSOR_ON_CPU(input_offsets);
TENSOR_ON_CPU(output_offsets);
TORCH_CHECK(permute.numel() > 0);
TORCH_CHECK(permute.numel() == input_offsets.numel() - 1);
TORCH_CHECK(permute.numel() == output_offsets.numel() - 1);

const auto permute_contig = permute.contiguous();

const auto permute_size = permute.numel();

Tensor output_permute = at::empty({output_size}, input_offsets.options());

AT_DISPATCH_INDEX_TYPES(
permute.scalar_type(), "expand_into_jagged_permute_cpu", [&] {
using offset_t = index_t;
_expand_into_jagged_permute_cpu_kernel(
input_offsets.data_ptr<offset_t>(),
output_offsets.data_ptr<offset_t>(),
permute_size,
permute.data_ptr<index_t>(),
output_permute.data_ptr<index_t>());
});

return output_permute;
}

std::tuple<
Tensor,
Tensor,
Expand Down Expand Up @@ -1700,6 +1752,8 @@ TORCH_LIBRARY_FRAGMENT(fbgemm, m) {
"permute_2D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, int? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)");
m.def(
"permute_1D_sparse_data(Tensor permute, Tensor lengths, Tensor values, Tensor? weights=None, int? permuted_lengths_sum=None) -> (Tensor, Tensor, Tensor?)");
m.def(
"expand_into_jagged_permute(Tensor permute, Tensor input_offset, Tensor output_offset, int output_size) -> Tensor");
m.def(
"block_bucketize_sparse_features(Tensor lengths, Tensor indices, bool bucketize_pos, bool sequence, Tensor block_sizes, int my_size, Tensor? weights=None) -> (Tensor, Tensor, Tensor?, Tensor?, Tensor?)");
m.def(
Expand Down Expand Up @@ -1736,7 +1790,8 @@ TORCH_LIBRARY_IMPL(fbgemm, CPU, m) {
"permute_2D_sparse_data", fbgemm_gpu::permute_2D_sparse_data_cpu);
DISPATCH_TO_CPU(
"permute_1D_sparse_data", fbgemm_gpu::permute_1D_sparse_data_cpu);

DISPATCH_TO_CPU(
"expand_into_jagged_permute", fbgemm_gpu::expand_into_jagged_permute_cpu);
DISPATCH_TO_CPU(
"block_bucketize_sparse_features",
fbgemm_gpu::block_bucketize_sparse_features_cpu);
Expand Down
3 changes: 3 additions & 0 deletions fbgemm_gpu/src/sparse_ops_gpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ TORCH_LIBRARY_IMPL(fbgemm, CUDA, m) {
"permute_2D_sparse_data", fbgemm_gpu::permute_2D_sparse_data_cuda);
DISPATCH_TO_CUDA(
"permute_1D_sparse_data", fbgemm_gpu::permute_1D_sparse_data_cuda);
DISPATCH_TO_CUDA(
"expand_into_jagged_permute",
fbgemm_gpu::expand_into_jagged_permute_cuda);
DISPATCH_TO_CUDA(
"block_bucketize_sparse_features",
fbgemm_gpu::block_bucketize_sparse_features_cuda);
Expand Down
71 changes: 71 additions & 0 deletions fbgemm_gpu/test/sparse_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,77 @@ def permute_indices_ref_(

return permuted_lengths, permuted_indices, permuted_weights

@staticmethod
def expand_into_jagged_permute_ref_(
permute: List[int],
length: List[int],
) -> List[int]:
offsets = [0] + list(itertools.accumulate(length))
output_permute = []
for r in permute:
output_permute.extend(
range(
offsets[r],
offsets[r + 1],
)
)
return output_permute

# pyre-ignore [56]: Invalid decoration, was not able to infer the type of argument
@given(
T=st.integers(min_value=10, max_value=20),
W=st.integers(min_value=8, max_value=128),
)
@settings(verbosity=Verbosity.verbose, max_examples=10, deadline=None)
def test_expand_into_jagged_permute(
self,
T: int,
W: int,
) -> None:
length_per_w = [random.randint(10000, 20000) for i in range(W)]
length_1d = list(
itertools.chain.from_iterable(itertools.repeat(x, T) for x in length_per_w)
)
permute_list = list(range(T * W))
random.shuffle(permute_list)
permuted_length_1d = [length_1d[r] for r in permute_list]
permute_tensor = torch.tensor(permute_list)

# compute offsets
offsets_1d = [0] + list(itertools.accumulate(length_1d))
permuted_offsets_1d = [0] + list(itertools.accumulate(permuted_length_1d))
offsets_1d_tensor = torch.tensor(offsets_1d)
permuted_offsets_1d_tensor = torch.tensor(permuted_offsets_1d)

# cpu op
output_permute_cpu = torch.ops.fbgemm.expand_into_jagged_permute(
permute_tensor,
offsets_1d_tensor,
permuted_offsets_1d_tensor,
offsets_1d[-1],
)

# reference solution
output_permute_ref = self.expand_into_jagged_permute_ref_(
permute_list,
length_1d,
)
output_permute_ref_tensor = torch.tensor(output_permute_ref)

# assert cpu and gpu ops
torch.testing.assert_allclose(output_permute_cpu, output_permute_ref_tensor)
if gpu_available:
# gpu op
output_permute_gpu = torch.ops.fbgemm.expand_into_jagged_permute(
permute_tensor.cuda(),
offsets_1d_tensor.cuda(),
permuted_offsets_1d_tensor.cuda(),
offsets_1d[-1],
)
torch.testing.assert_allclose(
output_permute_gpu.cpu(), output_permute_ref_tensor
)

# pyre-ignore [56]: Invalid decoration, was not able to infer the type of argument
@given(
B=st.integers(min_value=1, max_value=20),
Expand Down

0 comments on commit 073ea44

Please sign in to comment.