Skip to content

Commit

Permalink
add offset-based interface for PyTorch (pytorch#334)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#334

Add a variation of embedding operators using offsets rather than lengths. This is for PyTorch EmbeddingBag. Assume the length of offsets is output_size + 1 where offsets[output_size] == indices_size, confirming the standard compressed sparse row (CSR) convention.

This diff maintains a backward compatibility (use_offsets = false).
We'll have a separate follow-up Caffe2 diff that explicitly passes use_offsets = false, then another diff will change the default of use_offsets to true (making PyTorch as a default). The last diff will break backward compatibility but should be fine because we have only a few call-sites in Caffe2 that we can modify accordingly.

Reviewed By: jianyuh

Differential Revision: D20746569

fbshipit-source-id: 7530eca8cbf1efb40032ec3e695fb4a64ae13884
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Mar 31, 2020
1 parent a6e81fb commit df34258
Show file tree
Hide file tree
Showing 13 changed files with 533 additions and 300 deletions.
34 changes: 24 additions & 10 deletions include/fbgemm/FbgemmEmbedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,22 @@ class EmbeddingSpMDMKernelSignature {
std::int64_t data_size,
const inType* input,
const IndexType* indices,
const int* lengths,
const int* offsets_or_lengths,
const float* weights, // optional, can be null for non-weighted sum
float* out)>;
};

/**
* @tparam inType can be float or uint8_t
* @tparam IndexType can be int32_t or int64_t
*
* @param use_offsets If true, the generated code assumes we will pass offsets
* instead of lengths that confirms PyTorch EmbeddingBag
* interface. In this case, the length of offsets array
* should be output_size + 1 and offsets[output_size] should
* be index_size.
* If false, the generate code assumes we will pass lengths
* that confirms Caffe2 SparseLengthsSum interface.
*/
template <typename inType, typename IndexType>
FBGEMM_API typename EmbeddingSpMDMKernelSignature<inType, IndexType>::Type
Expand All @@ -37,7 +45,8 @@ GenerateEmbeddingSpMDM(
bool has_weight,
bool normalize_by_lengths,
int prefetch = 16,
bool is_weight_positional = false);
bool is_weight_positional = false,
bool use_offsets = false);

/**
* @tparam IndexType can be int32_t or int64_t
Expand All @@ -51,7 +60,8 @@ GenerateEmbeddingSpMDMNBit(
bool has_weight,
bool normalize_by_lengths,
int prefetch = 16,
bool is_weight_positional = false);
bool is_weight_positional = false,
bool use_offsets = false);

template <typename inType, typename IndexType>
class EmbeddingSpMDMRowWiseSparseKernelSignature {
Expand All @@ -63,7 +73,7 @@ class EmbeddingSpMDMRowWiseSparseKernelSignature {
// TODO: add compressed_data_size and check array bound
const inType* input,
const IndexType* indices,
const int* lengths,
const int* offsets_or_lengths,
const float* weights, // optional, can be null for non-weighted sum
float* out,
const std::int32_t* compressed_indices_table)>;
Expand All @@ -81,7 +91,8 @@ FBGEMM_API
bool has_weight,
bool normalize_by_lengths,
int prefetch = 16,
bool is_weight_positional = false);
bool is_weight_positional = false,
bool use_offsets = false);

/**
* @tparam IndexType can be int32_t or int64_t
Expand All @@ -97,7 +108,8 @@ GenerateEmbeddingSpMDMNBitRowWiseSparse(
bool has_weight,
bool normalize_by_lengths,
int prefetch = 16,
bool is_weight_positional = false);
bool is_weight_positional = false,
bool use_offsets = false);

/**
* @return The number of rows processed. If smaller than num_rows, an error
Expand Down Expand Up @@ -150,7 +162,7 @@ class RowWiseSparseAdaGradFusedSignature {
const float* g, // input gradients
float* h, // input/output momentums
const IndexType* indices, // indices of each row
const int* lengths,
const int* offsets_or_lengths,
float epsilon,
float lr)>;
};
Expand All @@ -159,7 +171,8 @@ template <typename IndexType>
FBGEMM_API typename RowWiseSparseAdaGradFusedSignature<IndexType>::Type
GenerateRowWiseSparseAdaGradFused(
int block_size, // number of parameters per row
int prefetch = 16);
int prefetch = 16,
bool use_offsets = false);

namespace internal {
// Specialization for block size 1 internally called by GenerateEmbeddingSpMDM
Expand All @@ -170,11 +183,12 @@ FBGEMM_API bool EmbeddingSpMDMBlockSize1_(
const std::int64_t data_size, // the number of rows in input
const inType* input,
const IndexType* indices,
const int* lengths,
const int* offsets_or_lengths,
const float* weights, // optional, can be null for non-weighted sum
bool normalize_by_lengths,
float* out,
bool is_weight_positional = false);
bool is_weight_positional = false,
bool use_offsets = false);

} // namespace internal

Expand Down
Loading

0 comments on commit df34258

Please sign in to comment.