Skip to content

Commit

Permalink
format sparse_ops.cu (pytorch#604)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#604

To prepare D28312997 and D28141945

Reviewed By: jianyuh

Differential Revision: D28312975

fbshipit-source-id: 8d96a7a846bb9bfe0c8864b64bcbdd3c0c42a23d
  • Loading branch information
jspark1105 authored and facebook-github-bot committed May 10, 2021
1 parent 580d637 commit 566d74c
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 40 deletions.
25 changes: 21 additions & 4 deletions fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -552,35 +552,51 @@ def forward(

if self.optimizer == OptimType.ADAM:
return invokers.lookup_adam.invoke(
common_args,
self.optimizer_args,
momentum1,
momentum2,
# pyre-fixme[29]:
# `Union[BoundMethod[typing.Callable(Tensor.item)[[Named(self,
# Tensor)], typing.Union[float, int]], Tensor], Tensor, nn.Module]` is
# not a function.
common_args, self.optimizer_args, momentum1, momentum2, self.iter.item()
self.iter.item(),
)
if self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM:
return invokers.lookup_partial_rowwise_adam.invoke(
common_args,
self.optimizer_args,
momentum1,
momentum2,
# pyre-fixme[29]:
# `Union[BoundMethod[typing.Callable(Tensor.item)[[Named(self,
# Tensor)], typing.Union[float, int]], Tensor], Tensor, nn.Module]` is
# not a function.
common_args, self.optimizer_args, momentum1, momentum2, self.iter.item()
self.iter.item(),
)
if self.optimizer == OptimType.LAMB:
return invokers.lookup_lamb.invoke(
common_args,
self.optimizer_args,
momentum1,
momentum2,
# pyre-fixme[29]:
# `Union[BoundMethod[typing.Callable(Tensor.item)[[Named(self,
# Tensor)], typing.Union[float, int]], Tensor], Tensor, nn.Module]` is
# not a function.
common_args, self.optimizer_args, momentum1, momentum2, self.iter.item()
self.iter.item(),
)
if self.optimizer == OptimType.PARTIAL_ROWWISE_LAMB:
return invokers.lookup_partial_rowwise_lamb.invoke(
common_args,
self.optimizer_args,
momentum1,
momentum2,
# pyre-fixme[29]:
# `Union[BoundMethod[typing.Callable(Tensor.item)[[Named(self,
# Tensor)], typing.Union[float, int]], Tensor], Tensor, nn.Module]` is
# not a function.
common_args, self.optimizer_args, momentum1, momentum2, self.iter.item()
self.iter.item(),
)

raise ValueError(f"Invalid OptimType: {self.optimizer}")
Expand Down Expand Up @@ -1055,6 +1071,7 @@ class DenseTableBatchedEmbeddingBagsCodegen(nn.Module):
"""
Table-batched version of nn.EmbeddingBag(sparse=False)
"""

weights: Tensor
weights_offsets: Tensor
D_offsets: Tensor
Expand Down
71 changes: 35 additions & 36 deletions fbgemm_gpu/src/sparse_ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ std::tuple<Tensor, Tensor, c10::optional<Tensor>> permute_sparse_data_cuda(
// repetitions
const auto T = permute.numel();
const auto T_ = lengths.size(0);
const auto B = lengths.view({ lengths.sizes()[0], -1 }).sizes()[1];
const auto B = lengths.view({lengths.sizes()[0], -1}).sizes()[1];

Tensor permuted_lengths;
Tensor permuted_indices;
Expand All @@ -88,8 +88,7 @@ std::tuple<Tensor, Tensor, c10::optional<Tensor>> permute_sparse_data_cuda(
permuted_lengths = at::empty({T, B}, lengths.options());

constexpr int32_t threads_1 = 256;
const auto blocks_1 =
cuda_calc_xblock_count(B * T, threads_1);
const auto blocks_1 = cuda_calc_xblock_count(B * T, threads_1);
AT_DISPATCH_INDEX_TYPES(
lengths.scalar_type(), "permute_lengths_kernel", ([&] {
permute_lengths_kernel<index_t>
Expand Down Expand Up @@ -120,46 +119,46 @@ std::tuple<Tensor, Tensor, c10::optional<Tensor>> permute_sparse_data_cuda(

constexpr int32_t BT_blocks = 32;
dim3 threads_2(32, BT_blocks);
const auto blocks_2 =
cuda_calc_xblock_count(B * T, BT_blocks);
const auto blocks_2 = cuda_calc_xblock_count(B * T, BT_blocks);
permuted_indices = at::empty(permuted_lengths_sum, indices.options());

AT_DISPATCH_INDEX_TYPES(
input_offsets.scalar_type(), "permute_data_kernel_1", ([&] {
using offsets_t = index_t;
AT_DISPATCH_ALL_TYPES(
indices.scalar_type(), "permute_data_kernel_2", ([&] {
using indices_t = scalar_t;
if (weights.has_value()) {
const Tensor weights_value = weights.value();
const auto weights_value_contig = weights_value.contiguous();
permuted_weights = at::empty(permuted_lengths_sum, weights_value.options());
AT_DISPATCH_FLOATING_TYPES(
weights_value.scalar_type(), "permute_data_kernel_3", ([&] {
using weights_t = scalar_t;
permute_data_kernel<true, offsets_t, indices_t, weights_t>
<<<blocks_2,
threads_2,
0,
at::cuda::getCurrentCUDAStream()>>>(
permuted_lengths_sum,
T,
B,
indices_contig.data_ptr<indices_t>(),
weights_value_contig.data_ptr<weights_t>(),
permute_contig.data_ptr<int32_t>(),
input_offsets.data_ptr<offsets_t>(),
output_offsets.data_ptr<offsets_t>(),
permuted_indices.data_ptr<indices_t>(),
permuted_weights.data_ptr<weights_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
})); // for each weights_t
indices.scalar_type(), "permute_data_kernel_2", ([&] {
using indices_t = scalar_t;
if (weights.has_value()) {
const Tensor weights_value = weights.value();
const auto weights_value_contig = weights_value.contiguous();
permuted_weights =
at::empty(permuted_lengths_sum, weights_value.options());
AT_DISPATCH_FLOATING_TYPES(
weights_value.scalar_type(), "permute_data_kernel_3", ([&] {
using weights_t = scalar_t;
permute_data_kernel<true, offsets_t, indices_t, weights_t>
<<<blocks_2,
threads_2,
0,
at::cuda::getCurrentCUDAStream()>>>(
permuted_lengths_sum,
T,
B,
indices_contig.data_ptr<indices_t>(),
weights_value_contig.data_ptr<weights_t>(),
permute_contig.data_ptr<int32_t>(),
input_offsets.data_ptr<offsets_t>(),
output_offsets.data_ptr<offsets_t>(),
permuted_indices.data_ptr<indices_t>(),
permuted_weights.data_ptr<weights_t>());
C10_CUDA_KERNEL_LAUNCH_CHECK();
})); // for each weights_t
} else {
permute_data_kernel<false, offsets_t, indices_t, std::nullptr_t>
<<<blocks_2,
threads_2,
0,
at::cuda::getCurrentCUDAStream()>>>(
threads_2,
0,
at::cuda::getCurrentCUDAStream()>>>(
permuted_lengths_sum,
T,
B,
Expand All @@ -172,8 +171,8 @@ std::tuple<Tensor, Tensor, c10::optional<Tensor>> permute_sparse_data_cuda(
nullptr);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
})); // for each indices_t
})); // for each offsets_t
})); // for each indices_t
})); // for each offsets_t
return {permuted_lengths, permuted_indices, permuted_weights};
}
Expand Down

0 comments on commit 566d74c

Please sign in to comment.