From 1ea1a02148c6d04b483ca33ef85afe25e19fa3ac Mon Sep 17 00:00:00 2001 From: Hongzhang Shan Date: Fri, 31 Jan 2020 08:47:02 -0800 Subject: [PATCH] add more instantiation for EmbeddingSpMDMAvx2 (#274) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/274 Solved unresolved references in test/EmbeddingSpMDMTest.cc. Add instantiation for input type uint8_t. Reviewed By: jspark1105 Differential Revision: D19654528 fbshipit-source-id: 5dca40e1f6530c5720f290c27730e2e13bd9dce7 --- src/EmbeddingSpMDMAvx2.cc | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/src/EmbeddingSpMDMAvx2.cc b/src/EmbeddingSpMDMAvx2.cc index 1f43dd26b0..5ff7b056b7 100644 --- a/src/EmbeddingSpMDMAvx2.cc +++ b/src/EmbeddingSpMDMAvx2.cc @@ -173,5 +173,29 @@ template bool EmbeddingSpMDMBlockSize1_( float* out, bool is_weight_positional); +template bool EmbeddingSpMDMBlockSize1_( + const std::int64_t output_size, + const std::int64_t index_size, + const std::int64_t data_size, // the number of rows in input + const std::uint8_t* input, + const std::int64_t* indices, + const int* lengths, + const float* weights, // optional, can be null for non-weighted sum + bool normalize_by_lengths, + float* out, + bool is_weight_positional); + +template bool EmbeddingSpMDMBlockSize1_( + const std::int64_t output_size, + const std::int64_t index_size, + const std::int64_t data_size, // the number of rows in input + const std::uint8_t* input, + const std::int32_t* indices, + const int* lengths, + const float* weights, // optional, can be null for non-weighted sum + bool normalize_by_lengths, + float* out, + bool is_weight_positional); + } // namespace internal } // namespace fbgemm