Skip to content

Commit

Permalink
Merge pull request BVLC#3116 from ronghanghu/solver-refactor
Browse files Browse the repository at this point in the history
Solver Refactor: Separate files and Change Solver's Type to String
  • Loading branch information
shelhamer committed Oct 17, 2015
2 parents 46dac40 + 9563537 commit 16de340
Show file tree
Hide file tree
Showing 29 changed files with 1,463 additions and 1,047 deletions.
28 changes: 14 additions & 14 deletions docs/tutorial/solver.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@ The responsibilities of learning are divided between the Solver for overseeing t

The Caffe solvers are:

- Stochastic Gradient Descent (`SGD`),
- AdaDelta (`ADADELTA`),
- Adaptive Gradient (`ADAGRAD`),
- Adam (`ADAM`),
- Nesterov's Accelerated Gradient (`NESTEROV`) and
- RMSprop (`RMSPROP`)
- Stochastic Gradient Descent (`type: "SGD"`),
- AdaDelta (`type: "AdaDelta"`),
- Adaptive Gradient (`type: "AdaGrad"`),
- Adam (`type: "Adam"`),
- Nesterov's Accelerated Gradient (`type: "Nesterov"`) and
- RMSprop (`type: "RMSProp"`)

The solver

Expand Down Expand Up @@ -51,7 +51,7 @@ The parameter update $$\Delta W$$ is formed by the solver from the error gradien

### SGD

**Stochastic gradient descent** (`solver_type: SGD`) updates the weights $$ W $$ by a linear combination of the negative gradient $$ \nabla L(W) $$ and the previous weight update $$ V_t $$.
**Stochastic gradient descent** (`type: "SGD"`) updates the weights $$ W $$ by a linear combination of the negative gradient $$ \nabla L(W) $$ and the previous weight update $$ V_t $$.
The **learning rate** $$ \alpha $$ is the weight of the negative gradient.
The **momentum** $$ \mu $$ is the weight of the previous update.

