Skip to content

Commit

Permalink
remove temp interface for sparse adagrad (pytorch#515)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#515

As a final step, we remove the temp duplicated interface.

Reviewed By: dskhudia

Differential Revision: D26232204

fbshipit-source-id: bf967ab7cf5387f7a9597cbd8bdb7d44cd70cb62
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Feb 13, 2021
1 parent 8db2382 commit 7f3baec
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 56 deletions.
32 changes: 0 additions & 32 deletions include/fbgemm/FbgemmEmbedding.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,30 +175,6 @@ class SparseAdaGradSignature {
std::int64_t counter_halflife)>; // frequency adjust happens only after
};

/**
* @return The number of rows processed. If smaller than num_rows, an error
* must have happened at the last row processed.
*/
template <typename IndexType>
class SparseAdaGradSignatureNew {
public:
using Type = std::function<int(
int num_rows, // number of rows reading
std::uint64_t param_size, // total number of parameters
float* w, // input/output parameters
const float* g, // input gradients
float* h, // input/output momentums
const IndexType* indices, // indices of each row
float epsilon,
float lr,
float weight_decay,
const double* counter, // used for weight_decay adjusted for frequency
// nullptr when frequency adjustment is not used.
// ignored when the kernel is generated with
// use_weight_decay = false.
std::int64_t counter_halflife)>; // frequency adjust happens only after
};

template <typename IndexType>
FBGEMM_API typename SparseAdaGradSignature<IndexType>::Type
GenerateSparseAdaGrad(
Expand All @@ -207,14 +183,6 @@ GenerateSparseAdaGrad(
int prefetch = 16,
bool use_weight_decay = false);

template <typename IndexType>
FBGEMM_API typename SparseAdaGradSignatureNew<IndexType>::Type
GenerateSparseAdaGradNew(
int block_size, // number of parameters per row
bool rowwise = false,
int prefetch = 16,
bool use_weight_decay = false);

// RowWiseSparseAdaGrad fused with SLS gradient
// Weights can be either float or float16
template <
Expand Down
24 changes: 0 additions & 24 deletions src/SparseAdagrad.cc
Original file line number Diff line number Diff line change
Expand Up @@ -979,28 +979,4 @@ GenerateSparseAdaGrad<std::int32_t>(
int prefetch,
bool use_weight_decay);

template <typename IndexType>
typename SparseAdaGradSignatureNew<IndexType>::Type GenerateSparseAdaGradNew(
int block_size,
bool rowwise,
int prefetch,
bool use_weight_decay) {
return GenerateSparseAdaGrad<IndexType>(
block_size, rowwise, prefetch, use_weight_decay);
}

template FBGEMM_API typename SparseAdaGradSignatureNew<std::int64_t>::Type
GenerateSparseAdaGradNew<std::int64_t>(
int block_size, // number of parameters per rows
bool rowwise,
int prefetch,
bool use_weight_decay);

template FBGEMM_API typename SparseAdaGradSignatureNew<std::int32_t>::Type
GenerateSparseAdaGradNew<std::int32_t>(
int block_size, // number of parameters per rows
bool rowwise,
int prefetch,
bool use_weight_decay);

} // namespace fbgemm

0 comments on commit 7f3baec

Please sign in to comment.