diff --git a/fbgemm_gpu/src/jagged_tensor_ops.cu b/fbgemm_gpu/src/jagged_tensor_ops.cu index 09def83cf9..7eb771d383 100644 --- a/fbgemm_gpu/src/jagged_tensor_ops.cu +++ b/fbgemm_gpu/src/jagged_tensor_ops.cu @@ -1156,26 +1156,26 @@ stacked_jagged_2d_to_dense_forward_cuda( size_t temp_storage_bytes = 0; auto offsets = at::empty({B + 1}, lengths.options()); offsets[0].zero_(); - AT_DISPATCH_INTEGRAL_TYPES( + AT_DISPATCH_INDEX_TYPES( lengths_contig.scalar_type(), "cub_inclusive_sum_wrapper1", [&] { AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum( nullptr, temp_storage_bytes, - &(lengths_contig.data_ptr()[t * B]), - offsets.data_ptr() + 1, + &(lengths_contig.data_ptr()[t * B]), + offsets.data_ptr() + 1, B, at::cuda::getCurrentCUDAStream())); }); auto temp_storage = at::empty( {static_cast(temp_storage_bytes)}, lengths.options().dtype(at::kByte)); - AT_DISPATCH_INTEGRAL_TYPES( + AT_DISPATCH_INDEX_TYPES( lengths_contig.scalar_type(), "cub_inclusive_sum_wrapper2", [&] { AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum( temp_storage.data_ptr(), temp_storage_bytes, - &(lengths_contig.data_ptr()[t * B]), - offsets.data_ptr() + 1, + &(lengths_contig.data_ptr()[t * B]), + offsets.data_ptr() + 1, B, at::cuda::getCurrentCUDAStream())); }); @@ -1248,26 +1248,26 @@ std::vector stacked_jagged_1d_to_dense_gpu( for (int32_t t = 0; t < T; t++) { int64_t max_L = max_lengths_per_key[t]; size_t temp_storage_bytes = 0; - AT_DISPATCH_INTEGRAL_TYPES( + AT_DISPATCH_INDEX_TYPES( lengths_contig.scalar_type(), "cub_inclusive_sum_wrapper1", [&] { AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum( nullptr, temp_storage_bytes, - &(lengths_contig.data_ptr()[t * B]), - offsets.data_ptr() + 1, + &(lengths_contig.data_ptr()[t * B]), + offsets.data_ptr() + 1, B, at::cuda::getCurrentCUDAStream())); }); auto temp_storage = at::empty( {static_cast(temp_storage_bytes)}, lengths.options().dtype(at::kByte)); - AT_DISPATCH_INTEGRAL_TYPES( + AT_DISPATCH_INDEX_TYPES( lengths_contig.scalar_type(), "cub_inclusive_sum_wrapper2", [&] { AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum( temp_storage.data_ptr(), temp_storage_bytes, - &(lengths_contig.data_ptr()[t * B]), - offsets.data_ptr() + 1, + &(lengths_contig.data_ptr()[t * B]), + offsets.data_ptr() + 1, B, at::cuda::getCurrentCUDAStream())); }); diff --git a/fbgemm_gpu/src/sparse_ops.cu b/fbgemm_gpu/src/sparse_ops.cu index 56104f540f..d7f1bc5b0f 100644 --- a/fbgemm_gpu/src/sparse_ops.cu +++ b/fbgemm_gpu/src/sparse_ops.cu @@ -197,26 +197,26 @@ Tensor asynchronous_inclusive_cumsum_gpu(const Tensor& t_in) { // CUB only handles up to INT_MAX elements. TORCH_CHECK(t_in.numel() < std::numeric_limits::max()); auto t_out = at::empty_like(t_in); - AT_DISPATCH_INTEGRAL_TYPES( + AT_DISPATCH_INDEX_TYPES( t_in.scalar_type(), "cub_inclusive_sum_wrapper1", [&] { AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum( nullptr, temp_storage_bytes, - t_in.data_ptr(), - t_out.data_ptr(), + t_in.data_ptr(), + t_out.data_ptr(), t_in.numel(), at::cuda::getCurrentCUDAStream())); }); auto temp_storage = at::empty( {static_cast(temp_storage_bytes)}, t_in.options().dtype(at::kByte)); - AT_DISPATCH_INTEGRAL_TYPES( + AT_DISPATCH_INDEX_TYPES( t_in.scalar_type(), "cub_inclusive_sum_wrapper2", [&] { AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum( temp_storage.data_ptr(), temp_storage_bytes, - t_in.data_ptr(), - t_out.data_ptr(), + t_in.data_ptr(), + t_out.data_ptr(), t_in.numel(), at::cuda::getCurrentCUDAStream())); }); @@ -234,26 +234,26 @@ Tensor asynchronous_exclusive_cumsum_gpu(const Tensor& t_in) { // CUB only handles up to INT_MAX elements. TORCH_CHECK(t_in.numel() < std::numeric_limits::max()); auto t_out = at::empty_like(t_in); - AT_DISPATCH_INTEGRAL_TYPES( + AT_DISPATCH_INDEX_TYPES( t_in.scalar_type(), "cub_exclusive_sum_wrapper1", [&] { AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::ExclusiveSum( nullptr, temp_storage_bytes, - t_in.data_ptr(), - t_out.data_ptr(), + t_in.data_ptr(), + t_out.data_ptr(), t_in.numel(), at::cuda::getCurrentCUDAStream())); }); auto temp_storage = at::empty( {static_cast(temp_storage_bytes)}, t_in.options().dtype(at::kByte)); - AT_DISPATCH_INTEGRAL_TYPES( + AT_DISPATCH_INDEX_TYPES( t_in.scalar_type(), "cub_exclusive_sum_wrapper2", [&] { AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::ExclusiveSum( temp_storage.data_ptr(), temp_storage_bytes, - t_in.data_ptr(), - t_out.data_ptr(), + t_in.data_ptr(), + t_out.data_ptr(), t_in.numel(), at::cuda::getCurrentCUDAStream())); }); @@ -273,26 +273,26 @@ Tensor asynchronous_complete_cumsum_gpu(const Tensor& t_in) { TORCH_CHECK(t_in.dim() == 1); auto t_out = at::empty({t_in.numel() + 1}, t_in.options()); t_out[0].zero_(); - AT_DISPATCH_INTEGRAL_TYPES( + AT_DISPATCH_INDEX_TYPES( t_in.scalar_type(), "cub_inclusive_sum_wrapper1", [&] { AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum( nullptr, temp_storage_bytes, - t_in.data_ptr(), - t_out.data_ptr() + 1, + t_in.data_ptr(), + t_out.data_ptr() + 1, t_in.numel(), at::cuda::getCurrentCUDAStream())); }); auto temp_storage = at::empty( {static_cast(temp_storage_bytes)}, t_in.options().dtype(at::kByte)); - AT_DISPATCH_INTEGRAL_TYPES( + AT_DISPATCH_INDEX_TYPES( t_in.scalar_type(), "cub_inclusive_sum_wrapper2", [&] { AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum( temp_storage.data_ptr(), temp_storage_bytes, - t_in.data_ptr(), - t_out.data_ptr() + 1, + t_in.data_ptr(), + t_out.data_ptr() + 1, t_in.numel(), at::cuda::getCurrentCUDAStream())); }); diff --git a/fbgemm_gpu/src/split_embeddings_utils.cu b/fbgemm_gpu/src/split_embeddings_utils.cu index 437f185317..f5dd9ec04b 100644 --- a/fbgemm_gpu/src/split_embeddings_utils.cu +++ b/fbgemm_gpu/src/split_embeddings_utils.cu @@ -29,26 +29,26 @@ inline at::Tensor asynchronous_complete_cumsum(at::Tensor t_in) { TORCH_CHECK(t_in.dim() == 1); auto t_out = at::empty({t_in.numel() + 1}, t_in.options()); t_out[0].zero_(); - AT_DISPATCH_INTEGRAL_TYPES( + AT_DISPATCH_INDEX_TYPES( t_in.scalar_type(), "cub_inclusive_sum_wrapper1", [&] { AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum( nullptr, temp_storage_bytes, - t_in.data_ptr(), - t_out.data_ptr() + 1, + t_in.data_ptr(), + t_out.data_ptr() + 1, t_in.numel(), at::cuda::getCurrentCUDAStream())); }); auto temp_storage = at::empty( {static_cast(temp_storage_bytes)}, t_in.options().dtype(at::kByte)); - AT_DISPATCH_INTEGRAL_TYPES( + AT_DISPATCH_INDEX_TYPES( t_in.scalar_type(), "cub_inclusive_sum_wrapper2", [&] { AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum( temp_storage.data_ptr(), temp_storage_bytes, - t_in.data_ptr(), - t_out.data_ptr() + 1, + t_in.data_ptr(), + t_out.data_ptr() + 1, t_in.numel(), at::cuda::getCurrentCUDAStream())); });