Skip to content

Commit

Permalink
match old sparse adagrad interface with the new one (pytorch#434)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#434

As title

Reviewed By: dskhudia

Differential Revision: D24197553

fbshipit-source-id: 1200a087e4df9ef5385122e296da6a5999eea606
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Oct 9, 2020
1 parent 974d2b4 commit 75ea7ce
Show file tree
Hide file tree
Showing 5 changed files with 21 additions and 105 deletions.
2 changes: 1 addition & 1 deletion bench/RowwiseAdagradBenchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ void run_benchmark(
constexpr int NUM_ITER = 10;
double data_moved = num_rows * (3 * sizeof(float) * block_size + 2 * 64);

auto fn = GenerateSparseAdaGradNew<int64_t>(block_size, /*rowwise=*/true);
auto fn = GenerateSparseAdaGrad<int64_t>(block_size, /*rowwise=*/true);

double t = measureWithWarmup(
[&]() {
Expand Down
4 changes: 2 additions & 2 deletions bench/SparseAdagradBenchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ void run_benchmark(

double t = 0.0;
if (isIndex64b) {
auto fn_indices_64 = GenerateSparseAdaGradNew<int64_t>(block_size);
auto fn_indices_64 = GenerateSparseAdaGrad<int64_t>(block_size);

t = measureWithWarmup(
[&]() {
Expand Down Expand Up @@ -112,7 +112,7 @@ void run_benchmark(
lr);
}
} else {
auto fn_indices_32 = GenerateSparseAdaGradNew<int32_t>(block_size);
auto fn_indices_32 = GenerateSparseAdaGrad<int32_t>(block_size);

t = measureWithWarmup(
[&]() {
Expand Down
12 changes: 2 additions & 10 deletions include/fbgemm/FbgemmEmbedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -159,15 +159,6 @@ template <typename IndexType>
class SparseAdaGradSignature {
public:
using Type = std::function<int(
int num_rows, // number of rows reading
std::uint64_t param_size, // total number of parameters
float* w, // input/output parameters
const float* g, // input gradients
float* h, // input/output momentums
const IndexType* indices, // indices of each row
float epsilon,
float lr)>;
using NewType = std::function<int(
int num_rows, // number of rows reading
std::uint64_t param_size, // total number of parameters
float* w, // input/output parameters
Expand All @@ -177,6 +168,7 @@ class SparseAdaGradSignature {
float epsilon,
float lr,
float weight_decay)>;
using NewType = Type;
};

template <typename IndexType>
Expand All @@ -185,7 +177,7 @@ GenerateSparseAdaGrad(
int block_size, // number of parameters per row
bool rowwise = false,
int prefetch = 16,
float weight_decay = 0.0f);
bool use_weight_decay = false);

template <typename IndexType>
FBGEMM_API typename SparseAdaGradSignature<IndexType>::NewType
Expand Down
100 changes: 12 additions & 88 deletions src/SparseAdagrad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ typename SparseAdaGradSignature<IndexType>::Type GenerateSparseAdaGrad(
int block_size,
bool rowwise,
int prefetch,
float weight_decay) {
bool use_weight_decay) {
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
Expand All @@ -828,7 +828,8 @@ typename SparseAdaGradSignature<IndexType>::Type GenerateSparseAdaGrad(
float* h, // input/output momentums
const IndexType* indices, // indices of each row
float epsilon,
float lr) {
float lr,
float weight_decay) {
return SparseAdaGradBlockSize1_(
num_rows,
param_size,
Expand All @@ -847,15 +848,16 @@ typename SparseAdaGradSignature<IndexType>::Type GenerateSparseAdaGrad(
const int* mask_avx2 = &internal::avx2_ps_or_epi32_combined_mask
[(VLEN - (block_size % VLEN)) % VLEN];
const auto original_func = kernel_generator.getOrCreate(
block_size, prefetch, rowwise, weight_decay != 0.0f);
block_size, prefetch, rowwise, use_weight_decay);
return [=](int num_rows, // number of rows reading
std::uint64_t param_size, // total number of parameters
float* w, // input/output parameters
const float* g, // input gradients
float* h, // input/output momentums
const IndexType* indices, // indices of each row
float epsilon,
float lr) {
float lr,
float weight_decay) {
return original_func(
num_rows, // number of rows reading
param_size, // total number of parameters
Expand All @@ -879,7 +881,8 @@ typename SparseAdaGradSignature<IndexType>::Type GenerateSparseAdaGrad(
float* h, // input/output momentums
const IndexType* indices, // indices of each row
float epsilon,
float lr) {
float lr,
float weight_decay) {
return sparse_adagrad_ref(
num_rows, // number of rows reading
block_size, // number of parameters per rows
Expand All @@ -900,102 +903,23 @@ GenerateSparseAdaGrad<std::int64_t>(
int block_size, // number of parameters per rows
bool rowwise,
int prefetch,
float weight_decay);
bool use_weight_decay);

template FBGEMM_API typename SparseAdaGradSignature<std::int32_t>::Type
GenerateSparseAdaGrad<std::int32_t>(
int block_size, // number of parameters per rows
bool rowwise,
int prefetch,
float weight_decay);
bool use_weight_decay);

template <typename IndexType>
typename SparseAdaGradSignature<IndexType>::NewType GenerateSparseAdaGradNew(
int block_size,
bool rowwise,
int prefetch,
bool use_weight_decay) {
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}

if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) {
if (block_size == 1) {
return [=](int num_rows, // number of rows reading
std::uint64_t param_size, // total number of parameters
float* w, // input/output parameters
const float* g, // input gradients
float* h, // input/output momentums
const IndexType* indices, // indices of each row
float epsilon,
float lr,
float weight_decay) {
return SparseAdaGradBlockSize1_(
num_rows,
param_size,
w,
g,
h,
indices,
epsilon,
lr,
rowwise,
weight_decay);
};
}
static GenSparseAdagrad<IndexType, inst_set_t::avx2> kernel_generator;
constexpr int VLEN = simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS;
const int* mask_avx2 = &internal::avx2_ps_or_epi32_combined_mask
[(VLEN - (block_size % VLEN)) % VLEN];
const auto original_func = kernel_generator.getOrCreate(
block_size, prefetch, rowwise, use_weight_decay);
return [=](int num_rows, // number of rows reading
std::uint64_t param_size, // total number of parameters
float* w, // input/output parameters
const float* g, // input gradients
float* h, // input/output momentums
const IndexType* indices, // indices of each row
float epsilon,
float lr,
float weight_decay) {
return original_func(
num_rows, // number of rows reading
param_size, // total number of parameters
w, // input/output parameters
g, // input gradients
h, // input/output momentums
indices, // indices of each row
epsilon,
lr,
mask_avx2,
weight_decay);
};
} else {
#ifdef VLOG
VLOG(0) << "AVX2 or AVX512 not found, taking the slow path";
#endif
return [=](int num_rows, // number of rows reading
std::uint64_t param_size, // total number of parameters
float* w, // input/output parameters
const float* g, // input gradients
float* h, // input/output momentums
const IndexType* indices, // indices of each row
float epsilon,
float lr,
float weight_decay) {
return sparse_adagrad_ref(
num_rows, // number of rows reading
block_size, // number of parameters per rows
param_size, // total number of parameters
w, // input/output parameters
g, // input gradients
h, // input/output momentums
indices,
epsilon,
lr,
weight_decay);
};
}
return GenerateSparseAdaGrad<IndexType>(
block_size, rowwise, prefetch, use_weight_decay);
}

template FBGEMM_API typename SparseAdaGradSignature<std::int64_t>::NewType
Expand Down
8 changes: 4 additions & 4 deletions test/SparseAdagradTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ TEST_P(SparseAdagradTest, basicTest_two_stages) {
lr,
weight_decay);

auto fn_fbgemm = GenerateSparseAdaGradNew<std::int64_t>(
auto fn_fbgemm = GenerateSparseAdaGrad<std::int64_t>(
block_size, false, prefetch, use_weight_decay);

ret_fbgemm = fn_fbgemm(
Expand All @@ -149,7 +149,7 @@ TEST_P(SparseAdagradTest, basicTest_two_stages) {
lr,
weight_decay);

auto fn_fbgemm = GenerateSparseAdaGradNew<std::int32_t>(
auto fn_fbgemm = GenerateSparseAdaGrad<std::int32_t>(
block_size, false, prefetch, use_weight_decay);

ret_fbgemm = fn_fbgemm(
Expand Down Expand Up @@ -239,7 +239,7 @@ TEST_P(SparseAdagradTest, rowwiseTest_two_stages) {
lr,
weight_decay);

auto fn_fbgemm = GenerateSparseAdaGradNew<std::int64_t>(
auto fn_fbgemm = GenerateSparseAdaGrad<std::int64_t>(
block_size, true, prefetch, use_weight_decay);

ret_fbgemm = fn_fbgemm(
Expand All @@ -265,7 +265,7 @@ TEST_P(SparseAdagradTest, rowwiseTest_two_stages) {
lr,
weight_decay);

auto fn_fbgemm = GenerateSparseAdaGradNew<std::int32_t>(
auto fn_fbgemm = GenerateSparseAdaGrad<std::int32_t>(
block_size, true, prefetch, use_weight_decay);

ret_fbgemm = fn_fbgemm(
Expand Down

0 comments on commit 75ea7ce

Please sign in to comment.