Skip to content

Commit

Permalink
add input stride option to embedding spmdm
Browse files Browse the repository at this point in the history
Summary: To prepare D27562367 suppose input stride != block size

Reviewed By: jianyuh

Differential Revision: D27559072

fbshipit-source-id: 2d7ac6c394ab76188653eeaf97fd7413d5c79a01
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Apr 6, 2021
1 parent c565348 commit 9b623bb
Show file tree
Hide file tree
Showing 9 changed files with 268 additions and 186 deletions.
2 changes: 1 addition & 1 deletion fbgemm_gpu/codegen/embedding_forward_split_cpu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ void split_embedding_forward_cpu_kernel(
if (use_fbgemm) {
using fbgemm_weight_t =
typename ::internal::half2float16<weights_t>::type;
auto kernel = fbgemm::GenerateEmbeddingSpMDMWithOutputStride<
auto kernel = fbgemm::GenerateEmbeddingSpMDMWithStrides<
fbgemm_weight_t,
/*IndexType=*/int64_t,
/*OffsetType=*/int64_t>(
Expand Down
6 changes: 4 additions & 2 deletions include/fbgemm/FbgemmEmbedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,21 +75,23 @@ FBGEMM_API

/**
* @param output_stride If -1, output_stride is same as block_size
* @param input_stride If -1, input_stride is same as block_size
*/
template <
typename InType,
typename IndexType,
typename OffsetType = std::int32_t>
FBGEMM_API
typename EmbeddingSpMDMKernelSignature<InType, IndexType, OffsetType>::Type
GenerateEmbeddingSpMDMWithOutputStride(
GenerateEmbeddingSpMDMWithStrides(
const std::int64_t block_size,
bool has_weight,
bool normalize_by_lengths,
int prefetch = 16,
bool is_weight_positional = false,
bool use_offsets = true,
std::int64_t output_stride = -1);
std::int64_t output_stride = -1,
std::int64_t input_stride = -1);

/**
* @tparam IndexType can be int32_t or int64_t
Expand Down
Loading

0 comments on commit 9b623bb

Please sign in to comment.