Skip to content

Commit

Permalink
clean up embedding spmdm tests (pytorch#300)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#300

Refactor redundant code among various embedding spmdm tests. More careful testing of corner cases

Reviewed By: jianyuh

Differential Revision: D19943506

fbshipit-source-id: 4eb8d55db7db729db65017c18c8ed0695c54516c
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Feb 19, 2020
1 parent 2b6eef4 commit f40e6d1
Show file tree
Hide file tree
Showing 8 changed files with 361 additions and 396 deletions.
2 changes: 1 addition & 1 deletion src/RefImplementations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -740,7 +740,7 @@ bool EmbeddingSpMDM_ref(
}
out += block_size;
}
return true;
return current == index_size;
} else {
// Reference implementation of FP32 SLS
int64_t current = 0;
Expand Down
5 changes: 4 additions & 1 deletion test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ endif()

macro(add_gtest TESTNAME)
add_executable(${TESTNAME} ${ARGN}
../bench/BenchUtils.cc QuantizationHelpers.cc TestUtils.cc)
../bench/BenchUtils.cc
EmbeddingSpMDMTestUtils.cc
QuantizationHelpers.cc
TestUtils.cc)
set_target_properties(${TESTNAME} PROPERTIES
CXX_STANDARD 11
CXX_EXTENSIONS NO)
Expand Down
176 changes: 66 additions & 110 deletions test/EmbeddingSpMDM8BitTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include <gtest/gtest.h>

#include "./EmbeddingSpMDMTestUtils.h"
#include "fbgemm/Fbgemm.h"
#include "fbgemm/Utils.h"
#include "src/RefImplementations.h"
Expand Down Expand Up @@ -48,10 +49,9 @@ vector<int> prefetch_distances{0, 16, 1000000};

namespace {

// tuple represents MB, IC, OC, IT, IH, IW, KH/KW, stride, pad
class Fused8BitRowwiseEmbeddingLookupTest
: public testing::TestWithParam<
tuple<bool, bool, int, bool, bool, bool, bool>> {};
tuple<bool, bool, int, bool, bool, EmbeddingSpMDMCornerCase>> {};
}; // namespace

INSTANTIATE_TEST_CASE_P(
Expand All @@ -63,21 +63,23 @@ INSTANTIATE_TEST_CASE_P(
::testing::ValuesIn(prefetch_distances),
::testing::Bool(), // use_weight
::testing::Bool(), // normalize_by_lengths
::testing::Bool(), // empty_indices
::testing::Bool())); // out of bounds
::testing::Values(
NONE,
EMPTY_INDICES,
OUT_OF_BOUND_INDICES,
UNMATCHED_NUM_INDICES_AND_LENGTHS_SUM)));

TEST_P(Fused8BitRowwiseEmbeddingLookupTest, basicTest) {
vector<vector<int>> inputs(GetInputs_());
bool isIndex64b, is_wt_positional, use_weight, normalize_by_lengths,
empty_indices, out_of_bounds;
bool isIndex64b, is_wt_positional, use_weight, normalize_by_lengths;
int prefetch;
EmbeddingSpMDMCornerCase corner_case;
tie(isIndex64b,
is_wt_positional,
prefetch,
use_weight,
normalize_by_lengths,
empty_indices,
out_of_bounds) = GetParam();
corner_case) = GetParam();

for (auto input : inputs) {
int batch_size = input[0];
Expand All @@ -104,39 +106,20 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, basicTest) {
scale_bias[1] = embedding_distribution(generator);
}

// Generate lengths
uniform_int_distribution<int> length_distribution(
1, std::min(2 * average_len + 1, num_rows));
vector<int> lengths(batch_size);
for (int i = 0; i < batch_size; ++i) {
lengths[i] = empty_indices ? 0 : length_distribution(generator);
}

// Compute the number of indices
int lengths_sum = accumulate(lengths.begin(), lengths.end(), 0);

// Generate indices
vector<int64_t> indices(lengths_sum);
vector<int32_t> indices_32(lengths_sum);

uniform_int_distribution<int> index_distribution(0, num_rows - 1);
for (int i = 0; i < lengths_sum; ++i) {
indices_32[i] = indices[i] = index_distribution(generator);
}
if (!empty_indices && out_of_bounds) {
int idx = uniform_int_distribution<int>(0, lengths_sum - 1)(generator);
indices_32[idx] = indices[idx] = num_rows;
}
if (!empty_indices) {
// To make sure to exercise out-of-bound cases
indices_32[0] = indices[0] = num_rows - 1;
}

