Skip to content

Commit

Permalink
2bit JIT'ed SLS kernel (pytorch#234)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#234

As title

Reviewed By: dskhudia

Differential Revision: D19258370

fbshipit-source-id: 46f86d8b69ef73302fc2f71e8d379e32eef692d8
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Jan 31, 2020
1 parent 969d173 commit 0e31e0a
Show file tree
Hide file tree
Showing 7 changed files with 303 additions and 258 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
#All the source files that either use avx2 instructions statically or JIT
#avx2/avx512 instructions.
set(FBGEMM_GENERIC_SRCS src/EmbeddingSpMDM.cc
src/EmbeddingSpMDM4Bit.cc
src/EmbeddingSpMDMNBit.cc
src/ExecuteKernel.cc
src/ExecuteKernelU8S8.cc
src/Fbgemm.cc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ static vector<vector<int>> GetInputs_() {
}

int run_benchmark(
int bit_rate,
int batch_size,
int num_unique_ids,
int embedding_dim,
Expand All @@ -75,19 +76,24 @@ int run_benchmark(
bool use_32_bit_indices = false,
bool prefetch = false) {
// Create embedding table
int fused_embedding_dim = (embedding_dim + 1) / 2 + 2 * sizeof(float16);
int num_elem_per_byte = 8 / bit_rate;
int fused_embedding_dim =
(embedding_dim + num_elem_per_byte - 1) / num_elem_per_byte +
2 * sizeof(float16);
vector<uint8_t> embedding_table(num_unique_ids * fused_embedding_dim);
default_random_engine generator;
normal_distribution<float> embedding_distribution;

vector<uint8_t> fused_embedding_table(num_unique_ids * fused_embedding_dim);
for (int i = 0; i < num_unique_ids; i++) {
for (int ii = 0; ii < embedding_dim / 2; ii++) {
for (int ii = 0;
ii < (embedding_dim + num_elem_per_byte - 1) / num_elem_per_byte;
ii++) {
fused_embedding_table[i * fused_embedding_dim + ii] = 2;
}
float16* scale_bias = reinterpret_cast<float16*>(
&fused_embedding_table[i * fused_embedding_dim] +
(embedding_dim + 1) / 2);
(embedding_dim + num_elem_per_byte - 1) / num_elem_per_byte);
float scale = 2.0f;
float bias = 1.0f;
FloatToFloat16_ref(&scale, scale_bias, 1, true /* clip */);
Expand Down Expand Up @@ -140,20 +146,20 @@ int run_benchmark(
constexpr int NUM_ITER = 10;
// Only counts the number of bytes for reading embedding table and ignore
// others. Should be good enough as long as embdding_dim is big enough.
double bytes =
lengths_sum * (embedding_dim * sizeof(uint8_t) + 2 * sizeof(float));
double bytes_padded =
lengths_sum * 64 *
static_cast<int>(
(embedding_dim * sizeof(uint8_t) + 2 * sizeof(float) + 63) / 64);
double bytes = lengths_sum * fused_embedding_dim;
constexpr int CACHE_LINE_LEN = 64;
double bytes_padded = lengths_sum * CACHE_LINE_LEN *
static_cast<int>((fused_embedding_dim + CACHE_LINE_LEN - 1) /
CACHE_LINE_LEN);

for (bool has_weight : {false, true}) {
vector<float>& output_ref = has_weight ? output_slws_ref : output_sls_ref;

bool success = false, success_ref = false;

if (use_32_bit_indices) {
success_ref = EmbeddingSpMDM4Bit_ref(
success_ref = EmbeddingSpMDMNBit_ref(
bit_rate,
embedding_dim,
batch_size,
lengths_sum,
Expand All @@ -164,9 +170,9 @@ int run_benchmark(
has_weight ? weights.data() : nullptr,
normalize_by_lengths,
output_ref.data());

} else {
success_ref = EmbeddingSpMDM4Bit_ref(
success = EmbeddingSpMDMNBit_ref(
bit_rate,
embedding_dim,
batch_size,
lengths_sum,
Expand All @@ -179,10 +185,10 @@ int run_benchmark(
output_ref.data());
}

auto kernel_32 = GenerateEmbeddingSpMDM4Bit<int32_t>(
embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0);
auto kernel_64 = GenerateEmbeddingSpMDM4Bit<int64_t>(
embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0);
auto kernel_32 = GenerateEmbeddingSpMDMNBit<int32_t>(
bit_rate, embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0);
auto kernel_64 = GenerateEmbeddingSpMDMNBit<int64_t>(
bit_rate, embedding_dim, has_weight, normalize_by_lengths, prefetch ? 16 : 0);

vector<float>& output = has_weight ? output_slws : output_sls;
for (bool flush_cache : {false, true}) {
Expand Down Expand Up @@ -275,58 +281,75 @@ int main() {

vector<vector<int>> inputs(GetInputs_());

for (auto& input : inputs) {
assert(input.size() > 3);
batch_size = input[0];
num_unique_ids = input[1];
embedding_dim = input[2];
average_len = input[3];

cout << "batch size" << setw(6) << batch_size << setw(10) << "num rows"
<< setw(16) << num_unique_ids << setw(10) << "emb dim" << setw(6)
<< embedding_dim << setw(16) << "avg length" << setw(6) << average_len
<< endl;
// args: batch sz, num rows, emb dim, avg len, normalize, use 32b, prefetch
cout << "64 bit indices, ";
run_benchmark(
batch_size, num_unique_ids, embedding_dim, average_len, false);

cout << "64 bit indices with prefetching, ";
run_benchmark(
batch_size,
num_unique_ids,
embedding_dim,
average_len,
false,
false,
true);

cout << "32 bit indices, ";
run_benchmark(
batch_size, num_unique_ids, embedding_dim, average_len, false, true);

cout << "32 bit indices with prefetching, ";
run_benchmark(
batch_size,
num_unique_ids,
embedding_dim,
average_len,
false,
true,
true);

// running with normalize by lengths
// run_benchmark(batch_size, num_unique_ids, embedding_dim, average_len,
// true); run_benchmark(
// batch_size, num_unique_ids, embedding_dim, average_len, true, true);
// run_benchmark(
// batch_size,
// num_unique_ids,
// embedding_dim,
// average_len,
// false,
// true,
// true);
for (int bit_rate : {2, 4}) {
for (auto& input : inputs) {
assert(input.size() > 3);
batch_size = input[0];
num_unique_ids = input[1];
embedding_dim = input[2];
average_len = input[3];

cout << "bit_rate" << setw(6) << bit_rate << "batch size" << setw(6)
<< batch_size << setw(10) << "num rows" << setw(16) << num_unique_ids
<< setw(10) << "emb dim" << setw(6) << embedding_dim << setw(16)
<< "avg length" << setw(6) << average_len << endl;
// args: batch sz, num rows, emb dim, avg len, normalize, use 32b,
// prefetch
cout << "64 bit indices, ";
run_benchmark(
bit_rate,
batch_size,
num_unique_ids,
embedding_dim,
average_len,
false); // normalize_by_lengths

cout << "64 bit indices with prefetching, ";
run_benchmark(
bit_rate,
batch_size,
num_unique_ids,
embedding_dim,
average_len,
false, // normalize_by_lengths
false, // use_32_bit_indices
true); // prefetch

cout << "32 bit indices, ";
run_benchmark(
bit_rate,
batch_size,
num_unique_ids,
embedding_dim,
average_len,
false, // normalize_by_lengths
true); // use_32_bit_indices

cout << "32 bit indices with prefetching, ";
run_benchmark(
bit_rate,
batch_size,
num_unique_ids,
embedding_dim,
average_len,
false, // normalize_by_lengths
true, // use_32_bit_indices
true); // prefetch

// running with normalize by lengths
// run_benchmark(batch_size, num_unique_ids, embedding_dim, average_len,
// true); run_benchmark(
// batch_size, num_unique_ids, embedding_dim, average_len, true,
// true);
// run_benchmark(
// batch_size,
// num_unique_ids,
// embedding_dim,
// average_len,
// false,
// true,
// true);
}
}
return 0;
}
18 changes: 2 additions & 16 deletions include/fbgemm/FbgemmEmbedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,28 +52,14 @@ FBGEMM_API bool EmbeddingSpMDM(

template <typename IndexType>
FBGEMM_API typename EmbeddingSpMDMKernelSignature<std::uint8_t, IndexType>::Type
GenerateEmbeddingSpMDM4Bit(
GenerateEmbeddingSpMDMNBit(
int bit_rate,
const std::int64_t block_size,
bool has_weight,
bool normalize_by_lengths,
int prefetch = 16,
bool is_weight_positional = false);

template <typename IndexType = std::int64_t>
FBGEMM_API bool EmbeddingSpMDM4Bit(
const std::int64_t block_size,
const std::int64_t output_size,
const std::int64_t index_size,
const std::int64_t data_size, // the number of rows in input
const std::uint8_t* input,
const IndexType* indices,
const int* lengths,
const float* weights, // optional, can be null for non-weighted sum
bool normalize_by_lengths,
float* out,
int prefetch = 16,
bool is_weight_positional = false);

/**
* @return The number of rows processed. If smaller than num_rows, an error
* must have happened at the last row processed.
Expand Down
Loading

0 comments on commit 0e31e0a

Please sign in to comment.