diff --git a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py index cf6218bfb6..0e958fcb25 100644 --- a/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py +++ b/fbgemm_gpu/fbgemm_gpu/split_table_batched_embeddings_ops.py @@ -552,35 +552,51 @@ def forward( if self.optimizer == OptimType.ADAM: return invokers.lookup_adam.invoke( + common_args, + self.optimizer_args, + momentum1, + momentum2, # pyre-fixme[29]: # `Union[BoundMethod[typing.Callable(Tensor.item)[[Named(self, # Tensor)], typing.Union[float, int]], Tensor], Tensor, nn.Module]` is # not a function. - common_args, self.optimizer_args, momentum1, momentum2, self.iter.item() + self.iter.item(), ) if self.optimizer == OptimType.PARTIAL_ROWWISE_ADAM: return invokers.lookup_partial_rowwise_adam.invoke( + common_args, + self.optimizer_args, + momentum1, + momentum2, # pyre-fixme[29]: # `Union[BoundMethod[typing.Callable(Tensor.item)[[Named(self, # Tensor)], typing.Union[float, int]], Tensor], Tensor, nn.Module]` is # not a function. - common_args, self.optimizer_args, momentum1, momentum2, self.iter.item() + self.iter.item(), ) if self.optimizer == OptimType.LAMB: return invokers.lookup_lamb.invoke( + common_args, + self.optimizer_args, + momentum1, + momentum2, # pyre-fixme[29]: # `Union[BoundMethod[typing.Callable(Tensor.item)[[Named(self, # Tensor)], typing.Union[float, int]], Tensor], Tensor, nn.Module]` is # not a function. - common_args, self.optimizer_args, momentum1, momentum2, self.iter.item() + self.iter.item(), ) if self.optimizer == OptimType.PARTIAL_ROWWISE_LAMB: return invokers.lookup_partial_rowwise_lamb.invoke( + common_args, + self.optimizer_args, + momentum1, + momentum2, # pyre-fixme[29]: # `Union[BoundMethod[typing.Callable(Tensor.item)[[Named(self, # Tensor)], typing.Union[float, int]], Tensor], Tensor, nn.Module]` is # not a function. - common_args, self.optimizer_args, momentum1, momentum2, self.iter.item() + self.iter.item(), ) raise ValueError(f"Invalid OptimType: {self.optimizer}") @@ -1055,6 +1071,7 @@ class DenseTableBatchedEmbeddingBagsCodegen(nn.Module): """ Table-batched version of nn.EmbeddingBag(sparse=False) """ + weights: Tensor weights_offsets: Tensor D_offsets: Tensor diff --git a/fbgemm_gpu/src/sparse_ops.cu b/fbgemm_gpu/src/sparse_ops.cu index 32a4e623cf..110c9f042f 100644 --- a/fbgemm_gpu/src/sparse_ops.cu +++ b/fbgemm_gpu/src/sparse_ops.cu @@ -79,7 +79,7 @@ std::tuple> permute_sparse_data_cuda( // repetitions const auto T = permute.numel(); const auto T_ = lengths.size(0); - const auto B = lengths.view({ lengths.sizes()[0], -1 }).sizes()[1]; + const auto B = lengths.view({lengths.sizes()[0], -1}).sizes()[1]; Tensor permuted_lengths; Tensor permuted_indices; @@ -88,8 +88,7 @@ std::tuple> permute_sparse_data_cuda( permuted_lengths = at::empty({T, B}, lengths.options()); constexpr int32_t threads_1 = 256; - const auto blocks_1 = - cuda_calc_xblock_count(B * T, threads_1); + const auto blocks_1 = cuda_calc_xblock_count(B * T, threads_1); AT_DISPATCH_INDEX_TYPES( lengths.scalar_type(), "permute_lengths_kernel", ([&] { permute_lengths_kernel @@ -120,46 +119,46 @@ std::tuple> permute_sparse_data_cuda( constexpr int32_t BT_blocks = 32; dim3 threads_2(32, BT_blocks); - const auto blocks_2 = - cuda_calc_xblock_count(B * T, BT_blocks); + const auto blocks_2 = cuda_calc_xblock_count(B * T, BT_blocks); permuted_indices = at::empty(permuted_lengths_sum, indices.options()); AT_DISPATCH_INDEX_TYPES( input_offsets.scalar_type(), "permute_data_kernel_1", ([&] { using offsets_t = index_t; AT_DISPATCH_ALL_TYPES( - indices.scalar_type(), "permute_data_kernel_2", ([&] { - using indices_t = scalar_t; - if (weights.has_value()) { - const Tensor weights_value = weights.value(); - const auto weights_value_contig = weights_value.contiguous(); - permuted_weights = at::empty(permuted_lengths_sum, weights_value.options()); - AT_DISPATCH_FLOATING_TYPES( - weights_value.scalar_type(), "permute_data_kernel_3", ([&] { - using weights_t = scalar_t; - permute_data_kernel - <<>>( - permuted_lengths_sum, - T, - B, - indices_contig.data_ptr(), - weights_value_contig.data_ptr(), - permute_contig.data_ptr(), - input_offsets.data_ptr(), - output_offsets.data_ptr(), - permuted_indices.data_ptr(), - permuted_weights.data_ptr()); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - })); // for each weights_t + indices.scalar_type(), "permute_data_kernel_2", ([&] { + using indices_t = scalar_t; + if (weights.has_value()) { + const Tensor weights_value = weights.value(); + const auto weights_value_contig = weights_value.contiguous(); + permuted_weights = + at::empty(permuted_lengths_sum, weights_value.options()); + AT_DISPATCH_FLOATING_TYPES( + weights_value.scalar_type(), "permute_data_kernel_3", ([&] { + using weights_t = scalar_t; + permute_data_kernel + <<>>( + permuted_lengths_sum, + T, + B, + indices_contig.data_ptr(), + weights_value_contig.data_ptr(), + permute_contig.data_ptr(), + input_offsets.data_ptr(), + output_offsets.data_ptr(), + permuted_indices.data_ptr(), + permuted_weights.data_ptr()); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + })); // for each weights_t } else { permute_data_kernel <<>>( + threads_2, + 0, + at::cuda::getCurrentCUDAStream()>>>( permuted_lengths_sum, T, B, @@ -172,8 +171,8 @@ std::tuple> permute_sparse_data_cuda( nullptr); C10_CUDA_KERNEL_LAUNCH_CHECK(); } - })); // for each indices_t - })); // for each offsets_t + })); // for each indices_t + })); // for each offsets_t return {permuted_lengths, permuted_indices, permuted_weights}; }