Skip to content

Commit

Permalink
enable EmbeddingSpMDMNBitRowWiseSparse avx2 (pytorch#280)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#280

As title

Reviewed By: jianyuh

Differential Revision: D19686831

fbshipit-source-id: df7abf8d25e899fb6db82b0f60b00f0a99e8eaae
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Feb 4, 2020
1 parent e384ddc commit 07fba56
Showing 1 changed file with 11 additions and 3 deletions.
14 changes: 11 additions & 3 deletions src/EmbeddingSpMDMNBit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,7 @@ GenEmbeddingSpMDMNBitLookup<indxType, ROWWISE_SPARSE>::getOrCreate(
x86::Gp scratchReg2_ = a->gpz(reg_id); // 14 or 15
x86::Gp scratchReg3_;
if (instSet == inst_set_t::avx2) {
// Can't be combined with ROWWISE_SPARSE
++reg_id;
scratchReg3_ = a->gpz(reg_id); // 15
scratchReg3_ = a->zax();
}

asmjit::FuncDetail func;
Expand Down Expand Up @@ -931,6 +929,16 @@ GenerateEmbeddingSpMDMNBitRowWiseSparse(
/*is_weight_positional*/ false,
normalize_by_lengths,
prefetch);
} else if (fbgemmHasAvx2Support()) {
static GenEmbeddingSpMDMNBitLookup<indxType, true /* rowwise_sparse */>
kernel_generator;
return kernel_generator.template getOrCreate<inst_set_t::avx2>(
bit_rate,
block_size,
has_weight,
/*is_weight_positional*/ false,
normalize_by_lengths,
prefetch);
} else {
#ifdef VLOG
VLOG(0) << "AVX2 or AVX512 not found, taking the slow path";
Expand Down

0 comments on commit 07fba56

Please sign in to comment.