Skip to content

Commit

Permalink
fix index mapping table type to int32_t (pytorch#302)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#302

To address a challenge of detecting index type during model transformation. We shouldn't expect the number of rows in an embedding table will exceed int32_t range especially after row-wise pruning

Reviewed By: dehuacheng, jianyuh

Differential Revision: D19954308

fbshipit-source-id: 8d152f747f7e95baa936bd66494414b516906b09
  • Loading branch information
jspark1105 committed Mar 21, 2020
1 parent 6a4ed1c commit 1a7be35
Show file tree
Hide file tree
Showing 11 changed files with 135 additions and 192 deletions.
11 changes: 3 additions & 8 deletions bench/EmbeddingSpMDMNBitRowWiseSparseBenchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ int run_benchmark(
// Generate mapping table
default_random_engine generator;
constexpr float sparsity = 0.7;
vector<int64_t> mapping_table(num_rows);
vector<int32_t> mapping_table(num_rows);
bernoulli_distribution row_prune_dist(sparsity);
int num_compressed_rows = 0;
for (int i = 0; i < num_rows; ++i) {
Expand All @@ -80,11 +80,6 @@ int run_benchmark(
++num_compressed_rows;
}
}
vector<int32_t> mapping_table_32;
copy(
mapping_table.begin(),
mapping_table.end(),
back_inserter(mapping_table_32));

// Create embedding table
int num_elem_per_byte = 8 / bit_rate;
Expand Down Expand Up @@ -192,7 +187,7 @@ int run_benchmark(
num_rows,
fused_embedding_table.data(),
indices_32.data(),
mapping_table_32.data(),
mapping_table.data(),
lengths.data(),
has_weight ? weights.data() : nullptr,
normalize_by_lengths,
Expand Down Expand Up @@ -243,7 +238,7 @@ int run_benchmark(
lengths.data(),
has_weight ? weights.data() : nullptr,
output.data(),
mapping_table_32.data());
mapping_table.data());
} else {
success = kernel_64(
batch_size,
Expand Down
2 changes: 1 addition & 1 deletion include/fbgemm/FbgemmEmbedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ class EmbeddingSpMDMRowWiseSparseKernelSignature {
const int* lengths,
const float* weights, // optional, can be null for non-weighted sum
float* out,
const IndexType* compressed_indices_table)>;
const std::int32_t* compressed_indices_table)>;
};

/**
Expand Down
81 changes: 30 additions & 51 deletions src/EmbeddingSpMDM.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class ReturnFunctionSignature<inType, indxType, true> {
const int* lengths,
const float* weights,
float* out,
const indxType* compressed_indices_table,
const std::int32_t* compressed_indices_table,
const int* mask);
};

Expand Down Expand Up @@ -203,18 +203,19 @@ typename ReturnFunctionSignature<inType, indxType, ROWWISE_SPARSE>::
asmjit::FuncDetail func;

if (ROWWISE_SPARSE) {
func.init(asmjit::FuncSignatureT<
bool,
std::int64_t, // output_size
std::int64_t, // index_size
std::int64_t, // uncompressed_data_size
const inType*, // input uint8_t or float
const indxType*, // indices
const int*, // lengths
const float*, // weights
float*, // out
const indxType*, // compressed_indices_table and then mask
const int*>(asmjit::CallConv::kIdHost));
func.init(
asmjit::FuncSignatureT<
bool,
std::int64_t, // output_size
std::int64_t, // index_size
std::int64_t, // uncompressed_data_size
const inType*, // input uint8_t or float
const indxType*, // indices
const int*, // lengths
const float*, // weights
float*, // out
const std::int32_t*, // compressed_indices_table and then mask
const int*>(asmjit::CallConv::kIdHost));
} else {
func.init(asmjit::FuncSignatureT<
bool,
Expand Down Expand Up @@ -450,21 +451,12 @@ typename ReturnFunctionSignature<inType, indxType, ROWWISE_SPARSE>::
a->jge(error);

if (ROWWISE_SPARSE) {
if (areIndices64b) {
a->mov(
scratchReg1_,
x86::qword_ptr(
compressed_indices_table,
scratchReg1_,
3)); // use of 3 is to multiply by 8
} else {
a->mov(
scratchReg1_.r32(),
x86::dword_ptr(
compressed_indices_table,
scratchReg1_,
2)); // use of 2 is to multiply by 4
}
a->mov(
scratchReg1_.r32(),
x86::dword_ptr(
compressed_indices_table,
scratchReg1_,
2)); // use of 2 is to multiply by 4
}

int fused_block_size = is8bit
Expand Down Expand Up @@ -511,21 +503,12 @@ typename ReturnFunctionSignature<inType, indxType, ROWWISE_SPARSE>::

a->bind(pref_dist_reset_end);
if (ROWWISE_SPARSE) {
if (areIndices64b) {
a->mov(
scratchReg2_,
x86::qword_ptr(
compressed_indices_table,
scratchReg2_,
3)); // use of 3 is to multiply by 8
} else {
a->mov(
scratchReg2_.r32(),
x86::dword_ptr(
compressed_indices_table,
scratchReg2_,
2)); // use of 2 is to multiply by 4
}
a->mov(
scratchReg2_.r32(),
x86::dword_ptr(
compressed_indices_table,
scratchReg2_,
2)); // use of 2 is to multiply by 4
}
a->imul(scratchReg2_, static_cast<asmjit::Imm>(fused_block_size));
}
Expand All @@ -538,11 +521,7 @@ typename ReturnFunctionSignature<inType, indxType, ROWWISE_SPARSE>::
}

if (ROWWISE_SPARSE) {
if (areIndices64b) {
a->cmp(scratchReg1_, static_cast<asmjit::Imm>(-1));
} else {
a->cmp(scratchReg1_.r32(), static_cast<asmjit::Imm>(-1));
}
a->cmp(scratchReg1_.r32(), static_cast<asmjit::Imm>(-1));
a->je(LoopDataIndexBegin);
}

Expand Down Expand Up @@ -922,7 +901,7 @@ GenerateEmbeddingSpMDMRowWiseSparse(
const int* lengths,
const float* weights,
float* out,
const indxType* compressed_indices_table) {
const std::int32_t* compressed_indices_table) {
return original_func(
output_size,
index_size,
Expand Down Expand Up @@ -953,7 +932,7 @@ GenerateEmbeddingSpMDMRowWiseSparse(
const int* lengths,
const float* weights,
float* out,
const indxType* compressed_indices_table) {
const std::int32_t* compressed_indices_table) {
return original_func(
output_size,
index_size,
Expand All @@ -979,7 +958,7 @@ GenerateEmbeddingSpMDMRowWiseSparse(
const int* lengths,
const float* weights, // optional, can be null for non-weighted sum
float* out,
const indxType* compressed_indices_table) {
const std::int32_t* compressed_indices_table) {
return EmbeddingSpMDMRowWiseSparse_ref(
block_size,
output_size,
Expand Down
58 changes: 18 additions & 40 deletions src/EmbeddingSpMDMNBit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class ReturnFunctionSignature<indxType, true> {
const int* lengths,
const float* weights,
float* out,
const indxType* compressed_indices_table,
const int32_t* compressed_indices_table,
const int* mask);
};

Expand Down Expand Up @@ -217,7 +217,7 @@ GenEmbeddingSpMDMNBitLookup<indxType, ROWWISE_SPARSE>::getOrCreate(
const int*, // lengths
const float*, // weights
float*, // out
const indxType* /* compressed_indices_table */,
const int32_t* /* compressed_indices_table */,
const int* /* mask */>(asmjit::CallConv::kIdHost));
} else {
func.init(asmjit::FuncSignatureT<
Expand Down Expand Up @@ -493,21 +493,12 @@ GenEmbeddingSpMDMNBitLookup<indxType, ROWWISE_SPARSE>::getOrCreate(
a->jge(error);

if (ROWWISE_SPARSE) {
if (areIndices64b) {
a->mov(
scratchReg1_,
x86::qword_ptr(
compressed_indices_table,
scratchReg1_,
3)); // use of 3 is to multiply by 8
} else {
a->mov(
scratchReg1_.r32(),
x86::dword_ptr(
compressed_indices_table,
scratchReg1_,
2)); // use of 2 is to multiply by 4
}
a->mov(
scratchReg1_.r32(),
x86::dword_ptr(
compressed_indices_table,
scratchReg1_,
2)); // use of 2 is to multiply by 4
}

int num_elem_per_byte = 8 / bit_rate;
Expand Down Expand Up @@ -553,21 +544,12 @@ GenEmbeddingSpMDMNBitLookup<indxType, ROWWISE_SPARSE>::getOrCreate(

a->bind(pref_dist_reset_end);
if (ROWWISE_SPARSE) {
if (areIndices64b) {
a->mov(
scratchReg2_,
x86::qword_ptr(
compressed_indices_table,
scratchReg2_,
3)); // use of 3 is to multiply by 8
} else {
a->mov(
scratchReg2_.r32(),
x86::dword_ptr(
compressed_indices_table,
scratchReg2_,
2)); // use of 2 is to multiply by 4
}
a->mov(
scratchReg2_.r32(),
x86::dword_ptr(
compressed_indices_table,
scratchReg2_,
2)); // use of 2 is to multiply by 4
}
// This has to be fused_block_size
a->imul(scratchReg2_, static_cast<asmjit::Imm>(fused_block_size));
Expand All @@ -581,11 +563,7 @@ GenEmbeddingSpMDMNBitLookup<indxType, ROWWISE_SPARSE>::getOrCreate(
}

if (ROWWISE_SPARSE) {
if (areIndices64b) {
a->cmp(scratchReg1_, static_cast<asmjit::Imm>(-1));
} else {
a->cmp(scratchReg1_.r32(), static_cast<asmjit::Imm>(-1));
}
a->cmp(scratchReg1_.r32(), static_cast<asmjit::Imm>(-1));
a->je(LoopDataIndexBegin);
}

Expand Down Expand Up @@ -977,7 +955,7 @@ GenerateEmbeddingSpMDMNBitRowWiseSparse(
const int* lengths,
const float* weights,
float* out,
const indxType* compressed_indices_table) {
const int32_t* compressed_indices_table) {
return original_func(
output_size,
index_size,
Expand Down Expand Up @@ -1009,7 +987,7 @@ GenerateEmbeddingSpMDMNBitRowWiseSparse(
const int* lengths,
const float* weights,
float* out,
const indxType* compressed_indices_table) {
const int32_t* compressed_indices_table) {
return original_func(
output_size,
index_size,
Expand All @@ -1034,7 +1012,7 @@ GenerateEmbeddingSpMDMNBitRowWiseSparse(
const int* lengths,
const float* weights,
float* out,
const indxType* compressed_indices_table) {
const int32_t* compressed_indices_table) {
return EmbeddingSpMDMNBitRowWiseSparse_ref(
bit_rate,
block_size,
Expand Down
Loading

0 comments on commit 1a7be35

Please sign in to comment.