Skip to content

Commit

Permalink
Add support for int32_t indices in TBE training (2C/N) (pytorch#3372)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#465


- Add `index_t` support to TBE training backward kernels

Differential Revision: D65925354
  • Loading branch information
q10 authored and facebook-github-bot committed Dec 14, 2024
1 parent d90e68b commit 71f02ee
Showing 1 changed file with 8 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -742,6 +742,7 @@ Tensor {{ embedding_cuda_op }}(
else {
{{ locs_or_addrs_tensor }}_sorted = at::empty_like({{ locs_or_addrs_tensor }});
size_t temp_storage_bytes = 0;
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ embedding_cuda_op }}_1", [&] {
AT_CUDA_CHECK(radix_sort_pairs(
nullptr,
temp_storage_bytes,
Expand All @@ -753,9 +754,11 @@ Tensor {{ embedding_cuda_op }}(
0,
total_hash_size_bits,
at::cuda::getCurrentCUDAStream()));
auto temp_storage = at::empty(
{static_cast<int64_t>(temp_storage_bytes)},
indices.options().dtype(at::kByte));
AT_CUDA_CHECK(radix_sort_pairs(
temp_storage.data_ptr(),
temp_storage_bytes,
Expand All @@ -767,6 +770,7 @@ Tensor {{ embedding_cuda_op }}(
0,
total_hash_size_bits,
at::cuda::getCurrentCUDAStream()));
});
}
}
Expand All @@ -775,6 +779,7 @@ Tensor {{ embedding_cuda_op }}(
}
{%- endif %}
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ embedding_cuda_op }}_2", [&] {
DISPATCH_EMB_GRAD_CACHE_TYPES(
dev_weights.scalar_type(),
aligned_grad_output.scalar_type(),
Expand All @@ -800,9 +805,11 @@ Tensor {{ embedding_cuda_op }}(
0,
total_hash_size_bits,
at::cuda::getCurrentCUDAStream()));
auto temp_storage = at::empty(
{static_cast<int64_t>(temp_storage_bytes)},
indices.options().dtype(at::kByte));
AT_CUDA_CHECK(radix_sort_pairs(
temp_storage.data_ptr(),
temp_storage_bytes,
Expand Down Expand Up @@ -1181,6 +1188,7 @@ Tensor {{ embedding_cuda_op }}(
}); // DISPATCH_OPTIMAL_KERNEL
}); // DISPATCH_EMB_GRAD_CACHE_TYPES
}); // AT_DISPATCH_INDEX_TYPES
{%- if dense %}
return grad_dev_weights;
Expand Down

0 comments on commit 71f02ee

Please sign in to comment.