Skip to content

Commit

Permalink
cleaning up some code in referene implementations WRT SLS that is not…
Browse files Browse the repository at this point in the history
… needed

Summary: Removing some code in Ref implementation WRT SLS that was not needed.

Reviewed By: dskhudia

Differential Revision: D19147339

fbshipit-source-id: 6463b8e6319619f7eabe06e9817c07230d70310e
  • Loading branch information
Protonu Basu authored and jspark1105 committed Mar 21, 2020
1 parent 76e8f19 commit 09af4c9
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 129 deletions.
86 changes: 0 additions & 86 deletions src/RefImplementations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -632,92 +632,6 @@ void transposeConvWeights(
}
}

template <
typename IndexType,
typename InType,
typename OutType,
bool IS_WEIGHT_POSITIONAL>
bool Fused8BitRowwiseEmbeddingLookup_ref(
const int64_t block_size,
const int64_t output_size,
const int64_t index_size,
const int64_t data_size,
const InType* input,
const IndexType* indices,
const int* lengths,
const float* weights, // optional, can be null for sum reducer
bool normalize_by_lengths,
OutType* out) {
// block_size is the number of elements and fused_block_size is the size of
// an entire row, including scale and bias.
const auto scale_bias_offset = 8 / sizeof(InType);
const int64_t fused_block_size = block_size + scale_bias_offset;
int64_t current = 0;
for (int m = 0; m < output_size; ++m) {
memset(out, 0, sizeof(OutType) * block_size);
if (current + lengths[m] > index_size) {
return false;
}
for (int i = 0; i < lengths[m]; ++i) {
int64_t idx = indices[current];
if (idx < 0 || idx >= data_size) {
return false;
}

const float* scale_bias = reinterpret_cast<const float*>(
input + fused_block_size * indices[current] + block_size);

float weight = 1.0f;
if (weights) {
weight = weights[IS_WEIGHT_POSITIONAL ? i : current];
}
const float scale = weight * scale_bias[0];
const float bias = weight * scale_bias[1];

for (int j = 0; j < block_size; ++j) {
out[j] = std::fma(
scale,
input[fused_block_size * indices[current] + j],
out[j] + bias);
}

++current;
}
if (normalize_by_lengths && lengths[m]) {
float scale = 1.f / lengths[m];
for (int j = 0; j < block_size; ++j) {
out[j] *= scale;
}
}
out += block_size;
}
return true;
}

template bool Fused8BitRowwiseEmbeddingLookup_ref(
const std::int64_t block_size,
const std::int64_t output_size,
const std::int64_t index_size,
const std::int64_t data_size,
const uint8_t* input,
const std::int64_t* indices,
const int* lengths,
const float* weights, // optional, can be null for non-weighted sum
bool normalize_by_lengths,
float* out);

template bool Fused8BitRowwiseEmbeddingLookup_ref(
const std::int64_t block_size,
const std::int64_t output_size,
const std::int64_t index_size,
const std::int64_t data_size,
const uint8_t* input,
const std::int32_t* indices,
const int* lengths,
const float* weights, // optional, can be null for non-weighted sum
bool normalize_by_lengths,
float* out);

template <typename inType, typename IndexType>
bool EmbeddingSpMDM_ref(
const std::int64_t block_size,
Expand Down
17 changes: 0 additions & 17 deletions src/RefImplementations.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,23 +214,6 @@ FBGEMM_API void im2col_ref(
std::int32_t A_zero_point,
std::uint8_t* Ao);

template <
typename IndexType,
typename InType,
typename OutType,
bool IS_WEIGHT_POSITIONAL = false>
FBGEMM_API bool Fused8BitRowwiseEmbeddingLookup_ref(
const std::int64_t block_size,
const std::int64_t output_size,
const std::int64_t index_size,
const std::int64_t data_size,
const InType* input,
const IndexType* indices,
const int* lengths,
const float* weights, // optional, can be null for non-weighted sum
bool normalize_by_lengths,
OutType* out);

template <typename inType = std::uint8_t, typename IndexType = std::int64_t>
FBGEMM_API bool EmbeddingSpMDM_ref(
const std::int64_t block_size,
Expand Down
49 changes: 23 additions & 26 deletions test/EmbeddingSpMDM8BitTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
* LICENSE file in the root directory of this source tree.
*/
#include <algorithm>
#include <numeric>
#include <ostream>
#include <random>
#include <stdexcept>
#include <numeric>

#include <gtest/gtest.h>

Expand Down Expand Up @@ -130,7 +130,6 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, basicTest) {

// Compute the number of indices
int lengths_sum = accumulate(lengths.begin(), lengths.end(), 0);
//cout << "lenths sum " << lengths_sum;

// Generate indices
vector<int64_t> indices;
Expand All @@ -156,18 +155,17 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, basicTest) {
vector<float>& output_ref = use_weight ? output_slws_ref : output_sls_ref;
vector<float>& output = use_weight ? output_slws : output_sls;
if (isIndex64b) {
fbgemm::
Fused8BitRowwiseEmbeddingLookup_ref<int64_t, uint8_t, float, false>(
embedding_dim,
batch_size,
lengths_sum,
num_unique_ids,
fused_embedding_table,
empty_indices ? nullptr : indices.data(),
lengths.data(),
use_weight ? weights.data() : nullptr,
normalize_by_lengths,
output_ref.data());
fbgemm::EmbeddingSpMDM_ref<uint8_t, int64_t>(
embedding_dim,
batch_size,
lengths_sum,
num_unique_ids,
fused_embedding_table,
empty_indices ? nullptr : indices.data(),
lengths.data(),
use_weight ? weights.data() : nullptr,
normalize_by_lengths,
output_ref.data());

fbgemm::EmbeddingSpMDM<uint8_t, int64_t>(
embedding_dim,
Expand All @@ -183,18 +181,17 @@ TEST_P(Fused8BitRowwiseEmbeddingLookupTest, basicTest) {
prefetch ? 16 : 0);

} else {
fbgemm::
Fused8BitRowwiseEmbeddingLookup_ref<int32_t, uint8_t, float, false>(
embedding_dim,
batch_size,
lengths_sum,
num_unique_ids,
fused_embedding_table,
empty_indices ? nullptr : indices_32.data(),
lengths.data(),
use_weight ? weights.data() : nullptr,
normalize_by_lengths,
output_ref.data());
fbgemm::EmbeddingSpMDM_ref<uint8_t, int32_t>(
embedding_dim,
batch_size,
lengths_sum,
num_unique_ids,
fused_embedding_table,
empty_indices ? nullptr : indices_32.data(),
lengths.data(),
use_weight ? weights.data() : nullptr,
normalize_by_lengths,
output_ref.data());

fbgemm::EmbeddingSpMDM<uint8_t, int32_t>(
embedding_dim,
Expand Down

0 comments on commit 09af4c9

Please sign in to comment.