Skip to content

Commit

Permalink
add grad stride option to fused rowwise sparse adagrad (pytorch#568)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#568

To use JIT'ed rowwise sparse adagrad kernel for table batched embedding

Reviewed By: jianyuh

Differential Revision: D27241384

fbshipit-source-id: 9d59bb390acd9c9117c0928d50bdb20d60ce1a90
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Mar 23, 2021
1 parent 7f78584 commit 8998e6f
Show file tree
Hide file tree
Showing 5 changed files with 86 additions and 35 deletions.
8 changes: 6 additions & 2 deletions include/fbgemm/FbgemmEmbedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ FBGEMM_API
bool use_offsets = true);

/**
* @param output_stride If not -1, output_stride is not same as block_size
* @param output_stride If -1, output_stride is same as block_size
*/
template <
typename InType,
Expand Down Expand Up @@ -222,6 +222,9 @@ class RowWiseSparseAdaGradFusedSignature {
float lr)>;
};

/**
* @param grad_stride If -1, grad_stride is same as block size
*/
template <
typename IndexType,
typename OffsetType = std::int32_t,
Expand All @@ -234,7 +237,8 @@ GenerateRowWiseSparseAdaGradFused(
int block_size, // number of parameters per row
int prefetch = 16,
bool use_offsets = true,
bool use_stochastic_rounding = true);
bool use_stochastic_rounding = true,
int grad_stride = -1);

namespace internal {
// Specialization for block size 1 internally called by GenerateEmbeddingSpMDM
Expand Down
12 changes: 9 additions & 3 deletions src/RefImplementations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1593,7 +1593,12 @@ int rowwise_sparse_adagrad_fused_ref(
float lr,
bool use_offsets,
bool use_stochastic_rounding,
int emu_vector_size) {
int emu_vector_size,
int64_t grad_stride) {
if (grad_stride == -1) {
grad_stride = block_size;
}

constexpr bool isFloat16w = std::is_same<float16, DataType>::value;
// Local random buffer to emulate SIMD vector
// R: generated 32bit base random numbers
Expand All @@ -1614,7 +1619,7 @@ int rowwise_sparse_adagrad_fused_ref(
if (current + len > index_size) {
return false;
}
const float* g_ = g + m * block_size;
const float* g_ = g + m * grad_stride;
// Note the following code assumes fbgemm will generate AVX2 code for
// horizontal reduction, which is OK for now because fbgemm always uses AVX2
// for SparseAdagrad due to its performance is bounded by memory bandwidth
Expand Down Expand Up @@ -1889,7 +1894,8 @@ template FBGEMM_API int rowwise_sparse_adagrad_ref(
float lr, \
bool use_offsets, \
bool use_stochastic_rounding, \
int emu_vector_size);
int emu_vector_size, \
int64_t grad_stride);

#define INSTANTIATE_SPMDM_OFFSET_T(DATA_TYPE, INDEX_TYPE) \
INSTANTIATE_SPMDM_BASE(DATA_TYPE, INDEX_TYPE, int32_t) \
Expand Down
3 changes: 2 additions & 1 deletion src/RefImplementations.h
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,7 @@ FBGEMM_API int rowwise_sparse_adagrad_fused_ref(
float lr,
bool use_offsets = true,
bool use_stochastic_rounding = true, // For DataType=float16
int emu_vector_size = 8);
int emu_vector_size = 8,
std::int64_t grad_stride = -1);

} // namespace fbgemm
64 changes: 45 additions & 19 deletions src/RowWiseSparseAdagradFused.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ class GenRowWiseSparseAdagradFused {
int block_size,
int prefetch,
bool use_offsets,
bool use_stochastic_rounding);
bool use_stochastic_rounding,
int grad_stride);

