diff --git a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu index 5029a382a..3199a1b00 100644 --- a/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu +++ b/fbgemm_gpu/codegen/training/backward/embedding_backward_split_template.cu @@ -821,6 +821,7 @@ Tensor {{ embedding_cuda_op }}( else { {{ locs_or_addrs_tensor }}_sorted = at::empty_like({{ locs_or_addrs_tensor }}); size_t temp_storage_bytes = 0; + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ embedding_cuda_op }}_1", [&] { AT_CUDA_CHECK(radix_sort_pairs( nullptr, temp_storage_bytes, @@ -832,9 +833,11 @@ Tensor {{ embedding_cuda_op }}( 0, total_hash_size_bits, at::cuda::getCurrentCUDAStream())); + auto temp_storage = at::empty( {static_cast(temp_storage_bytes)}, indices.options().dtype(at::kByte)); + AT_CUDA_CHECK(radix_sort_pairs( temp_storage.data_ptr(), temp_storage_bytes, @@ -846,6 +849,7 @@ Tensor {{ embedding_cuda_op }}( 0, total_hash_size_bits, at::cuda::getCurrentCUDAStream())); + }); } } @@ -865,6 +869,7 @@ Tensor {{ embedding_cuda_op }}( %} {%- endif %} + AT_DISPATCH_INDEX_TYPES(indices.scalar_type(), "{{ embedding_cuda_op }}_2", [&] { DISPATCH_EMB_GRAD_CACHE_TYPES( dev_weights.scalar_type(), aligned_grad_output.scalar_type(), @@ -890,9 +895,11 @@ Tensor {{ embedding_cuda_op }}( 0, total_hash_size_bits, at::cuda::getCurrentCUDAStream())); + auto temp_storage = at::empty( {static_cast(temp_storage_bytes)}, indices.options().dtype(at::kByte)); + AT_CUDA_CHECK(radix_sort_pairs( temp_storage.data_ptr(), temp_storage_bytes, @@ -1308,6 +1315,7 @@ Tensor {{ embedding_cuda_op }}( }); // DISPATCH_OPTIMAL_KERNEL }); // DISPATCH_EMB_GRAD_CACHE_TYPES + }); // AT_DISPATCH_INDEX_TYPES {%- if dense %} return grad_dev_weights;