Skip to content

Commit

Permalink
Adaptive Weight decay based on Sparse feature ID frequency. (pytorch#508
Browse files Browse the repository at this point in the history
)

Summary:
Pull Request resolved: pytorch#508

Add an option to adjust weight decay based on the update frequency of each embedding
Add a new interface to maintain backward compatibility. Will need multiple steps of changes btw fbgemm and Caffe2.

Reviewed By: dskhudia

Differential Revision: D26210974

fbshipit-source-id: 6936e2f8afaf94fd1273d99b02e63b5e3e574c4d
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Feb 9, 2021
1 parent a8bee90 commit 092497f
Show file tree
Hide file tree
Showing 5 changed files with 365 additions and 67 deletions.
32 changes: 32 additions & 0 deletions include/fbgemm/FbgemmEmbedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,30 @@ class SparseAdaGradSignature {
float weight_decay)>;
};

/**
* @return The number of rows processed. If smaller than num_rows, an error
* must have happened at the last row processed.
*/
template <typename IndexType>
class SparseAdaGradSignatureNew {
public:
using Type = std::function<int(
int num_rows, // number of rows reading
std::uint64_t param_size, // total number of parameters
float* w, // input/output parameters
const float* g, // input gradients
float* h, // input/output momentums
const IndexType* indices, // indices of each row
float epsilon,
float lr,
float weight_decay,
const double* counter, // used for weight_decay adjusted for frequency
// nullptr when frequency adjustment is not used.
// ignored when the kernel is generated with
// use_weight_decay = false.
std::int64_t counter_halflife)>; // frequency adjust happens only after
};

