Skip to content

Commit

Permalink
scale_bias_last option to support TBE layout (pytorch#848)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#848

This diff adds the followings to be used in quantized table batched embedding (TBE)
* scale_bias_last : by default true which is the old fbgemm CPU embedding JIT'ed kernel behavior. If false, scale and bias appear at the beginning of each row and are in fp16 matching with TBE. If false, it can also take -1 indices (output from pruned embedding id mapping)
* OutType can be fp16
* output_stride and input_stride support for int4/int2 embedding
* Fix a bug related to masking for fp16

Reviewed By: jianyuh

Differential Revision: D33430251

fbshipit-source-id: 59569f2b1ebf8cde40756fa3d7d013a61da6736d
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Jan 8, 2022
1 parent fa20cb3 commit 747fc4a
Show file tree
Hide file tree
Showing 8 changed files with 1,228 additions and 934 deletions.
92 changes: 66 additions & 26 deletions include/fbgemm/FbgemmEmbedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ namespace fbgemm {
template <
typename InType,
typename IndexType,
typename OffsetType = std::int32_t>
typename OffsetType = std::int32_t,
typename OutType = float>
class EmbeddingSpMDMKernelSignature {
public:
/**
Expand Down Expand Up @@ -43,7 +44,7 @@ class EmbeddingSpMDMKernelSignature {
const IndexType* indices,
const OffsetType* offsets_or_lengths,
const float* weights, // optional, can be null for non-weighted sum
float* out)>;
OutType* out)>;
};

/**
Expand All @@ -62,16 +63,20 @@ class EmbeddingSpMDMKernelSignature {
template <
typename InType,
typename IndexType,
typename OffsetType = std::int32_t>
FBGEMM_API
typename EmbeddingSpMDMKernelSignature<InType, IndexType, OffsetType>::Type
GenerateEmbeddingSpMDM(
const std::int64_t block_size,
bool has_weight,
bool normalize_by_lengths,
int prefetch = 16,
bool is_weight_positional = false,
bool use_offsets = true);
typename OffsetType = std::int32_t,
typename OutType = float>
FBGEMM_API typename EmbeddingSpMDMKernelSignature<
InType,
IndexType,
OffsetType,
OutType>::Type
GenerateEmbeddingSpMDM(
const std::int64_t block_size,
bool has_weight,
bool normalize_by_lengths,
int prefetch = 16,
bool is_weight_positional = false,
bool use_offsets = true);

/**
* @param output_stride If -1, output_stride is same as block_size
Expand All @@ -80,29 +85,38 @@ FBGEMM_API
template <
typename InType,
typename IndexType,
typename OffsetType = std::int32_t>
FBGEMM_API
typename EmbeddingSpMDMKernelSignature<InType, IndexType, OffsetType>::Type
GenerateEmbeddingSpMDMWithStrides(
const std::int64_t block_size,
bool has_weight,
bool normalize_by_lengths,
int prefetch = 16,
bool is_weight_positional = false,
bool use_offsets = true,
std::int64_t output_stride = -1,
std::int64_t input_stride = -1);
typename OffsetType = std::int32_t,
typename OutType = float>
FBGEMM_API typename EmbeddingSpMDMKernelSignature<
InType,
IndexType,
OffsetType,
OutType>::Type
GenerateEmbeddingSpMDMWithStrides(
const std::int64_t block_size,
bool has_weight,
bool normalize_by_lengths,
int prefetch = 16,
bool is_weight_positional = false,
bool use_offsets = true,
std::int64_t output_stride = -1,
std::int64_t input_stride = -1,
bool scale_bias_last = true);

/**
* @tparam IndexType can be int32_t or int64_t
* @tparam OffsetType can be int32_t or int64_t
* @param bit_rate can be 2 or 4
*/
template <typename IndexType, typename OffsetType = std::int32_t>
template <
typename IndexType,
typename OffsetType = std::int32_t,
typename OutType = float>
FBGEMM_API typename EmbeddingSpMDMKernelSignature<
std::uint8_t,
IndexType,
OffsetType>::Type
OffsetType,
OutType>::Type
GenerateEmbeddingSpMDMNBit(
int bit_rate,
const std::int64_t block_size,
Expand All @@ -112,6 +126,32 @@ GenerateEmbeddingSpMDMNBit(
bool is_weight_positional = false,
bool use_offsets = true);

/**
* @param output_stride If -1, output_stride is same as block_size
* @param input_stride in Bytes. If -1, input_stride is same as
* block_size / num_elem_per_byte + 2 * sizeof(float16)
*/
template <
typename IndexType,
typename OffsetType = std::int32_t,
typename OutType = float>
FBGEMM_API typename EmbeddingSpMDMKernelSignature<
std::uint8_t,
IndexType,
OffsetType,
OutType>::Type
GenerateEmbeddingSpMDMNBitWithStrides(
int bit_rate,
const std::int64_t block_size,
bool has_weight,
bool normalize_by_lengths,
int prefetch = 16,
bool is_weight_positional = false,
bool use_offsets = true,
std::int64_t output_stride = -1,
std::int64_t input_stride = -1,
bool scale_bias_last = true);

template <
typename InType,
typename IndexType,
Expand Down
Loading

0 comments on commit 747fc4a

Please sign in to comment.