Skip to content

Commit

Permalink
Adding FP 32 SLS, and unifying it with 8 Bit SLS (pytorch#206)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#206

This adds support for 32 bit indices to JITed FP 32  SLS Op.

This diff includes the following features of SLS:

1. Normalize by lengths
2. modified prefetch distances for avx2 vs. avx512
3. adds support for 32 bit indices
4. has support for weighted SLS, and supports positional weights
5. Does not specialize for blocksize 1 for avx512 as this reorders reduction.

Reviewed By: jspark1105

Differential Revision: D18210640

fbshipit-source-id: f9b4de5707a59cae5d34cb898c0cf52bc5f2a91f
  • Loading branch information
Protonu Basu authored and jspark1105 committed Mar 21, 2020
1 parent 10eff7b commit 81cc0a1
Show file tree
Hide file tree
Showing 10 changed files with 1,006 additions and 199 deletions.
3 changes: 2 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ set(CMAKE_EXPORT_COMPILE_COMMANDS ON)

#All the source files that either use avx2 instructions statically or JIT
#avx2/avx512 instructions.
set(FBGEMM_GENERIC_SRCS src/ExecuteKernel.cc
set(FBGEMM_GENERIC_SRCS src/EmbeddingSpMDM.cc
src/ExecuteKernel.cc
src/ExecuteKernelU8S8.cc
src/Fbgemm.cc
src/FbgemmBfloat16Convert.cc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,10 @@ static vector<vector<int>> GetInputs_() {
// {10, 4000000, 128, 100},
// {10, 4000000, 256, 100},
// Use these for debugging
{2, 16, 128, 10},
{10, 4000, 128, 100},
{10, 4000, 128, 100},
{10, 4000, 128, 100},
{2, 16, 128, 10},
{10, 4000, 128, 100},
{10, 4000, 128, 100},
{10, 4000, 128, 100},
};
return input_dims;
}
Expand Down Expand Up @@ -167,32 +167,30 @@ int run_benchmark(

for (int i = 0; i < NUM_WARMUP + NUM_ITER; ++i) {
if (use_32_bit_indices) {
fbgemm::
Fused8BitRowwiseEmbeddingLookup_ref<int32_t, uint8_t, float, false>(
embedding_dim,
batch_size,
lengths_sum,
num_unique_ids,
fused_embedding_table,
indices_32.data(),
lengths.data(),
has_weight ? weights.data() : nullptr,
normalize_by_lengths,
output_ref.data());
fbgemm::EmbeddingSpMDM_ref(
embedding_dim,
batch_size,
lengths_sum,
num_unique_ids,
fused_embedding_table,
indices_32.data(),
lengths.data(),
has_weight ? weights.data() : nullptr,
normalize_by_lengths,
output_ref.data());

} else {
fbgemm::
Fused8BitRowwiseEmbeddingLookup_ref<int64_t, uint8_t, float, false>(
embedding_dim,
batch_size,
lengths_sum,
num_unique_ids,
fused_embedding_table,
indices.data(),
lengths.data(),
has_weight ? weights.data() : nullptr,
normalize_by_lengths,
output_ref.data());
fbgemm::EmbeddingSpMDM_ref(
embedding_dim,
batch_size,
lengths_sum,
num_unique_ids,
fused_embedding_table,
indices.data(),
lengths.data(),
has_weight ? weights.data() : nullptr,
normalize_by_lengths,
output_ref.data());
}
}

Expand All @@ -201,7 +199,6 @@ int run_benchmark(
t = 0;
for (int i = 0; i < NUM_WARMUP + NUM_ITER; ++i) {
if (flush_cache) {
llc_flush(embedding_table);
llc_flush_fused_table(
fused_embedding_table, num_unique_ids * (embedding_dim + 8));
llc_flush(indices);
Expand All @@ -214,7 +211,7 @@ int run_benchmark(
if (use_32_bit_indices) {
t_begin = chrono::system_clock::now();

fbgemm::Fused8BitRowwiseEmbeddingLookup<int32_t>(
fbgemm::EmbeddingSpMDM<uint8_t, int32_t>(
embedding_dim,
batch_size,
lengths_sum,
Expand All @@ -232,7 +229,7 @@ int run_benchmark(
} else {
t_begin = chrono::system_clock::now();

fbgemm::Fused8BitRowwiseEmbeddingLookup<int64_t>(
fbgemm::EmbeddingSpMDM<uint8_t, int64_t>(
embedding_dim,
batch_size,
lengths_sum,
Expand Down
Loading

0 comments on commit 81cc0a1

Please sign in to comment.