Skip to content

Commit

Permalink
add new sparse adagrad interface takes weight_decay at call time not …
Browse files Browse the repository at this point in the history
…at codegen time (pytorch#432)

Summary:
Pull Request resolved: pytorch#432

To support usages like D24040476 . Will have follow-up diffs to use the new interface (need to split into multiple diffs to take care of a small time diff btw fbgemm and pytorch/caffe2 github sync)
* D24195799 (this one) : add a new interface to fbgemm with a temp name with New suffix
* D24196753 : Caffe2 calls the new interface with temp name
* D24197553 : change the old names to use the new interface
* D24197694 : Caffe2 calls the old name with new interface
* D24197727 : remove the temp name

It's like swap operation that needs a temp variable (sort of)

Reviewed By: dskhudia

Differential Revision: D24195799

fbshipit-source-id: f89a23291ffa2c241618a01cb7ab445484bd4e4f
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Oct 9, 2020
1 parent fe91640 commit 974d2b4
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 18 deletions.
5 changes: 3 additions & 2 deletions bench/RowwiseAdagradBenchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ void run_benchmark(
constexpr int NUM_ITER = 10;
double data_moved = num_rows * (3 * sizeof(float) * block_size + 2 * 64);

auto fn = GenerateSparseAdaGrad<int64_t>(block_size, /*rowwise=*/true);
auto fn = GenerateSparseAdaGradNew<int64_t>(block_size, /*rowwise=*/true);

double t = measureWithWarmup(
[&]() {
Expand All @@ -84,7 +84,8 @@ void run_benchmark(
h.data(), // input momentums
indices.data(), // indices of each row
epsilon,
lr);
lr,
0.0f); // weight_decay
},
NUM_WARMUP,
NUM_ITER,
Expand Down
10 changes: 6 additions & 4 deletions bench/SparseAdagradBenchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ void run_benchmark(

double t = 0.0;
if (isIndex64b) {
auto fn_indices_64 = GenerateSparseAdaGrad<int64_t>(block_size);
auto fn_indices_64 = GenerateSparseAdaGradNew<int64_t>(block_size);

t = measureWithWarmup(
[&]() {
Expand All @@ -92,7 +92,8 @@ void run_benchmark(
h.data(), // input momentums
indices.data(), // indices of each row
epsilon,
lr);
lr,
0.0f); // weight_decay
},
NUM_WARMUP,
NUM_ITER,
Expand All @@ -111,7 +112,7 @@ void run_benchmark(
lr);
}
} else {
auto fn_indices_32 = GenerateSparseAdaGrad<int32_t>(block_size);
auto fn_indices_32 = GenerateSparseAdaGradNew<int32_t>(block_size);

t = measureWithWarmup(
[&]() {
Expand All @@ -123,7 +124,8 @@ void run_benchmark(
h.data(), // input momentums
indices_32.data(), // indices of each row
epsilon,
lr);
lr,
0.0f); // weight_decay
},
NUM_WARMUP,
NUM_ITER,
Expand Down
18 changes: 18 additions & 0 deletions include/fbgemm/FbgemmEmbedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,16 @@ class SparseAdaGradSignature {
const IndexType* indices, // indices of each row
float epsilon,
float lr)>;
using NewType = std::function<int(
int num_rows, // number of rows reading
std::uint64_t param_size, // total number of parameters
float* w, // input/output parameters
const float* g, // input gradients
float* h, // input/output momentums
const IndexType* indices, // indices of each row
float epsilon,
float lr,
float weight_decay)>;
};

template <typename IndexType>
Expand All @@ -177,6 +187,14 @@ GenerateSparseAdaGrad(
int prefetch = 16,
float weight_decay = 0.0f);

template <typename IndexType>
FBGEMM_API typename SparseAdaGradSignature<IndexType>::NewType
GenerateSparseAdaGradNew(
int block_size, // number of parameters per row
bool rowwise = false,
int prefetch = 16,
bool use_weight_decay = false);

// RowWiseSparseAdaGrad fused with SLS gradient
// Weights can be either float or float16
template <
Expand Down
103 changes: 103 additions & 0 deletions src/SparseAdagrad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -909,4 +909,107 @@ GenerateSparseAdaGrad<std::int32_t>(
int prefetch,
float weight_decay);

template <typename IndexType>
typename SparseAdaGradSignature<IndexType>::NewType GenerateSparseAdaGradNew(
int block_size,
bool rowwise,
int prefetch,
bool use_weight_decay) {
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}

if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) {
if (block_size == 1) {
return [=](int num_rows, // number of rows reading
std::uint64_t param_size, // total number of parameters
float* w, // input/output parameters
const float* g, // input gradients
float* h, // input/output momentums
const IndexType* indices, // indices of each row
float epsilon,
float lr,
float weight_decay) {
return SparseAdaGradBlockSize1_(
num_rows,
param_size,
w,
g,
h,
indices,
epsilon,
lr,
rowwise,
weight_decay);
};
}
static GenSparseAdagrad<IndexType, inst_set_t::avx2> kernel_generator;
constexpr int VLEN = simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS;
const int* mask_avx2 = &internal::avx2_ps_or_epi32_combined_mask
[(VLEN - (block_size % VLEN)) % VLEN];
const auto original_func = kernel_generator.getOrCreate(
block_size, prefetch, rowwise, use_weight_decay);
return [=](int num_rows, // number of rows reading
std::uint64_t param_size, // total number of parameters
float* w, // input/output parameters
const float* g, // input gradients
float* h, // input/output momentums
const IndexType* indices, // indices of each row
float epsilon,
float lr,
float weight_decay) {
return original_func(
num_rows, // number of rows reading
param_size, // total number of parameters
w, // input/output parameters
g, // input gradients
h, // input/output momentums
indices, // indices of each row
epsilon,
lr,
mask_avx2,
weight_decay);
};
} else {
#ifdef VLOG
VLOG(0) << "AVX2 or AVX512 not found, taking the slow path";
#endif
return [=](int num_rows, // number of rows reading
std::uint64_t param_size, // total number of parameters
float* w, // input/output parameters
const float* g, // input gradients
float* h, // input/output momentums
const IndexType* indices, // indices of each row
float epsilon,
float lr,
float weight_decay) {
return sparse_adagrad_ref(
num_rows, // number of rows reading
block_size, // number of parameters per rows
param_size, // total number of parameters
w, // input/output parameters
g, // input gradients
h, // input/output momentums
indices,
epsilon,
lr,
weight_decay);
};
}
}

