diff --git a/include/fbgemm/FbgemmEmbedding.h b/include/fbgemm/FbgemmEmbedding.h index 835afec943..3caf531976 100644 --- a/include/fbgemm/FbgemmEmbedding.h +++ b/include/fbgemm/FbgemmEmbedding.h @@ -51,21 +51,7 @@ FBGEMM_API bool EmbeddingSpMDM( bool is_weight_positional = false); template -class EmbeddingSpMDM4BitKernelSignature { - public: - using Type = std::function; -}; - -template -FBGEMM_API typename EmbeddingSpMDM4BitKernelSignature::Type +FBGEMM_API typename EmbeddingSpMDMKernelSignature::Type GenerateEmbeddingSpMDM4Bit( const std::int64_t block_size, bool has_weight, diff --git a/src/EmbeddingSpMDM4Bit.cc b/src/EmbeddingSpMDM4Bit.cc index 1130383bba..531e4f7466 100644 --- a/src/EmbeddingSpMDM4Bit.cc +++ b/src/EmbeddingSpMDM4Bit.cc @@ -570,7 +570,7 @@ GenEmbeddingSpMDM4BitLookup::getOrCreate( } // namespace template -typename EmbeddingSpMDM4BitKernelSignature::Type +typename EmbeddingSpMDMKernelSignature::Type GenerateEmbeddingSpMDM4Bit( const std::int64_t block_size, bool has_weight, @@ -648,21 +648,23 @@ bool EmbeddingSpMDM4Bit( out); } -template typename EmbeddingSpMDM4BitKernelSignature::Type -GenerateEmbeddingSpMDM4Bit( - const std::int64_t block_size, - bool has_weight, - bool normalize_by_lengths, - int prefetch, - bool is_weight_positional); - -template typename EmbeddingSpMDM4BitKernelSignature::Type -GenerateEmbeddingSpMDM4Bit( - const std::int64_t block_size, - bool has_weight, - bool normalize_by_lengths, - int prefetch, - bool is_weight_positional); +template + typename EmbeddingSpMDMKernelSignature::Type + GenerateEmbeddingSpMDM4Bit( + const std::int64_t block_size, + bool has_weight, + bool normalize_by_lengths, + int prefetch, + bool is_weight_positional); + +template + typename EmbeddingSpMDMKernelSignature::Type + GenerateEmbeddingSpMDM4Bit( + const std::int64_t block_size, + bool has_weight, + bool normalize_by_lengths, + int prefetch, + bool is_weight_positional); template bool EmbeddingSpMDM4Bit( const std::int64_t block_size,