private:
static asmjit::JitRuntime& runtime() {
Expand All @@ -70,7 +71,7 @@ class GenRowWiseSparseAdagradFused {
// avx2 mask array, embedding dimension (block size), prefetch distance,
// use_offsets and use_stochastic_rouding switch
static CodeCache<
tuple<const int*, int, int, bool, bool>,
tuple<const int*, int, int, bool, bool, int>,
typename ReturnFunctionSignature<indxType, offsetType, dataType>::
jit_sparse_adagrad_kernel>
codeCache_; ///< JIT Code Cache for reuse.
Expand All @@ -90,7 +91,7 @@ template <
typename dataType,
inst_set_t instSet>
CodeCache<
tuple<const int*, int, int, bool, bool>,
tuple<const int*, int, int, bool, bool, int>,
typename ReturnFunctionSignature<indxType, offsetType, dataType>::
jit_sparse_adagrad_kernel>
GenRowWiseSparseAdagradFused<indxType, offsetType, dataType, instSet>::
Expand All @@ -109,9 +110,15 @@ typename ReturnFunctionSignature<indxType, offsetType, dataType>::
int block_size,
int prefetch,
bool use_offsets,
bool use_stochastic_rounding) {
tuple<const int*, int, int, bool, bool> kernelSig = make_tuple(
mask_avx2, block_size, prefetch, use_offsets, use_stochastic_rounding);
bool use_stochastic_rounding,
int grad_stride) {
tuple<const int*, int, int, bool, bool, int> kernelSig = make_tuple(
mask_avx2,
block_size,
prefetch,
use_offsets,
use_stochastic_rounding,
grad_stride);

return codeCache_.getOrCreate(
kernelSig,
Expand Down Expand Up @@ -699,7 +706,7 @@ typename ReturnFunctionSignature<indxType, offsetType, dataType>::
a->bind(LoopDataIndexEnd);

a->add(lengths, static_cast<asmjit::Imm>(sizeof(offsetType)));
a->add(g, static_cast<asmjit::Imm>(block_size * sizeof(float)));
a->add(g, static_cast<asmjit::Imm>(grad_stride * sizeof(float)));

a->jmp(LoopRangeIndexBegin);
a->bind(LoopRangeIndexEnd);
Expand Down Expand Up @@ -800,10 +807,14 @@ GenerateRowWiseSparseAdaGradFused(
int block_size, // number of parameters per row
int prefetch,
bool use_offsets,
bool use_stochastic_rounding) {
bool use_stochastic_rounding,
int grad_stride) {
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
if (grad_stride == -1) {
grad_stride = block_size;
}

// Use avx512 only for fp16 + stochastic rounding
if (fbgemmHasAvx512Support() && std::is_same<DataType, float16>::value &&
Expand All @@ -815,7 +826,12 @@ GenerateRowWiseSparseAdaGradFused(
inst_set_t::avx512>
kernel_generator;
const auto original_func = kernel_generator.getOrCreate(
nullptr, block_size, prefetch, use_offsets, use_stochastic_rounding);
nullptr,
block_size,
prefetch,
use_offsets,
use_stochastic_rounding,
grad_stride);
const auto lambda_func = [=](int64_t output_size,
int64_t index_size,
int64_t data_size,
Expand Down Expand Up @@ -858,7 +874,8 @@ GenerateRowWiseSparseAdaGradFused(
block_size,
prefetch,
use_offsets,
use_stochastic_rounding);
use_stochastic_rounding,
grad_stride);
const auto lambda_func = [=](int64_t output_size,
int64_t index_size,
int64_t data_size,
Expand Down Expand Up @@ -913,7 +930,8 @@ GenerateRowWiseSparseAdaGradFused(
epsilon,
lr,
use_offsets,
use_stochastic_rounding);
use_stochastic_rounding,
grad_stride);
};
}
}
Expand All @@ -924,62 +942,70 @@ template FBGEMM_API
int block_size, // number of parameters per row
int prefetch,
bool use_offsets,
bool use_stochastic_rounding);
bool use_stochastic_rounding,
int grad_stride);

template FBGEMM_API
typename RowWiseSparseAdaGradFusedSignature<int64_t, int64_t, float>::Type
GenerateRowWiseSparseAdaGradFused<int64_t, int64_t, float>(
int block_size, // number of parameters per row
int prefetch,
bool use_offsets,
bool use_stochastic_rounding);
bool use_stochastic_rounding,
int grad_stride);

template FBGEMM_API
typename RowWiseSparseAdaGradFusedSignature<int32_t, int32_t, float>::Type
GenerateRowWiseSparseAdaGradFused<int32_t, int32_t, float>(
int block_size, // number of parameters per row
int prefetch,
bool use_offsets,
bool use_stochastic_rounding);
bool use_stochastic_rounding,
int grad_stride);

template FBGEMM_API
typename RowWiseSparseAdaGradFusedSignature<int32_t, int64_t, float>::Type
GenerateRowWiseSparseAdaGradFused<int32_t, int64_t, float>(
int block_size, // number of parameters per row
int prefetch,
bool use_offsets,
bool use_stochastic_rounding);
bool use_stochastic_rounding,
int grad_stride);

template FBGEMM_API
typename RowWiseSparseAdaGradFusedSignature<int64_t, int32_t, float16>::Type
GenerateRowWiseSparseAdaGradFused<int64_t, int32_t, float16>(
int block_size, // number of parameters per row
int prefetch,
bool use_offsets,
bool use_stochastic_rounding);
bool use_stochastic_rounding,
int grad_stride);