Expand Down Expand Up @@ -113,7 +113,7 @@ If learning diverges (e.g., you start to see very large or `NaN` or `inf` loss v

### AdaDelta

The **AdaDelta** (`solver_type: ADADELTA`) method (M. Zeiler [1]) is a "robust learning rate method". It is a gradient-based optimization method (like SGD). The update formulas are
The **AdaDelta** (`type: "AdaDelta"`) method (M. Zeiler [1]) is a "robust learning rate method". It is a gradient-based optimization method (like SGD). The update formulas are

$$
\begin{align}
Expand All @@ -125,7 +125,7 @@ E[g^2]_t &= \delta{E[g^2]_{t-1} } + (1-\delta)g_{t}^2
\end{align}
$$

and
and

$$
(W_{t+1})_i =
Expand All @@ -139,7 +139,7 @@ $$

### AdaGrad

The **adaptive gradient** (`solver_type: ADAGRAD`) method (Duchi et al. [1]) is a gradient-based optimization method (like SGD) that attempts to "find needles in haystacks in the form of very predictive but rarely seen features," in Duchi et al.'s words.
The **adaptive gradient** (`type: "AdaGrad"`) method (Duchi et al. [1]) is a gradient-based optimization method (like SGD) that attempts to "find needles in haystacks in the form of very predictive but rarely seen features," in Duchi et al.'s words.
Given the update information from all previous iterations $$ \left( \nabla L(W) \right)_{t'} $$ for $$ t' \in \{1, 2, ..., t\} $$,
the update formulas proposed by [1] are as follows, specified for each component $$i$$ of the weights $$W$$:

Expand All @@ -159,7 +159,7 @@ Note that in practice, for weights $$ W \in \mathcal{R}^d $$, AdaGrad implementa

### Adam

The **Adam** (`solver_type: ADAM`), proposed in Kingma et al. [1], is a gradient-based optimization method (like SGD). This includes an "adaptive moment estimation" ($$m_t, v_t$$) and can be regarded as a generalization of AdaGrad. The update formulas are
The **Adam** (`type: "Adam"`), proposed in Kingma et al. [1], is a gradient-based optimization method (like SGD). This includes an "adaptive moment estimation" ($$m_t, v_t$$) and can be regarded as a generalization of AdaGrad. The update formulas are

$$
(m_t)_i = \beta_1 (m_{t-1})_i + (1-\beta_1)(\nabla L(W_t))_i,\\
Expand All @@ -181,7 +181,7 @@ Kingma et al. [1] proposed to use $$\beta_1 = 0.9, \beta_2 = 0.999, \varepsilon

### NAG

**Nesterov's accelerated gradient** (`solver_type: NESTEROV`) was proposed by Nesterov [1] as an "optimal" method of convex optimization, achieving a convergence rate of $$ \mathcal{O}(1/t^2) $$ rather than the $$ \mathcal{O}(1/t) $$.
**Nesterov's accelerated gradient** (`type: "Nesterov"`) was proposed by Nesterov [1] as an "optimal" method of convex optimization, achieving a convergence rate of $$ \mathcal{O}(1/t^2) $$ rather than the $$ \mathcal{O}(1/t) $$.
Though the required assumptions to achieve the $$ \mathcal{O}(1/t^2) $$ convergence typically will not hold for deep networks trained with Caffe (e.g., due to non-smoothness and non-convexity), in practice NAG can be a very effective method for optimizing certain types of deep learning architectures, as demonstrated for deep MNIST autoencoders by Sutskever et al. [2].

The weight update formulas look very similar to the SGD updates given above:
Expand All @@ -206,10 +206,10 @@ What distinguishes the method from SGD is the weight setting $$ W $$ on which we

### RMSprop

The **RMSprop** (`solver_type: RMSPROP`), suggested by Tieleman in a Coursera course lecture, is a gradient-based optimization method (like SGD). The update formulas are
The **RMSprop** (`type: "RMSProp"`), suggested by Tieleman in a Coursera course lecture, is a gradient-based optimization method (like SGD). The update formulas are

$$
(v_t)_i =
(v_t)_i =
\begin{cases}
(v_{t-1})_i + \delta, &(\nabla L(W_t))_i(\nabla L(W_{t-1}))_i > 0\\
(v_{t-1})_i \cdot (1-\delta), & \text{else}
Expand Down
2 changes: 1 addition & 1 deletion examples/mnist/lenet_adadelta_solver.prototxt
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@ snapshot: 5000
snapshot_prefix: "examples/mnist/lenet_adadelta"
# solver mode: CPU or GPU
solver_mode: GPU
solver_type: ADADELTA
type: "AdaDelta"
delta: 1e-6
2 changes: 1 addition & 1 deletion examples/mnist/lenet_solver_adam.prototxt
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ max_iter: 10000
snapshot: 5000
snapshot_prefix: "examples/mnist/lenet"
# solver mode: CPU or GPU
solver_type: ADAM
type: "Adam"
solver_mode: GPU
2 changes: 1 addition & 1 deletion examples/mnist/lenet_solver_rmsprop.prototxt
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,5 @@ snapshot: 5000
snapshot_prefix: "examples/mnist/lenet_rmsprop"
# solver mode: CPU or GPU
solver_mode: GPU
solver_type: RMSPROP
type: "RMSProp"
rms_decay: 0.98
2 changes: 1 addition & 1 deletion examples/mnist/mnist_autoencoder_solver_adadelta.prototxt
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@ snapshot: 10000
snapshot_prefix: "examples/mnist/mnist_autoencoder_adadelta_train"
# solver mode: CPU or GPU
solver_mode: GPU
solver_type: ADADELTA
type: "AdaDelta"
2 changes: 1 addition & 1 deletion examples/mnist/mnist_autoencoder_solver_adagrad.prototxt
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ snapshot: 10000
snapshot_prefix: "examples/mnist/mnist_autoencoder_adagrad_train"
# solver mode: CPU or GPU
solver_mode: GPU
solver_type: ADAGRAD
type: "AdaGrad"
2 changes: 1 addition & 1 deletion examples/mnist/mnist_autoencoder_solver_nesterov.prototxt
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ snapshot_prefix: "examples/mnist/mnist_autoencoder_nesterov_train"
momentum: 0.95
# solver mode: CPU or GPU
solver_mode: GPU
solver_type: NESTEROV
type: "Nesterov"
2 changes: 2 additions & 0 deletions include/caffe/caffe.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,10 @@
#include "caffe/parallel.hpp"
#include "caffe/proto/caffe.pb.h"
#include "caffe/solver.hpp"
#include "caffe/solver_factory.hpp"
#include "caffe/util/benchmark.hpp"
#include "caffe/util/io.hpp"
#include "caffe/util/upgrade_proto.hpp"
#include "caffe/vision_layers.hpp"

#endif // CAFFE_CAFFE_HPP_
148 changes: 148 additions & 0 deletions include/caffe/sgd_solvers.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
#ifndef CAFFE_SGD_SOLVERS_HPP_
#define CAFFE_SGD_SOLVERS_HPP_

#include <string>
#include <vector>

#include "caffe/solver.hpp"

namespace caffe {

/**
* @brief Optimizes the parameters of a Net using
* stochastic gradient descent (SGD) with momentum.
*/
template <typename Dtype>
class SGDSolver : public Solver<Dtype> {
public:
explicit SGDSolver(const SolverParameter& param)
: Solver<Dtype>(param) { PreSolve(); }
explicit SGDSolver(const string& param_file)
: Solver<Dtype>(param_file) { PreSolve(); }
virtual inline const char* type() const { return "SGD"; }

const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; }

protected:
void PreSolve();
Dtype GetLearningRate();
virtual void ApplyUpdate();
virtual void Normalize(int param_id);
virtual void Regularize(int param_id);
virtual void ComputeUpdateValue(int param_id, Dtype rate);
virtual void ClipGradients();
virtual void SnapshotSolverState(const string& model_filename);
virtual void SnapshotSolverStateToBinaryProto(const string& model_filename);
virtual void SnapshotSolverStateToHDF5(const string& model_filename);
virtual void RestoreSolverStateFromHDF5(const string& state_file);
virtual void RestoreSolverStateFromBinaryProto(const string& state_file);
// history maintains the historical momentum data.
// update maintains update related data and is not needed in snapshots.
// temp maintains other information that might be needed in computation
// of gradients/updates and is not needed in snapshots
vector<shared_ptr<Blob<Dtype> > > history_, update_, temp_;

DISABLE_COPY_AND_ASSIGN(SGDSolver);
};

template <typename Dtype>
class NesterovSolver : public SGDSolver<Dtype> {
public:
explicit NesterovSolver(const SolverParameter& param)
: SGDSolver<Dtype>(param) {}
explicit NesterovSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) {}
virtual inline const char* type() const { return "Nesterov"; }

protected:
virtual void ComputeUpdateValue(int param_id, Dtype rate);

DISABLE_COPY_AND_ASSIGN(NesterovSolver);
};

template <typename Dtype>
class AdaGradSolver : public SGDSolver<Dtype> {
public:
explicit AdaGradSolver(const SolverParameter& param)
: SGDSolver<Dtype>(param) { constructor_sanity_check(); }
explicit AdaGradSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
virtual inline const char* type() const { return "AdaGrad"; }

protected:
virtual void ComputeUpdateValue(int param_id, Dtype rate);
void constructor_sanity_check() {
CHECK_EQ(0, this->param_.momentum())
<< "Momentum cannot be used with AdaGrad.";
}

DISABLE_COPY_AND_ASSIGN(AdaGradSolver);
};


template <typename Dtype>
class RMSPropSolver : public SGDSolver<Dtype> {
public:
explicit RMSPropSolver(const SolverParameter& param)
: SGDSolver<Dtype>(param) { constructor_sanity_check(); }
explicit RMSPropSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) { constructor_sanity_check(); }
virtual inline const char* type() const { return "RMSProp"; }

protected:
virtual void ComputeUpdateValue(int param_id, Dtype rate);
void constructor_sanity_check() {
CHECK_EQ(0, this->param_.momentum())
<< "Momentum cannot be used with RMSProp.";
CHECK_GE(this->param_.rms_decay(), 0)
<< "rms_decay should lie between 0 and 1.";
CHECK_LT(this->param_.rms_decay(), 1)
<< "rms_decay should lie between 0 and 1.";
}

DISABLE_COPY_AND_ASSIGN(RMSPropSolver);
};

