Skip to content

Commit

Permalink
add more instantiation for EmbeddingSpMDMAvx2 (pytorch#274)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#274

Solved unresolved references in test/EmbeddingSpMDMTest.cc. Add instantiation for input type uint8_t.

Reviewed By: jspark1105

Differential Revision: D19654528

fbshipit-source-id: 5dca40e1f6530c5720f290c27730e2e13bd9dce7
  • Loading branch information
Hongzhang Shan authored and jspark1105 committed Mar 21, 2020
1 parent 492a0b8 commit 1ea1a02
Showing 1 changed file with 24 additions and 0 deletions.
24 changes: 24 additions & 0 deletions src/EmbeddingSpMDMAvx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -173,5 +173,29 @@ template bool EmbeddingSpMDMBlockSize1_(
float* out,
bool is_weight_positional);

template bool EmbeddingSpMDMBlockSize1_(
const std::int64_t output_size,
const std::int64_t index_size,
const std::int64_t data_size, // the number of rows in input
const std::uint8_t* input,
const std::int64_t* indices,
const int* lengths,
const float* weights, // optional, can be null for non-weighted sum
bool normalize_by_lengths,
float* out,
bool is_weight_positional);

template bool EmbeddingSpMDMBlockSize1_(
const std::int64_t output_size,
const std::int64_t index_size,
const std::int64_t data_size, // the number of rows in input
const std::uint8_t* input,
const std::int32_t* indices,
const int* lengths,
const float* weights, // optional, can be null for non-weighted sum
bool normalize_by_lengths,
float* out,
bool is_weight_positional);

} // namespace internal
} // namespace fbgemm

0 comments on commit 1ea1a02

Please sign in to comment.