Skip to content

Commit

Permalink
add option to use int64_t offsets/lengths in embedding operators (pyt…
Browse files Browse the repository at this point in the history
…orch#350)

Summary:
Pull Request resolved: pytorch#350

PyTorch passes int64_t offsets.
Inside the JIT'ed code, we actually cast int64 offsets/lengths to int32 but in practice this shouldn't be a problem unless we have more than 2B indices.

Other changes:
Move inst_set_t template parameter from generate function to class template parameter to follow the changes in other code generators in fbgemm

Reviewed By: jianyuh

Differential Revision: D20959297

fbshipit-source-id: c567ed979e8a35eefb04d3a4ce57c8afa13b066d
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Apr 10, 2020
1 parent 20b6cf1 commit 3db8b84
Showing 13 changed files with 1,948 additions and 1,427 deletions.
108 changes: 66 additions & 42 deletions include/fbgemm/FbgemmEmbedding.h
Original file line number Diff line number Diff line change
@@ -12,22 +12,26 @@

namespace fbgemm {

template <typename inType, typename IndexType>
template <
typename InType,
typename IndexType,
typename OffsetType = std::int32_t>
class EmbeddingSpMDMKernelSignature {
public:
using Type = std::function<bool(
std::int64_t output_size,
std::int64_t index_size,
std::int64_t data_size,
const inType* input,
const InType* input,
const IndexType* indices,
const int* offsets_or_lengths,
const OffsetType* offsets_or_lengths,
const float* weights, // optional, can be null for non-weighted sum
float* out)>;
};

/**
* @tparam inType can be float or uint8_t
* @tparam InType can be float, float16, or uint8_t
* @tparam IndexType can be int32_t or int64_t
* @tparam IndexType can be int32_t or int64_t
*
* @param use_offsets If true, the generated code assumes we will pass offsets
@@ -38,22 +42,30 @@ class EmbeddingSpMDMKernelSignature {
* If false, the generate code assumes we will pass lengths
* that confirms Caffe2 SparseLengthsSum interface.
*/
template <typename inType, typename IndexType>
FBGEMM_API typename EmbeddingSpMDMKernelSignature<inType, IndexType>::Type
GenerateEmbeddingSpMDM(
const std::int64_t block_size,
bool has_weight,
bool normalize_by_lengths,
int prefetch = 16,
bool is_weight_positional = false,
bool use_offsets = true);
template <
typename InType,
typename IndexType,
typename OffsetType = std::int32_t>
FBGEMM_API
typename EmbeddingSpMDMKernelSignature<InType, IndexType, OffsetType>::Type
GenerateEmbeddingSpMDM(
const std::int64_t block_size,
bool has_weight,
bool normalize_by_lengths,
int prefetch = 16,
bool is_weight_positional = false,
bool use_offsets = true);

/**
* @tparam IndexType can be int32_t or int64_t
* @tparam OffsetType can be int32_t or int64_t
* @param bit_rate can be 2 or 4
*/
template <typename IndexType>
FBGEMM_API typename EmbeddingSpMDMKernelSignature<std::uint8_t, IndexType>::Type
template <typename IndexType, typename OffsetType = std::int32_t>
FBGEMM_API typename EmbeddingSpMDMKernelSignature<
std::uint8_t,
IndexType,
OffsetType>::Type
GenerateEmbeddingSpMDMNBit(
int bit_rate,
const std::int64_t block_size,
@@ -63,45 +75,56 @@ GenerateEmbeddingSpMDMNBit(
bool is_weight_positional = false,
bool use_offsets = true);

template <typename inType, typename IndexType>
template <
typename InType,
typename IndexType,
typename OffsetType = std::int32_t>
class EmbeddingSpMDMRowWiseSparseKernelSignature {
public:
using Type = std::function<bool(
std::int64_t output_size,
std::int64_t index_size,
std::int64_t uncompressed_data_size,
// TODO: add compressed_data_size and check array bound
const inType* input,
const InType* input,
const IndexType* indices,
const int* offsets_or_lengths,
const OffsetType* offsets_or_lengths,
const float* weights, // optional, can be null for non-weighted sum
float* out,
const std::int32_t* compressed_indices_table)>;
};

/**
* @tparam inType can be float or uint8_t
* @tparam InType can be float, float16, or uint8_t
* @tparam IndexType can be int32_t or int64_t
* @tparam OffsetType can be int32_t or int64_t
*/
template <typename inType, typename IndexType>
FBGEMM_API
typename EmbeddingSpMDMRowWiseSparseKernelSignature<inType, IndexType>::Type
GenerateEmbeddingSpMDMRowWiseSparse(
const std::int64_t block_size,
bool has_weight,
bool normalize_by_lengths,
int prefetch = 16,
bool is_weight_positional = false,
bool use_offsets = true);
template <
typename InType,
typename IndexType,
typename OffsetType = std::int32_t>
FBGEMM_API typename EmbeddingSpMDMRowWiseSparseKernelSignature<
InType,
IndexType,
OffsetType>::Type
GenerateEmbeddingSpMDMRowWiseSparse(
const std::int64_t block_size,
bool has_weight,
bool normalize_by_lengths,
int prefetch = 16,
bool is_weight_positional = false,
bool use_offsets = true);

/**
* @tparam IndexType can be int32_t or int64_t
* @tparam OffsetType can be int32_t or int64_t
* @param bit_rate can be 2 or 4
*/
template <typename IndexType>
template <typename IndexType, typename OffsetType = std::int32_t>
FBGEMM_API typename EmbeddingSpMDMRowWiseSparseKernelSignature<
std::uint8_t,
IndexType>::Type
IndexType,
OffsetType>::Type
GenerateEmbeddingSpMDMNBitRowWiseSparse(
int bit_rate,
const std::int64_t block_size,
@@ -151,7 +174,7 @@ FBGEMM_API int SparseAdaGrad(
int prefetch = 16);

// RowWiseSparseAdaGrad fused with SLS gradient
template <typename IndexType>
template <typename IndexType, typename OffsetType = std::int32_t>
class RowWiseSparseAdaGradFusedSignature {
public:
using Type = std::function<bool(
@@ -162,28 +185,29 @@ class RowWiseSparseAdaGradFusedSignature {
const float* g, // input gradients
float* h, // input/output momentums
const IndexType* indices, // indices of each row
const int* offsets_or_lengths,
const OffsetType* offsets_or_lengths,
float epsilon,
float lr)>;
};

template <typename IndexType>
FBGEMM_API typename RowWiseSparseAdaGradFusedSignature<IndexType>::Type
GenerateRowWiseSparseAdaGradFused(
int block_size, // number of parameters per row
int prefetch = 16,
bool use_offsets = true);
template <typename IndexType, typename OffsetType = std::int32_t>
FBGEMM_API
typename RowWiseSparseAdaGradFusedSignature<IndexType, OffsetType>::Type
GenerateRowWiseSparseAdaGradFused(
int block_size, // number of parameters per row
int prefetch = 16,
bool use_offsets = true);

namespace internal {
// Specialization for block size 1 internally called by GenerateEmbeddingSpMDM
template <typename inType = float, typename IndexType = std::int64_t>
template <typename InType, typename IndexType, typename OffsetType>
FBGEMM_API bool EmbeddingSpMDMBlockSize1_(
const std::int64_t output_size,
const std::int64_t index_size,
const std::int64_t data_size, // the number of rows in input
const inType* input,
const InType* input,
const IndexType* indices,
const int* offsets_or_lengths,
const OffsetType* offsets_or_lengths,
const float* weights, // optional, can be null for non-weighted sum
bool normalize_by_lengths,
float* out,
Loading

0 comments on commit 3db8b84

Please sign in to comment.