template <typename Dtype>
class AdaDeltaSolver : public SGDSolver<Dtype> {
public:
explicit AdaDeltaSolver(const SolverParameter& param)
: SGDSolver<Dtype>(param) { AdaDeltaPreSolve(); }
explicit AdaDeltaSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) { AdaDeltaPreSolve(); }
virtual inline const char* type() const { return "AdaDelta"; }

protected:
void AdaDeltaPreSolve();
virtual void ComputeUpdateValue(int param_id, Dtype rate);

DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver);
};

/**
* @brief AdamSolver, an algorithm for first-order gradient-based optimization
* of stochastic objective functions, based on adaptive estimates of
* lower-order moments. Described in [1].
*
* [1] D. P. Kingma and J. L. Ba, "ADAM: A Method for Stochastic Optimization."
* arXiv preprint arXiv:1412.6980v8 (2014).
*/
template <typename Dtype>
class AdamSolver : public SGDSolver<Dtype> {
public:
explicit AdamSolver(const SolverParameter& param)
: SGDSolver<Dtype>(param) { AdamPreSolve();}
explicit AdamSolver(const string& param_file)
: SGDSolver<Dtype>(param_file) { AdamPreSolve(); }
virtual inline const char* type() const { return "Adam"; }

protected:
void AdamPreSolve();
virtual void ComputeUpdateValue(int param_id, Dtype rate);

DISABLE_COPY_AND_ASSIGN(AdamSolver);
};

} // namespace caffe

#endif // CAFFE_SGD_SOLVERS_HPP_
Loading

0 comments on commit 16de340

Please sign in to comment.