forked from BVLC/caffe
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request BVLC#3116 from ronghanghu/solver-refactor
Solver Refactor: Separate files and Change Solver's Type to String
- Loading branch information
Showing
29 changed files
with
1,463 additions
and
1,047 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,148 @@ | ||
#ifndef CAFFE_SGD_SOLVERS_HPP_ | ||
#define CAFFE_SGD_SOLVERS_HPP_ | ||
|
||
#include <string> | ||
#include <vector> | ||
|
||
#include "caffe/solver.hpp" | ||
|
||
namespace caffe { | ||
|
||
/** | ||
* @brief Optimizes the parameters of a Net using | ||
* stochastic gradient descent (SGD) with momentum. | ||
*/ | ||
template <typename Dtype> | ||
class SGDSolver : public Solver<Dtype> { | ||
public: | ||
explicit SGDSolver(const SolverParameter& param) | ||
: Solver<Dtype>(param) { PreSolve(); } | ||
explicit SGDSolver(const string& param_file) | ||
: Solver<Dtype>(param_file) { PreSolve(); } | ||
virtual inline const char* type() const { return "SGD"; } | ||
|
||
const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; } | ||
|
||
protected: | ||
void PreSolve(); | ||
Dtype GetLearningRate(); | ||
virtual void ApplyUpdate(); | ||
virtual void Normalize(int param_id); | ||
virtual void Regularize(int param_id); | ||
virtual void ComputeUpdateValue(int param_id, Dtype rate); | ||
virtual void ClipGradients(); | ||
virtual void SnapshotSolverState(const string& model_filename); | ||
virtual void SnapshotSolverStateToBinaryProto(const string& model_filename); | ||
virtual void SnapshotSolverStateToHDF5(const string& model_filename); | ||
virtual void RestoreSolverStateFromHDF5(const string& state_file); | ||
virtual void RestoreSolverStateFromBinaryProto(const string& state_file); | ||
// history maintains the historical momentum data. | ||
// update maintains update related data and is not needed in snapshots. | ||
// temp maintains other information that might be needed in computation | ||
// of gradients/updates and is not needed in snapshots | ||
vector<shared_ptr<Blob<Dtype> > > history_, update_, temp_; | ||
|
||
DISABLE_COPY_AND_ASSIGN(SGDSolver); | ||
}; | ||
|
||
template <typename Dtype> | ||
class NesterovSolver : public SGDSolver<Dtype> { | ||
public: | ||
explicit NesterovSolver(const SolverParameter& param) | ||
: SGDSolver<Dtype>(param) {} | ||
explicit NesterovSolver(const string& param_file) | ||
: SGDSolver<Dtype>(param_file) {} | ||
virtual inline const char* type() const { return "Nesterov"; } | ||
|
||
protected: | ||
virtual void ComputeUpdateValue(int param_id, Dtype rate); | ||
|
||
DISABLE_COPY_AND_ASSIGN(NesterovSolver); | ||
}; | ||
|
||
template <typename Dtype> | ||
class AdaGradSolver : public SGDSolver<Dtype> { | ||
public: | ||
explicit AdaGradSolver(const SolverParameter& param) | ||
: SGDSolver<Dtype>(param) { constructor_sanity_check(); } | ||
explicit AdaGradSolver(const string& param_file) | ||
: SGDSolver<Dtype>(param_file) { constructor_sanity_check(); } | ||
virtual inline const char* type() const { return "AdaGrad"; } | ||
|
||
protected: | ||
virtual void ComputeUpdateValue(int param_id, Dtype rate); | ||
void constructor_sanity_check() { | ||
CHECK_EQ(0, this->param_.momentum()) | ||
<< "Momentum cannot be used with AdaGrad."; | ||
} | ||
|
||
DISABLE_COPY_AND_ASSIGN(AdaGradSolver); | ||
}; | ||
|
||
|
||
template <typename Dtype> | ||
class RMSPropSolver : public SGDSolver<Dtype> { | ||
public: | ||
explicit RMSPropSolver(const SolverParameter& param) | ||
: SGDSolver<Dtype>(param) { constructor_sanity_check(); } | ||
explicit RMSPropSolver(const string& param_file) | ||
: SGDSolver<Dtype>(param_file) { constructor_sanity_check(); } | ||
virtual inline const char* type() const { return "RMSProp"; } | ||
|
||
protected: | ||
virtual void ComputeUpdateValue(int param_id, Dtype rate); | ||
void constructor_sanity_check() { | ||
CHECK_EQ(0, this->param_.momentum()) | ||
<< "Momentum cannot be used with RMSProp."; | ||
CHECK_GE(this->param_.rms_decay(), 0) | ||
<< "rms_decay should lie between 0 and 1."; | ||
CHECK_LT(this->param_.rms_decay(), 1) | ||
<< "rms_decay should lie between 0 and 1."; | ||
} | ||
|
||
DISABLE_COPY_AND_ASSIGN(RMSPropSolver); | ||
}; | ||
|
||
template <typename Dtype> | ||
class AdaDeltaSolver : public SGDSolver<Dtype> { | ||
public: | ||
explicit AdaDeltaSolver(const SolverParameter& param) | ||
: SGDSolver<Dtype>(param) { AdaDeltaPreSolve(); } | ||
explicit AdaDeltaSolver(const string& param_file) | ||
: SGDSolver<Dtype>(param_file) { AdaDeltaPreSolve(); } | ||
virtual inline const char* type() const { return "AdaDelta"; } | ||
|
||
protected: | ||
void AdaDeltaPreSolve(); | ||
virtual void ComputeUpdateValue(int param_id, Dtype rate); | ||
|
||
DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver); | ||
}; | ||
|
||
/** | ||
* @brief AdamSolver, an algorithm for first-order gradient-based optimization | ||
* of stochastic objective functions, based on adaptive estimates of | ||
* lower-order moments. Described in [1]. | ||
* | ||
* [1] D. P. Kingma and J. L. Ba, "ADAM: A Method for Stochastic Optimization." | ||
* arXiv preprint arXiv:1412.6980v8 (2014). | ||
*/ | ||
template <typename Dtype> | ||
class AdamSolver : public SGDSolver<Dtype> { | ||
public: | ||
explicit AdamSolver(const SolverParameter& param) | ||
: SGDSolver<Dtype>(param) { AdamPreSolve();} | ||
explicit AdamSolver(const string& param_file) | ||
: SGDSolver<Dtype>(param_file) { AdamPreSolve(); } | ||
virtual inline const char* type() const { return "Adam"; } | ||
|
||
protected: | ||
void AdamPreSolve(); | ||
virtual void ComputeUpdateValue(int param_id, Dtype rate); | ||
|
||
DISABLE_COPY_AND_ASSIGN(AdamSolver); | ||
}; | ||
|
||
} // namespace caffe | ||
|
||
#endif // CAFFE_SGD_SOLVERS_HPP_ |
Oops, something went wrong.