Skip to content

Commit

Permalink
Add support for int32_t indices in TBE training (2C/N) (pytorch#3372)
Browse files Browse the repository at this point in the history
Summary:
X-link: facebookresearch/FBGEMM#619

X-link: facebookresearch/FBGEMM#465


- Add `index_t` support to TBE training backward kernels

Differential Revision: D65925354
  • Loading branch information
q10 authored and facebook-github-bot committed Jan 6, 2025
1 parent 794f0e9 commit addc287
Showing 1 changed file with 8 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -821,6 +821,7 @@ Tensor {{ embedding_cuda_op }}(
else {
{{ locs_or_addrs_tensor }}_sorted = at::empty_like({{ locs_or_addrs_tensor }});
size_t temp_storage_bytes = 0;
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ embedding_cuda_op }}_1", [&] {
AT_CUDA_CHECK(radix_sort_pairs(
nullptr,
temp_storage_bytes,
Expand All @@ -832,9 +833,11 @@ Tensor {{ embedding_cuda_op }}(
0,
total_hash_size_bits,
at::cuda::getCurrentCUDAStream()));
auto temp_storage = at::empty(
{static_cast<int64_t>(temp_storage_bytes)},
indices.options().dtype(at::kByte));
AT_CUDA_CHECK(radix_sort_pairs(
temp_storage.data_ptr(),
temp_storage_bytes,
Expand All @@ -846,6 +849,7 @@ Tensor {{ embedding_cuda_op }}(
0,
total_hash_size_bits,
at::cuda::getCurrentCUDAStream()));
});
}
}
Expand All @@ -865,6 +869,7 @@ Tensor {{ embedding_cuda_op }}(
%}
{%- endif %}
AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ embedding_cuda_op }}_2", [&] {
DISPATCH_EMB_GRAD_CACHE_TYPES(
dev_weights.scalar_type(),
aligned_grad_output.scalar_type(),
Expand All @@ -890,9 +895,11 @@ Tensor {{ embedding_cuda_op }}(
0,
total_hash_size_bits,
at::cuda::getCurrentCUDAStream()));
auto temp_storage = at::empty(
{static_cast<int64_t>(temp_storage_bytes)},
indices.options().dtype(at::kByte));
AT_CUDA_CHECK(radix_sort_pairs(
temp_storage.data_ptr(),
temp_storage_bytes,
Expand Down Expand Up @@ -1308,6 +1315,7 @@ Tensor {{ embedding_cuda_op }}(
}); // DISPATCH_OPTIMAL_KERNEL
}); // DISPATCH_EMB_GRAD_CACHE_TYPES
}); // AT_DISPATCH_INDEX_TYPES
{%- if dense %}
return grad_dev_weights;
Expand Down

0 comments on commit addc287

Please sign in to comment.