template FBGEMM_API
typename RowWiseSparseAdaGradFusedSignature<int64_t, int64_t, float16>::Type
GenerateRowWiseSparseAdaGradFused<int64_t, int64_t, float16>(
int block_size, // number of parameters per row
int prefetch,
bool use_offsets,
bool use_stochastic_rounding);
bool use_stochastic_rounding,
int grad_stride);

template FBGEMM_API
typename RowWiseSparseAdaGradFusedSignature<int32_t, int32_t, float16>::Type
GenerateRowWiseSparseAdaGradFused<int32_t, int32_t, float16>(
int block_size, // number of parameters per row
int prefetch,
bool use_offsets,
bool use_stochastic_rounding);
bool use_stochastic_rounding,
int grad_stride);

template FBGEMM_API
typename RowWiseSparseAdaGradFusedSignature<int32_t, int64_t, float16>::Type
GenerateRowWiseSparseAdaGradFused<int32_t, int64_t, float16>(
int block_size, // number of parameters per row
int prefetch,
bool use_offsets,
bool use_stochastic_rounding);
bool use_stochastic_rounding,
int grad_stride);

} // namespace fbgemm
34 changes: 24 additions & 10 deletions test/RowWiseSparseAdagradFusedTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,15 @@ vector<int> prefetch_distances{0, 16, 1000000};

namespace {

class RowWiseSparseAdagradFusedTest
: public testing::TestWithParam<
tuple<bool, bool, bool, bool, int, bool, EmbeddingSpMDMCornerCase>> {
};
class RowWiseSparseAdagradFusedTest : public testing::TestWithParam<tuple<
bool,
bool,
bool,
bool,
int,
bool,
EmbeddingSpMDMCornerCase,
bool>> {};
}; // namespace

INSTANTIATE_TEST_CASE_P(
Expand All @@ -71,12 +76,13 @@ INSTANTIATE_TEST_CASE_P(
NONE,
EMPTY_INDICES,
OUT_OF_BOUND_INDICES,
UNMATCHED_NUM_INDICES_AND_LENGTHS_SUM)));
UNMATCHED_NUM_INDICES_AND_LENGTHS_SUM),
::testing::Bool())); // grad_stride != block_size

TEST_P(RowWiseSparseAdagradFusedTest, rowwiseTest) {
vector<vector<int>> inputs(GetInputs_());
bool isWeightFp16, useStochasticRounding, isIndex64b, isOffset64b,
use_offsets;
use_offsets, use_grad_stride;
int prefetch;
EmbeddingSpMDMCornerCase corner_case;
tie(isWeightFp16,
Expand All @@ -85,7 +91,8 @@ TEST_P(RowWiseSparseAdagradFusedTest, rowwiseTest) {
isOffset64b,
prefetch,
use_offsets,
corner_case) = GetParam();
corner_case,
use_grad_stride) = GetParam();

if (!isWeightFp16 && useStochasticRounding) {
// stochastic rounding makes sense only for fp16 weight
Expand All @@ -102,10 +109,12 @@ TEST_P(RowWiseSparseAdagradFusedTest, rowwiseTest) {
int num_rows = input[1];
int embedding_dim = input[2];
int average_len = input[3];
int grad_stride = use_grad_stride ? embedding_dim * 2 + 3 : -1;

// Create embedding table
vector<float> w(num_rows * embedding_dim), w_ref(num_rows * embedding_dim),
h(num_rows), h_ref(num_rows), g(batch_size * embedding_dim);
h(num_rows), h_ref(num_rows),
g(batch_size * (use_grad_stride ? grad_stride : embedding_dim));
vector<float16> w_fp16(w.size()), w_fp16_ref(w.size());
default_random_engine generator;
uniform_real_distribution<float> values_gen(0, 2);
Expand Down Expand Up @@ -160,14 +169,19 @@ TEST_P(RowWiseSparseAdagradFusedTest, rowwiseTest) {
lr, \
use_offsets, \
useStochasticRounding, \
vlen); \
vlen, \
grad_stride); \
} while (0)

#define JIT(WeightType, IndexType, OffsetType, Weights, Indices, Offsets) \
do { \
auto kernel = \
GenerateRowWiseSparseAdaGradFused<IndexType, OffsetType, WeightType>( \
embedding_dim, prefetch, use_offsets, useStochasticRounding); \
embedding_dim, \
prefetch, \
use_offsets, \
useStochasticRounding, \
grad_stride); \
success = kernel( \
batch_size, \
lengths_sum, \
Expand Down

0 comments on commit 8998e6f

Please sign in to comment.