Skip to content

Commit

Permalink
Add support for int32_t indices in TBE training (2B/N) (pytorch#3371)
Browse files Browse the repository at this point in the history
Summary:

X-link: facebookresearch/FBGEMM#462

- Fix `hash_size_cumsum` to be `int64_t` in `linearize_index_index_select_kernel` and `linearize_index_kernel`
- Enable more specializations for radix sort

Reviewed By: r-barnes

Differential Revision: D65923591
  • Loading branch information
q10 authored and facebook-github-bot committed Dec 14, 2024
1 parent c932a35 commit d90e68b
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 21 deletions.
8 changes: 6 additions & 2 deletions fbgemm_gpu/include/fbgemm_gpu/split_embeddings_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,13 @@ transpose_embedding_input(
int end_bit = sizeof(KeyT) * 8, \
cudaStream_t stream = 0)

DECL_RADIX_SORT_PAIRS_FN(int64_t, int32_t);
DECL_RADIX_SORT_PAIRS_FN(int64_t, int64_t);
DECL_RADIX_SORT_PAIRS_FN(int64_t, float);
DECL_RADIX_SORT_PAIRS_FN(int64_t, double);
DECL_RADIX_SORT_PAIRS_FN(int64_t, int64_t);
DECL_RADIX_SORT_PAIRS_FN(int64_t, int32_t);
DECL_RADIX_SORT_PAIRS_FN(int32_t, int32_t);
DECL_RADIX_SORT_PAIRS_FN(int32_t, int64_t);
DECL_RADIX_SORT_PAIRS_FN(int32_t, float);
DECL_RADIX_SORT_PAIRS_FN(int32_t, double);

#undef DECL_RADIX_SORT_PAIRS_FN
8 changes: 6 additions & 2 deletions fbgemm_gpu/src/split_embeddings_utils/radix_sort_pairs.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,11 @@ using namespace fbgemm_gpu;
}
#endif