template <typename IndexType>
FBGEMM_API typename SparseAdaGradSignature<IndexType>::Type
GenerateSparseAdaGrad(
Expand All @@ -178,6 +202,14 @@ GenerateSparseAdaGrad(
int prefetch = 16,
bool use_weight_decay = false);

template <typename IndexType>
FBGEMM_API typename SparseAdaGradSignatureNew<IndexType>::Type
GenerateSparseAdaGradNew(
int block_size, // number of parameters per row
bool rowwise = false,
int prefetch = 16,
bool use_weight_decay = false);

// RowWiseSparseAdaGrad fused with SLS gradient
// Weights can be either float or float16
template <
Expand Down
36 changes: 27 additions & 9 deletions src/RefImplementations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1472,7 +1472,9 @@ int sparse_adagrad_ref(
const IndexType* indices, // indices of each row
float epsilon,
float lr,
float weight_decay) {
float weight_decay,
const double* counter,
const int64_t counter_halflife) {
for (auto i = 0; i < num_rows; ++i) {
uint64_t idx = indices[i];
auto offsetI = i * block_size;
Expand All @@ -1482,6 +1484,9 @@ int sparse_adagrad_ref(
return i;
}

float freq =
(counter && counter[idx] > 0) ? counter_halflife / counter[idx] : 1.0;

const float* g_;
const float* h_;
const float* w_;
Expand All @@ -1495,7 +1500,7 @@ int sparse_adagrad_ref(
nw_ = w + offsetIdx;

for (auto j = 0; j < block_size; ++j) {
float gj = std::fma(weight_decay, w_[j], g_[j]);
float gj = std::fma(weight_decay * freq, w_[j], g_[j]);
float hj = h_[j] + gj * gj;
nh_[j] = hj;
nw_[j] = w_[j] + lr * gj / (std::sqrt(hj) + epsilon);
Expand All @@ -1515,7 +1520,9 @@ int rowwise_sparse_adagrad_ref(
const IndexType* indices, // indices of each row
float epsilon,
float lr,
float weight_decay) {
float weight_decay,
const double* counter,
const int64_t counter_halflife) {
for (auto i = 0; i < num_rows; ++i) {
uint64_t idx = indices[i];
auto offsetI = i * block_size;
Expand All @@ -1525,6 +1532,9 @@ int rowwise_sparse_adagrad_ref(
return i;
}

float freq =
(counter && counter[idx] > 0) ? counter_halflife / counter[idx] : 1.0;

const float* g_;
float* h_;
float* w_;
Expand All @@ -1546,7 +1556,7 @@ int rowwise_sparse_adagrad_ref(
constexpr int VLEN = 8;
array<float, VLEN> partial_sum = {0.0f};
for (auto j = 0; j < block_size; ++j) {
float gj = std::fma(weight_decay, w_[j], g_[j]);
float gj = std::fma(weight_decay * freq, w_[j], g_[j]);
partial_sum[j % VLEN] += gj * gj;
}
final_sum = ((partial_sum[0] + partial_sum[1]) +
Expand All @@ -1557,7 +1567,7 @@ int rowwise_sparse_adagrad_ref(
float float_step = lr / (std::sqrt(hi) + epsilon);

for (auto j = 0; j < block_size; ++j) {
float gj = std::fma(weight_decay, w_[j], g_[j]);
float gj = std::fma(weight_decay * freq, w_[j], g_[j]);
w_[j] += gj * float_step;
}
}
Expand Down Expand Up @@ -1813,7 +1823,9 @@ template FBGEMM_API int sparse_adagrad_ref(
const std::int64_t* indices, // indices of each row
float epsilon,
float lr,
float weight_decay);
float weight_decay,
const double* counter,
const int64_t counter_halflife);

template FBGEMM_API int sparse_adagrad_ref(
int num_rows, // number of rows reading
Expand All @@ -1825,7 +1837,9 @@ template FBGEMM_API int sparse_adagrad_ref(
const std::int32_t* indices, // indices of each row
float epsilon,
float lr,
float weight_decay);
float weight_decay,
const double* counter,
const int64_t counter_halflife);

template FBGEMM_API int rowwise_sparse_adagrad_ref(
int num_rows, // number of rows reading
Expand All @@ -1837,7 +1851,9 @@ template FBGEMM_API int rowwise_sparse_adagrad_ref(
const std::int64_t* indices, // indices of each row
float epsilon,
float lr,
float weight_decay);
float weight_decay,
const double* counter,
const int64_t counter_halflife);

template FBGEMM_API int rowwise_sparse_adagrad_ref(
int num_rows, // number of rows reading
Expand All @@ -1849,7 +1865,9 @@ template FBGEMM_API int rowwise_sparse_adagrad_ref(
const std::int32_t* indices, // indices of each row
float epsilon,
float lr,
float weight_decay);
float weight_decay,
const double* counter,
const int64_t counter_halflife);

#define INSTANTIATE_SPMDM_BASE(DATA_TYPE, INDEX_TYPE, OFFSET_TYPE) \
template FBGEMM_API int rowwise_sparse_adagrad_fused_ref( \
Expand Down
66 changes: 49 additions & 17 deletions src/RefImplementations.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
#include <algorithm>
#include <cstdint>

#include "fbgemm/Types.h"
#include "fbgemm/ConvUtils.h"
#include "fbgemm/FbgemmI8Spmdm.h"
#include "fbgemm/Types.h"

namespace fbgemm {

Expand Down Expand Up @@ -286,31 +286,63 @@ FBGEMM_API bool EmbeddingSpMDMNBitRowWiseSparse_ref(
bool is_weight_positional = false,
bool use_offsets = true);

/**
* @param num_rows number of rows reading
* @param block_size number of parameters per rows
* @param param_size total number of parameters
* @param w input parameters
* @param g input gradients
* @param h input momentum
* @param indices indices of each row
* @param counter used for weight_decay adjusted for frequency. nullptr when
* frequency adjustment is not used. Ignored when weight_decay
* == 0
* @param counter_halflife weight_decay is adjusted only after this number of
* iterations
*/
template <typename IndexType>
FBGEMM_API int sparse_adagrad_ref(
int num_rows, // number of rows reading
int block_size, // number of parameters per rows
std::uint64_t param_size, // total number of parameters
float* w, // input parameters
const float* g, // input gradients
float* h, // input momentums
const IndexType* indices, // indices of each row
int num_rows,
int block_size,
std::uint64_t param_size,
float* w,
const float* g,
float* h,
const IndexType* indices,
float epsilon,
float lr,
float weight_decay = 0.f);
float weight_decay = 0.f,
const double* counter = nullptr,
const int64_t counter_halflife = 0);

/**
* @param num_rows number of rows reading
* @param block_size number of parameters per rows
* @param param_size total number of parameters
* @param w input parameters
* @param g input gradients
* @param h input momentum
* @param indices indices of each row
* @param counter used for weight_decay adjusted for frequency. nullptr when
* frequency adjustment is not used. Ignored when weight_decay
* == 0
* @param counter_halflife weight_decay is adjusted only after this number of
* iterations
*/
template <typename IndexType>
FBGEMM_API int rowwise_sparse_adagrad_ref(
int num_rows, // number of rows reading
int block_size, // number of parameters per rows
std::uint64_t param_size, // total number of parameters
float* w, // input parameters
const float* g, // input gradients
float* h, // input momentums
const IndexType* indices, // indices of each row
int num_rows,
int block_size,
std::uint64_t param_size,
float* w,
const float* g,
float* h,
const IndexType* indices,
float epsilon,
float lr,
float weight_decay = 0.f);
float weight_decay = 0.f,
const double* counter = nullptr,
const int64_t counter_halflife = 0);

template <typename DataType, typename IndexType, typename OffsetType>
FBGEMM_API int rowwise_sparse_adagrad_fused_ref(
Expand Down
Loading

0 comments on commit 092497f

Please sign in to comment.