From 1a7be35d9445038470a16d299dd37b5fe95c6222 Mon Sep 17 00:00:00 2001 From: Jongsoo Park Date: Wed, 19 Feb 2020 00:54:56 -0800 Subject: [PATCH] fix index mapping table type to int32_t (#302) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/302 To address a challenge of detecting index type during model transformation. We shouldn't expect the number of rows in an embedding table will exceed int32_t range especially after row-wise pruning Reviewed By: dehuacheng, jianyuh Differential Revision: D19954308 fbshipit-source-id: 8d152f747f7e95baa936bd66494414b516906b09 --- ...mbeddingSpMDMNBitRowWiseSparseBenchmark.cc | 11 +- include/fbgemm/FbgemmEmbedding.h | 2 +- src/EmbeddingSpMDM.cc | 81 +++++------- src/EmbeddingSpMDMNBit.cc | 58 +++----- src/RefImplementations.cc | 124 +++++++++--------- src/RefImplementations.h | 4 +- test/EmbeddingSpMDM8BitTest.cc | 11 +- test/EmbeddingSpMDMNBitTest.cc | 11 +- test/EmbeddingSpMDMTest.cc | 15 +-- test/EmbeddingSpMDMTestUtils.cc | 7 +- test/EmbeddingSpMDMTestUtils.h | 3 +- 11 files changed, 135 insertions(+), 192 deletions(-) diff --git a/bench/EmbeddingSpMDMNBitRowWiseSparseBenchmark.cc b/bench/EmbeddingSpMDMNBitRowWiseSparseBenchmark.cc index d400666e0f..43fa3aae58 100644 --- a/bench/EmbeddingSpMDMNBitRowWiseSparseBenchmark.cc +++ b/bench/EmbeddingSpMDMNBitRowWiseSparseBenchmark.cc @@ -68,7 +68,7 @@ int run_benchmark( // Generate mapping table default_random_engine generator; constexpr float sparsity = 0.7; - vector mapping_table(num_rows); + vector mapping_table(num_rows); bernoulli_distribution row_prune_dist(sparsity); int num_compressed_rows = 0; for (int i = 0; i < num_rows; ++i) { @@ -80,11 +80,6 @@ int run_benchmark( ++num_compressed_rows; } } - vector mapping_table_32; - copy( - mapping_table.begin(), - mapping_table.end(), - back_inserter(mapping_table_32)); // Create embedding table int num_elem_per_byte = 8 / bit_rate; @@ -192,7 +187,7 @@ int run_benchmark( num_rows, fused_embedding_table.data(), indices_32.data(), - mapping_table_32.data(), + mapping_table.data(), lengths.data(), has_weight ? weights.data() : nullptr, normalize_by_lengths, @@ -243,7 +238,7 @@ int run_benchmark( lengths.data(), has_weight ? weights.data() : nullptr, output.data(), - mapping_table_32.data()); + mapping_table.data()); } else { success = kernel_64( batch_size, diff --git a/include/fbgemm/FbgemmEmbedding.h b/include/fbgemm/FbgemmEmbedding.h index 081eedc646..91f9b8459d 100644 --- a/include/fbgemm/FbgemmEmbedding.h +++ b/include/fbgemm/FbgemmEmbedding.h @@ -66,7 +66,7 @@ class EmbeddingSpMDMRowWiseSparseKernelSignature { const int* lengths, const float* weights, // optional, can be null for non-weighted sum float* out, - const IndexType* compressed_indices_table)>; + const std::int32_t* compressed_indices_table)>; }; /** diff --git a/src/EmbeddingSpMDM.cc b/src/EmbeddingSpMDM.cc index 2a9b4b4810..2a11e6b5ac 100644 --- a/src/EmbeddingSpMDM.cc +++ b/src/EmbeddingSpMDM.cc @@ -60,7 +60,7 @@ class ReturnFunctionSignature { const int* lengths, const float* weights, float* out, - const indxType* compressed_indices_table, + const std::int32_t* compressed_indices_table, const int* mask); }; @@ -203,18 +203,19 @@ typename ReturnFunctionSignature:: asmjit::FuncDetail func; if (ROWWISE_SPARSE) { - func.init(asmjit::FuncSignatureT< - bool, - std::int64_t, // output_size - std::int64_t, // index_size - std::int64_t, // uncompressed_data_size - const inType*, // input uint8_t or float - const indxType*, // indices - const int*, // lengths - const float*, // weights - float*, // out - const indxType*, // compressed_indices_table and then mask - const int*>(asmjit::CallConv::kIdHost)); + func.init( + asmjit::FuncSignatureT< + bool, + std::int64_t, // output_size + std::int64_t, // index_size + std::int64_t, // uncompressed_data_size + const inType*, // input uint8_t or float + const indxType*, // indices + const int*, // lengths + const float*, // weights + float*, // out + const std::int32_t*, // compressed_indices_table and then mask + const int*>(asmjit::CallConv::kIdHost)); } else { func.init(asmjit::FuncSignatureT< bool, @@ -450,21 +451,12 @@ typename ReturnFunctionSignature:: a->jge(error); if (ROWWISE_SPARSE) { - if (areIndices64b) { - a->mov( - scratchReg1_, - x86::qword_ptr( - compressed_indices_table, - scratchReg1_, - 3)); // use of 3 is to multiply by 8 - } else { - a->mov( - scratchReg1_.r32(), - x86::dword_ptr( - compressed_indices_table, - scratchReg1_, - 2)); // use of 2 is to multiply by 4 - } + a->mov( + scratchReg1_.r32(), + x86::dword_ptr( + compressed_indices_table, + scratchReg1_, + 2)); // use of 2 is to multiply by 4 } int fused_block_size = is8bit @@ -511,21 +503,12 @@ typename ReturnFunctionSignature:: a->bind(pref_dist_reset_end); if (ROWWISE_SPARSE) { - if (areIndices64b) { - a->mov( - scratchReg2_, - x86::qword_ptr( - compressed_indices_table, - scratchReg2_, - 3)); // use of 3 is to multiply by 8 - } else { - a->mov( - scratchReg2_.r32(), - x86::dword_ptr( - compressed_indices_table, - scratchReg2_, - 2)); // use of 2 is to multiply by 4 - } + a->mov( + scratchReg2_.r32(), + x86::dword_ptr( + compressed_indices_table, + scratchReg2_, + 2)); // use of 2 is to multiply by 4 } a->imul(scratchReg2_, static_cast(fused_block_size)); } @@ -538,11 +521,7 @@ typename ReturnFunctionSignature:: } if (ROWWISE_SPARSE) { - if (areIndices64b) { - a->cmp(scratchReg1_, static_cast(-1)); - } else { - a->cmp(scratchReg1_.r32(), static_cast(-1)); - } + a->cmp(scratchReg1_.r32(), static_cast(-1)); a->je(LoopDataIndexBegin); } @@ -922,7 +901,7 @@ GenerateEmbeddingSpMDMRowWiseSparse( const int* lengths, const float* weights, float* out, - const indxType* compressed_indices_table) { + const std::int32_t* compressed_indices_table) { return original_func( output_size, index_size, @@ -953,7 +932,7 @@ GenerateEmbeddingSpMDMRowWiseSparse( const int* lengths, const float* weights, float* out, - const indxType* compressed_indices_table) { + const std::int32_t* compressed_indices_table) { return original_func( output_size, index_size, @@ -979,7 +958,7 @@ GenerateEmbeddingSpMDMRowWiseSparse( const int* lengths, const float* weights, // optional, can be null for non-weighted sum float* out, - const indxType* compressed_indices_table) { + const std::int32_t* compressed_indices_table) { return EmbeddingSpMDMRowWiseSparse_ref( block_size, output_size, diff --git a/src/EmbeddingSpMDMNBit.cc b/src/EmbeddingSpMDMNBit.cc index 8067001ff7..42ba26da29 100644 --- a/src/EmbeddingSpMDMNBit.cc +++ b/src/EmbeddingSpMDMNBit.cc @@ -67,7 +67,7 @@ class ReturnFunctionSignature { const int* lengths, const float* weights, float* out, - const indxType* compressed_indices_table, + const int32_t* compressed_indices_table, const int* mask); }; @@ -217,7 +217,7 @@ GenEmbeddingSpMDMNBitLookup::getOrCreate( const int*, // lengths const float*, // weights float*, // out - const indxType* /* compressed_indices_table */, + const int32_t* /* compressed_indices_table */, const int* /* mask */>(asmjit::CallConv::kIdHost)); } else { func.init(asmjit::FuncSignatureT< @@ -493,21 +493,12 @@ GenEmbeddingSpMDMNBitLookup::getOrCreate( a->jge(error); if (ROWWISE_SPARSE) { - if (areIndices64b) { - a->mov( - scratchReg1_, - x86::qword_ptr( - compressed_indices_table, - scratchReg1_, - 3)); // use of 3 is to multiply by 8 - } else { - a->mov( - scratchReg1_.r32(), - x86::dword_ptr( - compressed_indices_table, - scratchReg1_, - 2)); // use of 2 is to multiply by 4 - } + a->mov( + scratchReg1_.r32(), + x86::dword_ptr( + compressed_indices_table, + scratchReg1_, + 2)); // use of 2 is to multiply by 4 } int num_elem_per_byte = 8 / bit_rate; @@ -553,21 +544,12 @@ GenEmbeddingSpMDMNBitLookup::getOrCreate( a->bind(pref_dist_reset_end); if (ROWWISE_SPARSE) { - if (areIndices64b) { - a->mov( - scratchReg2_, - x86::qword_ptr( - compressed_indices_table, - scratchReg2_, - 3)); // use of 3 is to multiply by 8 - } else { - a->mov( - scratchReg2_.r32(), - x86::dword_ptr( - compressed_indices_table, - scratchReg2_, - 2)); // use of 2 is to multiply by 4 - } + a->mov( + scratchReg2_.r32(), + x86::dword_ptr( + compressed_indices_table, + scratchReg2_, + 2)); // use of 2 is to multiply by 4 } // This has to be fused_block_size a->imul(scratchReg2_, static_cast(fused_block_size)); @@ -581,11 +563,7 @@ GenEmbeddingSpMDMNBitLookup::getOrCreate( } if (ROWWISE_SPARSE) { - if (areIndices64b) { - a->cmp(scratchReg1_, static_cast(-1)); - } else { - a->cmp(scratchReg1_.r32(), static_cast(-1)); - } + a->cmp(scratchReg1_.r32(), static_cast(-1)); a->je(LoopDataIndexBegin); } @@ -977,7 +955,7 @@ GenerateEmbeddingSpMDMNBitRowWiseSparse( const int* lengths, const float* weights, float* out, - const indxType* compressed_indices_table) { + const int32_t* compressed_indices_table) { return original_func( output_size, index_size, @@ -1009,7 +987,7 @@ GenerateEmbeddingSpMDMNBitRowWiseSparse( const int* lengths, const float* weights, float* out, - const indxType* compressed_indices_table) { + const int32_t* compressed_indices_table) { return original_func( output_size, index_size, @@ -1034,7 +1012,7 @@ GenerateEmbeddingSpMDMNBitRowWiseSparse( const int* lengths, const float* weights, float* out, - const indxType* compressed_indices_table) { + const int32_t* compressed_indices_table) { return EmbeddingSpMDMNBitRowWiseSparse_ref( bit_rate, block_size, diff --git a/src/RefImplementations.cc b/src/RefImplementations.cc index 73c31b97da..4ccaebe0fd 100644 --- a/src/RefImplementations.cc +++ b/src/RefImplementations.cc @@ -809,7 +809,7 @@ bool EmbeddingSpMDMRowWiseSparse_ref( // const int64_t compressed_data_size, const inType* input, const IndexType* indices, - const IndexType* compressed_indices_table, + const int32_t* compressed_indices_table, const int* lengths, const float* weights, // optional, can be null for non-weighted sum bool normalize_by_lengths, @@ -929,7 +929,7 @@ bool EmbeddingSpMDMNBitRowWiseSparse_ref( // const int64_t compressed_data_size, const uint8_t* input, const IndexType* indices, - const IndexType* compressed_indices_table, + const int32_t* compressed_indices_table, const int* lengths, const float* weights, // optional, can be null for non-weighted sum bool normalize_by_lengths, @@ -1273,14 +1273,14 @@ template FBGEMM_API bool EmbeddingSpMDMNBit_ref( template FBGEMM_API bool EmbeddingSpMDMNBitRowWiseSparse_ref( int bit_rate, - const std::int64_t block_size, - const std::int64_t output_size, - const std::int64_t index_size, - const std::int64_t uncompressed_data_size, - // const std::int64_t compressed_data_size, - const std::uint8_t* input, - const std::int64_t* indices, - const std::int64_t* compressed_indices_table, + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t uncompressed_data_size, + // const int64_t compressed_data_size, + const uint8_t* input, + const int64_t* indices, + const int32_t* compressed_indices_table, const int* lengths, const float* weights, // optional, can be null for non-weighted sum bool normalize_by_lengths, @@ -1288,14 +1288,14 @@ template FBGEMM_API bool EmbeddingSpMDMNBitRowWiseSparse_ref( bool is_weight_positional); template FBGEMM_API bool EmbeddingSpMDMRowWiseSparse_ref( - const std::int64_t block_size, - const std::int64_t output_size, - const std::int64_t index_size, - const std::int64_t uncompressed_data_size, - // const std::int64_t compressed_data_size, + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t uncompressed_data_size, + // const int64_t compressed_data_size, const float16* input, - const std::int64_t* indices, - const std::int64_t* compressed_indices_table, + const int64_t* indices, + const int32_t* compressed_indices_table, const int* lengths, const float* weights, // optional, can be null for non-weighted sum bool normalize_by_lengths, @@ -1303,14 +1303,14 @@ template FBGEMM_API bool EmbeddingSpMDMRowWiseSparse_ref( bool is_weight_positional); template FBGEMM_API bool EmbeddingSpMDMRowWiseSparse_ref( - const std::int64_t block_size, - const std::int64_t output_size, - const std::int64_t index_size, - const std::int64_t uncompressed_data_size, - // const std::int64_t compressed_data_size, + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t uncompressed_data_size, + // const int64_t compressed_data_size, const float16* input, - const std::int32_t* indices, - const std::int32_t* compressed_indices_table, + const int32_t* indices, + const int32_t* compressed_indices_table, const int* lengths, const float* weights, // optional, can be null for non-weighted sum bool normalize_by_lengths, @@ -1318,14 +1318,14 @@ template FBGEMM_API bool EmbeddingSpMDMRowWiseSparse_ref( bool is_weight_positional); template FBGEMM_API bool EmbeddingSpMDMRowWiseSparse_ref( - const std::int64_t block_size, - const std::int64_t output_size, - const std::int64_t index_size, - const std::int64_t uncompressed_data_size, - // const std::int64_t compressed_data_size, + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t uncompressed_data_size, + // const int64_t compressed_data_size, const float* input, - const std::int64_t* indices, - const std::int64_t* compressed_indices_table, + const int64_t* indices, + const int32_t* compressed_indices_table, const int* lengths, const float* weights, // optional, can be null for non-weighted sum bool normalize_by_lengths, @@ -1333,14 +1333,14 @@ template FBGEMM_API bool EmbeddingSpMDMRowWiseSparse_ref( bool is_weight_positional); template FBGEMM_API bool EmbeddingSpMDMRowWiseSparse_ref( - const std::int64_t block_size, - const std::int64_t output_size, - const std::int64_t index_size, - const std::int64_t uncompressed_data_size, - // const std::int64_t compressed_data_size, + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t uncompressed_data_size, + // const int64_t compressed_data_size, const float* input, - const std::int32_t* indices, - const std::int32_t* compressed_indices_table, + const int32_t* indices, + const int32_t* compressed_indices_table, const int* lengths, const float* weights, // optional, can be null for non-weighted sum bool normalize_by_lengths, @@ -1348,14 +1348,14 @@ template FBGEMM_API bool EmbeddingSpMDMRowWiseSparse_ref( bool is_weight_positional); template FBGEMM_API bool EmbeddingSpMDMRowWiseSparse_ref( - const std::int64_t block_size, - const std::int64_t output_size, - const std::int64_t index_size, - const std::int64_t uncompressed_data_size, - // const std::int64_t compressed_data_size, - const std::uint8_t* input, - const std::int64_t* indices, - const std::int64_t* compressed_indices_table, + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t uncompressed_data_size, + // const int64_t compressed_data_size, + const uint8_t* input, + const int64_t* indices, + const int32_t* compressed_indices_table, const int* lengths, const float* weights, // optional, can be null for non-weighted sum bool normalize_by_lengths, @@ -1363,14 +1363,14 @@ template FBGEMM_API bool EmbeddingSpMDMRowWiseSparse_ref( bool is_weight_positional); template FBGEMM_API bool EmbeddingSpMDMRowWiseSparse_ref( - const std::int64_t block_size, - const std::int64_t output_size, - const std::int64_t index_size, - const std::int64_t uncompressed_data_size, - // const std::int64_t compressed_data_size, - const std::uint8_t* input, - const std::int32_t* indices, - const std::int32_t* compressed_indices_table, + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t uncompressed_data_size, + // const int64_t compressed_data_size, + const uint8_t* input, + const int32_t* indices, + const int32_t* compressed_indices_table, const int* lengths, const float* weights, // optional, can be null for non-weighted sum bool normalize_by_lengths, @@ -1379,14 +1379,14 @@ template FBGEMM_API bool EmbeddingSpMDMRowWiseSparse_ref( template FBGEMM_API bool EmbeddingSpMDMNBitRowWiseSparse_ref( int bit_rate, - const std::int64_t block_size, - const std::int64_t output_size, - const std::int64_t index_size, - const std::int64_t uncompressed_data_size, - // const std::int64_t compressed_data_size, - const std::uint8_t* input, - const std::int32_t* indices, - const std::int32_t* compressed_indices_table, + const int64_t block_size, + const int64_t output_size, + const int64_t index_size, + const int64_t uncompressed_data_size, + // const int64_t compressed_data_size, + const uint8_t* input, + const int32_t* indices, + const int32_t* compressed_indices_table, const int* lengths, const float* weights, // optional, can be null for non-weighted sum bool normalize_by_lengths, diff --git a/src/RefImplementations.h b/src/RefImplementations.h index c82c651b17..ef1bbdee35 100644 --- a/src/RefImplementations.h +++ b/src/RefImplementations.h @@ -252,7 +252,7 @@ FBGEMM_API bool EmbeddingSpMDMRowWiseSparse_ref( // const std::int64_t compressed_data_size, const inType* input, const IndexType* indices, - const IndexType* compressed_indices_table, + const std::int32_t* compressed_indices_table, const int* lengths, const float* weights, // optional, can be null for non-weighted sum bool normalize_by_lengths, @@ -269,7 +269,7 @@ FBGEMM_API bool EmbeddingSpMDMNBitRowWiseSparse_ref( // const std::int64_t compressed_data_size, const std::uint8_t* input, const IndexType* indices, - const IndexType* compressed_indices_table, + const std::int32_t* compressed_indices_table, const int* lengths, const float* weights, // optional, can be null for non-weighted sum bool normalize_by_lengths, diff --git a/test/EmbeddingSpMDM8BitTest.cc b/test/EmbeddingSpMDM8BitTest.cc index 2a6ada2b39..267dc6d1cd 100644 --- a/test/EmbeddingSpMDM8BitTest.cc +++ b/test/EmbeddingSpMDM8BitTest.cc @@ -227,10 +227,9 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, rowwiseSparseTest) { int average_len = input[3]; // Create mapping table for rowwise sparsity - vector mapping_table; - vector mapping_table_32; - int num_compressed_rows = CreateMappingTableForRowWiseSparsity( - mapping_table, mapping_table_32, num_rows, sparsity); + vector mapping_table; + int num_compressed_rows = + CreateMappingTableForRowWiseSparsity(mapping_table, num_rows, sparsity); // Create embedding table default_random_engine generator; @@ -314,7 +313,7 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, rowwiseSparseTest) { num_rows, fused_embedding_table.data(), corner_case == EMPTY_INDICES ? nullptr : indices_32.data(), - mapping_table_32.data(), + mapping_table.data(), lengths.data(), use_weight ? weights.data() : nullptr, normalize_by_lengths, @@ -336,7 +335,7 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, rowwiseSparseTest) { lengths.data(), use_weight ? weights.data() : nullptr, output.data(), - mapping_table_32.data()); + mapping_table.data()); } // Check correctness diff --git a/test/EmbeddingSpMDMNBitTest.cc b/test/EmbeddingSpMDMNBitTest.cc index c4396c83d6..9b1445580e 100644 --- a/test/EmbeddingSpMDMNBitTest.cc +++ b/test/EmbeddingSpMDMNBitTest.cc @@ -246,10 +246,9 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, rowwiseSparseTest) { int average_len = input[3]; // Create mapping table for rowwise sparsity - vector mapping_table; - vector mapping_table_32; - int num_compressed_rows = CreateMappingTableForRowWiseSparsity( - mapping_table, mapping_table_32, num_rows, sparsity); + vector mapping_table; + int num_compressed_rows = + CreateMappingTableForRowWiseSparsity(mapping_table, num_rows, sparsity); // Create embedding table default_random_engine generator; @@ -342,7 +341,7 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, rowwiseSparseTest) { num_rows, fused_embedding_table.data(), corner_case == EMPTY_INDICES ? nullptr : indices_32.data(), - mapping_table_32.data(), + mapping_table.data(), lengths.data(), use_weight ? weights.data() : nullptr, normalize_by_lengths, @@ -365,7 +364,7 @@ TEST_P(FusedNBitRowwiseEmbeddingLookupTest, rowwiseSparseTest) { lengths.data(), use_weight ? weights.data() : nullptr, output.data(), - mapping_table_32.data()); + mapping_table.data()); } // Check correctness diff --git a/test/EmbeddingSpMDMTest.cc b/test/EmbeddingSpMDMTest.cc index 20281c00c5..4be1fa9d1e 100644 --- a/test/EmbeddingSpMDMTest.cc +++ b/test/EmbeddingSpMDMTest.cc @@ -292,10 +292,9 @@ TEST_P(EmbeddingSpMDMTest, rowwiseSparseTest) { int average_len = input[3]; // Create mapping table for rowwise sparsity - vector mapping_table; - vector mapping_table_32; - int num_compressed_rows = CreateMappingTableForRowWiseSparsity( - mapping_table, mapping_table_32, num_rows, sparsity); + vector mapping_table; + int num_compressed_rows = + CreateMappingTableForRowWiseSparsity(mapping_table, num_rows, sparsity); // Create embedding table vector embedding_table(num_compressed_rows * embedding_dim); @@ -409,7 +408,7 @@ TEST_P(EmbeddingSpMDMTest, rowwiseSparseTest) { num_rows, embedding_table_fp16.data(), corner_case == EMPTY_INDICES ? nullptr : indices_32.data(), - mapping_table_32.data(), + mapping_table.data(), lengths.data(), use_weight ? weights.data() : nullptr, normalize_by_lengths, @@ -431,7 +430,7 @@ TEST_P(EmbeddingSpMDMTest, rowwiseSparseTest) { lengths.data(), use_weight ? weights.data() : nullptr, output.data(), - mapping_table_32.data()); + mapping_table.data()); } else { success_ref = EmbeddingSpMDMRowWiseSparse_ref( embedding_dim, @@ -440,7 +439,7 @@ TEST_P(EmbeddingSpMDMTest, rowwiseSparseTest) { num_rows, embedding_table.data(), corner_case == EMPTY_INDICES ? nullptr : indices_32.data(), - mapping_table_32.data(), + mapping_table.data(), lengths.data(), use_weight ? weights.data() : nullptr, normalize_by_lengths, @@ -462,7 +461,7 @@ TEST_P(EmbeddingSpMDMTest, rowwiseSparseTest) { lengths.data(), use_weight ? weights.data() : nullptr, output.data(), - mapping_table_32.data()); + mapping_table.data()); } } diff --git a/test/EmbeddingSpMDMTestUtils.cc b/test/EmbeddingSpMDMTestUtils.cc index 5e6728da30..3b01eb707a 100644 --- a/test/EmbeddingSpMDMTestUtils.cc +++ b/test/EmbeddingSpMDMTestUtils.cc @@ -65,8 +65,7 @@ int GenerateLengthsIndicesWeights( } int CreateMappingTableForRowWiseSparsity( - vector& mapping_table, - vector& mapping_table_32, + vector& mapping_table, int num_rows, float sparsity) { default_random_engine generator; @@ -82,10 +81,6 @@ int CreateMappingTableForRowWiseSparsity( ++num_compressed_rows; } } - copy( - mapping_table.begin(), - mapping_table.end(), - back_inserter(mapping_table_32)); return num_compressed_rows; } diff --git a/test/EmbeddingSpMDMTestUtils.h b/test/EmbeddingSpMDMTestUtils.h index 036db655be..a9572901e9 100644 --- a/test/EmbeddingSpMDMTestUtils.h +++ b/test/EmbeddingSpMDMTestUtils.h @@ -30,8 +30,7 @@ int GenerateLengthsIndicesWeights( * @return num_compressed_rows */ int CreateMappingTableForRowWiseSparsity( - std::vector& mapping_table, - std::vector& mapping_table_32, + std::vector& mapping_table, int num_rows, float sparsity);