// Generate weights
vector<float> weights(lengths_sum);
for (int i = 0; i < lengths_sum; ++i) {
weights[i] = embedding_distribution(generator);
}
vector<int> lengths;
vector<int64_t> indices;
vector<int32_t> indices_32;
vector<float> weights;
int lengths_sum = GenerateLengthsIndicesWeights(
lengths,
indices,
indices_32,
weights,
batch_size,
num_rows,
embedding_dim,
average_len,
corner_case);

vector<float> output_sls_ref(batch_size * embedding_dim);
vector<float> output_slws_ref(output_sls_ref.size()),
Expand All @@ -153,7 +136,7 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, basicTest) {
lengths_sum,
num_rows,
fused_embedding_table.data(),
empty_indices ? nullptr : indices.data(),
corner_case == EMPTY_INDICES ? nullptr : indices.data(),
lengths.data(),
use_weight ? weights.data() : nullptr,
normalize_by_lengths,
Expand All @@ -171,7 +154,7 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, basicTest) {
lengths_sum,
num_rows,
fused_embedding_table.data(),
empty_indices ? nullptr : indices.data(),
corner_case == EMPTY_INDICES ? nullptr : indices.data(),
lengths.data(),
use_weight ? weights.data() : nullptr,
output.data());
Expand All @@ -182,7 +165,7 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, basicTest) {
lengths_sum,
num_rows,
fused_embedding_table.data(),
empty_indices ? nullptr : indices_32.data(),
corner_case == EMPTY_INDICES ? nullptr : indices_32.data(),
lengths.data(),
use_weight ? weights.data() : nullptr,
normalize_by_lengths,
Expand All @@ -200,7 +183,7 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, basicTest) {
lengths_sum,
num_rows,
fused_embedding_table.data(),
empty_indices ? nullptr : indices_32.data(),
corner_case == EMPTY_INDICES ? nullptr : indices_32.data(),
lengths.data(),
use_weight ? weights.data() : nullptr,
output.data());
Expand All @@ -209,6 +192,10 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, basicTest) {
// Check correctness
EXPECT_EQ(success, success_ref)
<< "Reference and JIT impl did not both succeed";
if (corner_case == OUT_OF_BOUND_INDICES ||
corner_case == UNMATCHED_NUM_INDICES_AND_LENGTHS_SUM) {
EXPECT_EQ(success, false);
}
if (success) {
for (int i = 0; i < output.size(); ++i) {
EXPECT_EQ(output[i], output_ref[i])
Expand All @@ -221,16 +208,15 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, basicTest) {

TEST_P(Fused8BitRowwiseEmbeddingLookupTest, rowwiseSparseTest) {
vector<vector<int>> inputs(GetInputs_());
bool isIndex64b, is_wt_positional, use_weight, normalize_by_lengths,
empty_indices, out_of_bounds;
bool isIndex64b, is_wt_positional, use_weight, normalize_by_lengths;
int prefetch;
EmbeddingSpMDMCornerCase corner_case;
tie(isIndex64b,
is_wt_positional,
prefetch,
use_weight,
normalize_by_lengths,
empty_indices,
out_of_bounds) = GetParam();
corner_case) = GetParam();

constexpr float sparsity = 0.7;

Expand All @@ -240,33 +226,21 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, rowwiseSparseTest) {
int embedding_dim = input[2];
int average_len = input[3];

// Create mapping table for rowwise sparsity
vector<int64_t> mapping_table;
vector<int32_t> mapping_table_32;
int num_compressed_rows = CreateMappingTableForRowWiseSparsity(
mapping_table, mapping_table_32, num_rows, sparsity);

// Create embedding table
default_random_engine generator;
normal_distribution<float> embedding_distribution;
uniform_int_distribution<int> entries(0, 16);

// Create mapping table for rowwise sparsity
vector<int64_t> mapping_table(num_rows);
bernoulli_distribution row_prune_dist(sparsity);
int num_compressed_rows = 0;
for (int i = 0; i < num_rows; ++i) {
if (row_prune_dist(generator)) {
// pruned
mapping_table[i] = -1;
} else {
mapping_table[i] = num_compressed_rows;
++num_compressed_rows;
}
}
vector<int32_t> mapping_table_32;
copy(
mapping_table.begin(),
mapping_table.end(),
back_inserter(mapping_table_32));

int fused_embedding_dim = embedding_dim + 2 * sizeof(float);
vector<uint8_t> fused_embedding_table(num_rows * fused_embedding_dim);
for (int i = 0; i < num_rows; i++) {
vector<uint8_t> fused_embedding_table(
num_compressed_rows * fused_embedding_dim);
for (int i = 0; i < num_compressed_rows; i++) {
for (int ii = 0; ii < embedding_dim; ii++) {
fused_embedding_table[i * fused_embedding_dim + ii] =
entries(generator);
Expand All @@ -278,42 +252,20 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, rowwiseSparseTest) {
scale_bias[1] = embedding_distribution(generator);
}

// Generate lengths
uniform_int_distribution<int> length_distribution(
1, std::min(2 * average_len + 1, num_rows));
vector<int> lengths(batch_size);
for (int i = 0; i < batch_size; ++i) {
lengths[i] = empty_indices ? 0 : length_distribution(generator);
}

// Compute the number of indices
int lengths_sum = accumulate(lengths.begin(), lengths.end(), 0);

// Generate indices
vector<int64_t> indices(lengths_sum);
vector<int32_t> indices_32(lengths_sum);

uniform_int_distribution<int> index_distribution(0, num_rows - 1);
for (int i = 0; i < lengths_sum; ++i) {
indices_32[i] = indices[i] = index_distribution(generator);
}
if (!empty_indices && out_of_bounds) {
int idx = uniform_int_distribution<int>(0, lengths_sum - 1)(generator);
indices_32[idx] = indices[idx] = num_rows;

// idx = uniform_int_distribution<int>(0, num_rows - 1)(generator);
// mapping_table_32[idx] = mapping_table[idx] = num_compressed_rows;
}
if (!empty_indices) {
// To make sure to exercise out-of-bound cases
indices_32[0] = indices[0] = num_rows - 1;
}

// Generate weights
vector<float> weights(lengths_sum);
for (int i = 0; i < lengths_sum; ++i) {
weights[i] = embedding_distribution(generator);
}
vector<int> lengths;
vector<int64_t> indices;
vector<int32_t> indices_32;
vector<float> weights;
int lengths_sum = GenerateLengthsIndicesWeights(
lengths,
indices,
indices_32,
weights,
batch_size,
num_rows,
embedding_dim,
average_len,
corner_case);

vector<float> output_sls_ref(batch_size * embedding_dim);
vector<float> output_slws_ref(output_sls_ref.size()),
Expand All @@ -330,7 +282,7 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, rowwiseSparseTest) {
lengths_sum,
num_rows,
fused_embedding_table.data(),
empty_indices ? nullptr : indices.data(),
corner_case == EMPTY_INDICES ? nullptr : indices.data(),
mapping_table.data(),
lengths.data(),
use_weight ? weights.data() : nullptr,
Expand All @@ -349,7 +301,7 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, rowwiseSparseTest) {
lengths_sum,
num_rows,
fused_embedding_table.data(),
empty_indices ? nullptr : indices.data(),
corner_case == EMPTY_INDICES ? nullptr : indices.data(),
lengths.data(),
use_weight ? weights.data() : nullptr,
output.data(),
Expand All @@ -361,7 +313,7 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, rowwiseSparseTest) {
lengths_sum,
num_rows,
fused_embedding_table.data(),
empty_indices ? nullptr : indices_32.data(),
corner_case == EMPTY_INDICES ? nullptr : indices_32.data(),
mapping_table_32.data(),
lengths.data(),
use_weight ? weights.data() : nullptr,
Expand All @@ -380,7 +332,7 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, rowwiseSparseTest) {
lengths_sum,
num_rows,
fused_embedding_table.data(),
empty_indices ? nullptr : indices_32.data(),
corner_case == EMPTY_INDICES ? nullptr : indices_32.data(),
lengths.data(),
use_weight ? weights.data() : nullptr,
output.data(),
Expand All @@ -390,6 +342,10 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, rowwiseSparseTest) {
// Check correctness
EXPECT_EQ(success, success_ref)
<< "Reference and JIT impl did not both succeed";
if (corner_case == OUT_OF_BOUND_INDICES ||
corner_case == UNMATCHED_NUM_INDICES_AND_LENGTHS_SUM) {
EXPECT_EQ(success, false);
}
if (success) {
for (int i = 0; i < output.size(); ++i) {
EXPECT_EQ(output[i], output_ref[i])
Expand Down
Loading

0 comments on commit f40e6d1

Please sign in to comment.