template FBGEMM_API typename SparseAdaGradSignature<std::int64_t>::NewType
GenerateSparseAdaGradNew<std::int64_t>(
int block_size, // number of parameters per rows
bool rowwise,
int prefetch,
bool use_weight_decay);

template FBGEMM_API typename SparseAdaGradSignature<std::int32_t>::NewType
GenerateSparseAdaGradNew<std::int32_t>(
int block_size, // number of parameters per rows
bool rowwise,
int prefetch,
bool use_weight_decay);

} // namespace fbgemm
28 changes: 16 additions & 12 deletions test/SparseAdagradTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,8 @@ TEST_P(SparseAdagradTest, basicTest_two_stages) {
lr,
weight_decay);

auto fn_fbgemm = GenerateSparseAdaGrad<std::int64_t>(
block_size, false, prefetch, weight_decay);
auto fn_fbgemm = GenerateSparseAdaGradNew<std::int64_t>(
block_size, false, prefetch, use_weight_decay);

ret_fbgemm = fn_fbgemm(
num_rows, // number of rows reading
Expand All @@ -134,7 +134,8 @@ TEST_P(SparseAdagradTest, basicTest_two_stages) {
h.data(), // input momentums
indices.data(), // indices of each row
epsilon,
lr);
lr,
weight_decay);
} else { // 32 bit indices
ret_ref = sparse_adagrad_ref(
num_rows, // number of rows reading
Expand All @@ -148,8 +149,8 @@ TEST_P(SparseAdagradTest, basicTest_two_stages) {
lr,
weight_decay);

auto fn_fbgemm = GenerateSparseAdaGrad<std::int32_t>(
block_size, false, prefetch, weight_decay);
auto fn_fbgemm = GenerateSparseAdaGradNew<std::int32_t>(
block_size, false, prefetch, use_weight_decay);

ret_fbgemm = fn_fbgemm(
num_rows, // number of rows reading
Expand All @@ -159,7 +160,8 @@ TEST_P(SparseAdagradTest, basicTest_two_stages) {
h.data(), // input momentums
indices_32.data(), // indices of each row
epsilon,
lr);
lr,
weight_decay);
}

EXPECT_EQ(ret_fbgemm, ret_ref)
Expand Down Expand Up @@ -237,8 +239,8 @@ TEST_P(SparseAdagradTest, rowwiseTest_two_stages) {
lr,
weight_decay);

auto fn_fbgemm = GenerateSparseAdaGrad<std::int64_t>(
block_size, true, prefetch, weight_decay);
auto fn_fbgemm = GenerateSparseAdaGradNew<std::int64_t>(
block_size, true, prefetch, use_weight_decay);

ret_fbgemm = fn_fbgemm(
num_rows, // number of rows reading
Expand All @@ -248,7 +250,8 @@ TEST_P(SparseAdagradTest, rowwiseTest_two_stages) {
h.data(), // input momentums
indices.data(), // indices of each row
epsilon,
lr);
lr,
weight_decay);
} else { // 32 bit indices
ret_ref = rowwise_sparse_adagrad_ref(
num_rows, // number of rows reading
Expand All @@ -262,8 +265,8 @@ TEST_P(SparseAdagradTest, rowwiseTest_two_stages) {
lr,
weight_decay);

auto fn_fbgemm = GenerateSparseAdaGrad<std::int32_t>(
block_size, true, prefetch, weight_decay);
auto fn_fbgemm = GenerateSparseAdaGradNew<std::int32_t>(
block_size, true, prefetch, use_weight_decay);

ret_fbgemm = fn_fbgemm(
num_rows, // number of rows reading
Expand All @@ -273,7 +276,8 @@ TEST_P(SparseAdagradTest, rowwiseTest_two_stages) {
h.data(), // input momentums
indices_32.data(), // indices of each row
epsilon,
lr);
lr,
weight_decay);
}

EXPECT_EQ(ret_fbgemm, ret_ref)
Expand Down

0 comments on commit 974d2b4

Please sign in to comment.