DEF_RADIX_SORT_PAIRS_FN(int64_t, int32_t);
DEF_RADIX_SORT_PAIRS_FN(int64_t, int64_t);
DEF_RADIX_SORT_PAIRS_FN(int64_t, float);
DEF_RADIX_SORT_PAIRS_FN(int64_t, double);
DEF_RADIX_SORT_PAIRS_FN(int64_t, int64_t);
DEF_RADIX_SORT_PAIRS_FN(int64_t, int32_t);
DEF_RADIX_SORT_PAIRS_FN(int32_t, int32_t);
DEF_RADIX_SORT_PAIRS_FN(int32_t, int64_t);
DEF_RADIX_SORT_PAIRS_FN(int32_t, float);
DEF_RADIX_SORT_PAIRS_FN(int32_t, double);
32 changes: 15 additions & 17 deletions fbgemm_gpu/src/split_embeddings_utils/transpose_embedding_input.cu
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ inline at::Tensor asynchronous_complete_cumsum(at::Tensor t_in) {

template <typename index_t, typename info_acc_t, bool nobag, bool vbe>
__global__ __launch_bounds__(kMaxThreads) void linearize_index_kernel(
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
hash_size_cumsum,
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
indices,
Expand All @@ -79,7 +79,7 @@ __global__ __launch_bounds__(kMaxThreads) void linearize_index_kernel(
// Use a raw pointer to avoid creating dummy PackedTensorAccessor
const uint32_t* const __restrict__ vbe_b_t_map,
FixedDivisor fd) {
const int32_t T = hash_size_cumsum.size(0) - 1;
const auto T = hash_size_cumsum.size(0) - 1;
auto b_t = blockIdx.x * blockDim.x + threadIdx.x;
int32_t b;
int32_t t;
Expand All @@ -97,21 +97,20 @@ __global__ __launch_bounds__(kMaxThreads) void linearize_index_kernel(
}

const index_t hash_offset = valid ? hash_size_cumsum[t] : -1;
const index_t indices_start = valid ? offsets[b_t] : -1;
const int32_t L = valid ? offsets[b_t + 1] - indices_start : 0;
const auto indices_start = valid ? offsets[b_t] : -1;
const auto L = valid ? offsets[b_t + 1] - indices_start : 0;
const int32_t lane_id = threadIdx.x % fbgemm_gpu::kWarpSize;

// Compile-time conditional
if (nobag) {
for (int32_t j = 0; j < fbgemm_gpu::kWarpSize; ++j) {
const index_t indices_start_warp =
fbgemm_gpu::shfl_sync(indices_start, j);
const int32_t t_warp = fbgemm_gpu::shfl_sync(t, j);
const int32_t L_warp = fbgemm_gpu::shfl_sync(L, j);
const auto indices_start_warp = fbgemm_gpu::shfl_sync(indices_start, j);
const auto t_warp = fbgemm_gpu::shfl_sync(t, j);
const auto L_warp = fbgemm_gpu::shfl_sync(L, j);
const index_t hash_offset_warp = fbgemm_gpu::shfl_sync(hash_offset, j);
for (int32_t i = lane_id; i < L_warp; i += fbgemm_gpu::kWarpSize) {
const index_t idx = __ldg(&indices[indices_start_warp + i]);
const int64_t l_t = (indices_start_warp + i) * T + t_warp;
const auto l_t = (indices_start_warp + i) * T + t_warp;
infos[indices_start_warp + i] = l_t;
linear_indices[indices_start_warp + i] = hash_offset_warp + idx;
}
Expand All @@ -124,10 +123,9 @@ __global__ __launch_bounds__(kMaxThreads) void linearize_index_kernel(
reinterpret_cast<uint32_t*>(&b)[0];
}
for (int32_t j = 0; j < fbgemm_gpu::kWarpSize; ++j) {
const index_t indices_start_warp =
fbgemm_gpu::shfl_sync(indices_start, j);
const uint32_t info_warp = fbgemm_gpu::shfl_sync(info, j);
const int32_t L_warp = fbgemm_gpu::shfl_sync(L, j);
const auto indices_start_warp = fbgemm_gpu::shfl_sync(indices_start, j);
const auto info_warp = fbgemm_gpu::shfl_sync(info, j);
const auto L_warp = fbgemm_gpu::shfl_sync(L, j);
const index_t hash_offset_warp = fbgemm_gpu::shfl_sync(hash_offset, j);
for (int32_t i = lane_id; i < L_warp; i += fbgemm_gpu::kWarpSize) {
const index_t idx = __ldg(&indices[indices_start_warp + i]);
Expand All @@ -142,7 +140,7 @@ __global__ __launch_bounds__(kMaxThreads) void linearize_index_kernel(
template <typename index_t, typename info_acc_t>
__global__
__launch_bounds__(kMaxThreads) void linearize_index_index_select_kernel(
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
const pta::PackedTensorAccessor32<int64_t, 1, at::RestrictPtrTraits>
hash_size_cumsum,
const pta::PackedTensorAccessor32<index_t, 1, at::RestrictPtrTraits>
indices,
Expand All @@ -153,7 +151,7 @@ __launch_bounds__(kMaxThreads) void linearize_index_index_select_kernel(
linear_indices,
FixedDivisor fd,
int32_t fixed_L_per_warp) {
const int32_t T = hash_size_cumsum.size(0) - 1;
const auto T = hash_size_cumsum.size(0) - 1;
auto b_t = blockIdx.x * blockDim.x + threadIdx.x;
int32_t b;
int32_t t;
Expand Down Expand Up @@ -258,7 +256,7 @@ transpose_embedding_input(
kMaxThreads, \
0, \
at::cuda::getCurrentCUDAStream()>>>( \
MAKE_PTA_WITH_NAME(func_name, hash_size_cumsum, index_t, 1, 32), \
MAKE_PTA_WITH_NAME(func_name, hash_size_cumsum, int64_t, 1, 32), \
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32), \
MAKE_PTA_WITH_NAME(func_name, offsets, index_t, 1, 32), \
MAKE_PTA_WITH_NAME(func_name, infos, INFO_ACC_T, 1, 32), \
Expand Down Expand Up @@ -296,7 +294,7 @@ transpose_embedding_input(
0,
at::cuda::getCurrentCUDAStream()>>>(
MAKE_PTA_WITH_NAME(
func_name, hash_size_cumsum, index_t, 1, 32),
func_name, hash_size_cumsum, int64_t, 1, 32),
MAKE_PTA_WITH_NAME(func_name, indices, index_t, 1, 32),
MAKE_PTA_WITH_NAME(
func_name, total_L_offsets.value(), index_t, 1, 32),
Expand Down

0 comments on commit d90e68b

Please sign in to comment.