Skip to content

Commit

Permalink
Change AT_DISPATCH_INTEGRAL_TYPES to AT_DISPATCH_INDEX_TYPES to reduc…
Browse files Browse the repository at this point in the history
…e dispatch types (pytorch#1040)

Summary:
Pull Request resolved: pytorch#1040

As title

Reviewed By: jasonjk-park

Differential Revision: D35531965

fbshipit-source-id: cb81c411cbb8706cd6bc50494a39f6285a032e22
  • Loading branch information
jianyuh authored and facebook-github-bot committed Apr 11, 2022
1 parent 64c9c39 commit 9147ea2
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 36 deletions.
24 changes: 12 additions & 12 deletions fbgemm_gpu/src/jagged_tensor_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1156,26 +1156,26 @@ stacked_jagged_2d_to_dense_forward_cuda(
size_t temp_storage_bytes = 0;
auto offsets = at::empty({B + 1}, lengths.options());
offsets[0].zero_();
AT_DISPATCH_INTEGRAL_TYPES(
AT_DISPATCH_INDEX_TYPES(
lengths_contig.scalar_type(), "cub_inclusive_sum_wrapper1", [&] {
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
nullptr,
temp_storage_bytes,
&(lengths_contig.data_ptr<scalar_t>()[t * B]),
offsets.data_ptr<scalar_t>() + 1,
&(lengths_contig.data_ptr<index_t>()[t * B]),
offsets.data_ptr<index_t>() + 1,
B,
at::cuda::getCurrentCUDAStream()));
});
auto temp_storage = at::empty(
{static_cast<int64_t>(temp_storage_bytes)},
lengths.options().dtype(at::kByte));
AT_DISPATCH_INTEGRAL_TYPES(
AT_DISPATCH_INDEX_TYPES(
lengths_contig.scalar_type(), "cub_inclusive_sum_wrapper2", [&] {
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
temp_storage.data_ptr(),
temp_storage_bytes,
&(lengths_contig.data_ptr<scalar_t>()[t * B]),
offsets.data_ptr<scalar_t>() + 1,
&(lengths_contig.data_ptr<index_t>()[t * B]),
offsets.data_ptr<index_t>() + 1,
B,
at::cuda::getCurrentCUDAStream()));
});
Expand Down Expand Up @@ -1248,26 +1248,26 @@ std::vector<Tensor> stacked_jagged_1d_to_dense_gpu(
for (int32_t t = 0; t < T; t++) {
int64_t max_L = max_lengths_per_key[t];
size_t temp_storage_bytes = 0;
AT_DISPATCH_INTEGRAL_TYPES(
AT_DISPATCH_INDEX_TYPES(
lengths_contig.scalar_type(), "cub_inclusive_sum_wrapper1", [&] {
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
nullptr,
temp_storage_bytes,
&(lengths_contig.data_ptr<scalar_t>()[t * B]),
offsets.data_ptr<scalar_t>() + 1,
&(lengths_contig.data_ptr<index_t>()[t * B]),
offsets.data_ptr<index_t>() + 1,
B,
at::cuda::getCurrentCUDAStream()));
});
auto temp_storage = at::empty(
{static_cast<int64_t>(temp_storage_bytes)},
lengths.options().dtype(at::kByte));
AT_DISPATCH_INTEGRAL_TYPES(
AT_DISPATCH_INDEX_TYPES(
lengths_contig.scalar_type(), "cub_inclusive_sum_wrapper2", [&] {
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
temp_storage.data_ptr(),
temp_storage_bytes,
&(lengths_contig.data_ptr<scalar_t>()[t * B]),
offsets.data_ptr<scalar_t>() + 1,
&(lengths_contig.data_ptr<index_t>()[t * B]),
offsets.data_ptr<index_t>() + 1,
B,
at::cuda::getCurrentCUDAStream()));
});
Expand Down
36 changes: 18 additions & 18 deletions fbgemm_gpu/src/sparse_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -197,26 +197,26 @@ Tensor asynchronous_inclusive_cumsum_gpu(const Tensor& t_in) {
// CUB only handles up to INT_MAX elements.
TORCH_CHECK(t_in.numel() < std::numeric_limits<int32_t>::max());
auto t_out = at::empty_like(t_in);
AT_DISPATCH_INTEGRAL_TYPES(
AT_DISPATCH_INDEX_TYPES(
t_in.scalar_type(), "cub_inclusive_sum_wrapper1", [&] {
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
nullptr,
temp_storage_bytes,
t_in.data_ptr<scalar_t>(),
t_out.data_ptr<scalar_t>(),
t_in.data_ptr<index_t>(),
t_out.data_ptr<index_t>(),
t_in.numel(),
at::cuda::getCurrentCUDAStream()));
});
auto temp_storage = at::empty(
{static_cast<int64_t>(temp_storage_bytes)},
t_in.options().dtype(at::kByte));
AT_DISPATCH_INTEGRAL_TYPES(
AT_DISPATCH_INDEX_TYPES(
t_in.scalar_type(), "cub_inclusive_sum_wrapper2", [&] {
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
temp_storage.data_ptr(),
temp_storage_bytes,
t_in.data_ptr<scalar_t>(),
t_out.data_ptr<scalar_t>(),
t_in.data_ptr<index_t>(),
t_out.data_ptr<index_t>(),
t_in.numel(),
at::cuda::getCurrentCUDAStream()));
});
Expand All @@ -234,26 +234,26 @@ Tensor asynchronous_exclusive_cumsum_gpu(const Tensor& t_in) {
// CUB only handles up to INT_MAX elements.
TORCH_CHECK(t_in.numel() < std::numeric_limits<int32_t>::max());
auto t_out = at::empty_like(t_in);
AT_DISPATCH_INTEGRAL_TYPES(
AT_DISPATCH_INDEX_TYPES(
t_in.scalar_type(), "cub_exclusive_sum_wrapper1", [&] {
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::ExclusiveSum(
nullptr,
temp_storage_bytes,
t_in.data_ptr<scalar_t>(),
t_out.data_ptr<scalar_t>(),
t_in.data_ptr<index_t>(),
t_out.data_ptr<index_t>(),
t_in.numel(),
at::cuda::getCurrentCUDAStream()));
});
auto temp_storage = at::empty(
{static_cast<int64_t>(temp_storage_bytes)},
t_in.options().dtype(at::kByte));
AT_DISPATCH_INTEGRAL_TYPES(
AT_DISPATCH_INDEX_TYPES(
t_in.scalar_type(), "cub_exclusive_sum_wrapper2", [&] {
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::ExclusiveSum(
temp_storage.data_ptr(),
temp_storage_bytes,
t_in.data_ptr<scalar_t>(),
t_out.data_ptr<scalar_t>(),
t_in.data_ptr<index_t>(),
t_out.data_ptr<index_t>(),
t_in.numel(),
at::cuda::getCurrentCUDAStream()));
});
Expand All @@ -273,26 +273,26 @@ Tensor asynchronous_complete_cumsum_gpu(const Tensor& t_in) {
TORCH_CHECK(t_in.dim() == 1);
auto t_out = at::empty({t_in.numel() + 1}, t_in.options());
t_out[0].zero_();
AT_DISPATCH_INTEGRAL_TYPES(
AT_DISPATCH_INDEX_TYPES(
t_in.scalar_type(), "cub_inclusive_sum_wrapper1", [&] {
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
nullptr,
temp_storage_bytes,
t_in.data_ptr<scalar_t>(),
t_out.data_ptr<scalar_t>() + 1,
t_in.data_ptr<index_t>(),
t_out.data_ptr<index_t>() + 1,
t_in.numel(),
at::cuda::getCurrentCUDAStream()));
});
auto temp_storage = at::empty(
{static_cast<int64_t>(temp_storage_bytes)},
t_in.options().dtype(at::kByte));
AT_DISPATCH_INTEGRAL_TYPES(
AT_DISPATCH_INDEX_TYPES(
t_in.scalar_type(), "cub_inclusive_sum_wrapper2", [&] {
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
temp_storage.data_ptr(),
temp_storage_bytes,
t_in.data_ptr<scalar_t>(),
t_out.data_ptr<scalar_t>() + 1,
t_in.data_ptr<index_t>(),
t_out.data_ptr<index_t>() + 1,
t_in.numel(),
at::cuda::getCurrentCUDAStream()));
});
Expand Down
12 changes: 6 additions & 6 deletions fbgemm_gpu/src/split_embeddings_utils.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,26 +29,26 @@ inline at::Tensor asynchronous_complete_cumsum(at::Tensor t_in) {
TORCH_CHECK(t_in.dim() == 1);
auto t_out = at::empty({t_in.numel() + 1}, t_in.options());
t_out[0].zero_();
AT_DISPATCH_INTEGRAL_TYPES(
AT_DISPATCH_INDEX_TYPES(
t_in.scalar_type(), "cub_inclusive_sum_wrapper1", [&] {
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
nullptr,
temp_storage_bytes,
t_in.data_ptr<scalar_t>(),
t_out.data_ptr<scalar_t>() + 1,
t_in.data_ptr<index_t>(),
t_out.data_ptr<index_t>() + 1,
t_in.numel(),
at::cuda::getCurrentCUDAStream()));
});
auto temp_storage = at::empty(
{static_cast<int64_t>(temp_storage_bytes)},
t_in.options().dtype(at::kByte));
AT_DISPATCH_INTEGRAL_TYPES(
AT_DISPATCH_INDEX_TYPES(
t_in.scalar_type(), "cub_inclusive_sum_wrapper2", [&] {
AT_CUDA_CHECK(FBGEMM_GPU_CUB_NS_PREFIX cub::DeviceScan::InclusiveSum(
temp_storage.data_ptr(),
temp_storage_bytes,
t_in.data_ptr<scalar_t>(),
t_out.data_ptr<scalar_t>() + 1,
t_in.data_ptr<index_t>(),
t_out.data_ptr<index_t>() + 1,
t_in.numel(),
at::cuda::getCurrentCUDAStream()));
});
Expand Down

0 comments on commit 9147ea2

Please sign in to comment.