diff --git a/docs/tutorial/solver.md b/docs/tutorial/solver.md index b150f6487bc..b719f715a4b 100644 --- a/docs/tutorial/solver.md +++ b/docs/tutorial/solver.md @@ -8,12 +8,12 @@ The responsibilities of learning are divided between the Solver for overseeing t The Caffe solvers are: -- Stochastic Gradient Descent (`SGD`), -- AdaDelta (`ADADELTA`), -- Adaptive Gradient (`ADAGRAD`), -- Adam (`ADAM`), -- Nesterov's Accelerated Gradient (`NESTEROV`) and -- RMSprop (`RMSPROP`) +- Stochastic Gradient Descent (`type: "SGD"`), +- AdaDelta (`type: "AdaDelta"`), +- Adaptive Gradient (`type: "AdaGrad"`), +- Adam (`type: "Adam"`), +- Nesterov's Accelerated Gradient (`type: "Nesterov"`) and +- RMSprop (`type: "RMSProp"`) The solver @@ -51,7 +51,7 @@ The parameter update $$\Delta W$$ is formed by the solver from the error gradien ### SGD -**Stochastic gradient descent** (`solver_type: SGD`) updates the weights $$ W $$ by a linear combination of the negative gradient $$ \nabla L(W) $$ and the previous weight update $$ V_t $$. +**Stochastic gradient descent** (`type: "SGD"`) updates the weights $$ W $$ by a linear combination of the negative gradient $$ \nabla L(W) $$ and the previous weight update $$ V_t $$. The **learning rate** $$ \alpha $$ is the weight of the negative gradient. The **momentum** $$ \mu $$ is the weight of the previous update. @@ -113,7 +113,7 @@ If learning diverges (e.g., you start to see very large or `NaN` or `inf` loss v ### AdaDelta -The **AdaDelta** (`solver_type: ADADELTA`) method (M. Zeiler [1]) is a "robust learning rate method". It is a gradient-based optimization method (like SGD). The update formulas are +The **AdaDelta** (`type: "AdaDelta"`) method (M. Zeiler [1]) is a "robust learning rate method". It is a gradient-based optimization method (like SGD). The update formulas are $$ \begin{align} @@ -125,7 +125,7 @@ E[g^2]_t &= \delta{E[g^2]_{t-1} } + (1-\delta)g_{t}^2 \end{align} $$ -and +and $$ (W_{t+1})_i = @@ -139,7 +139,7 @@ $$ ### AdaGrad -The **adaptive gradient** (`solver_type: ADAGRAD`) method (Duchi et al. [1]) is a gradient-based optimization method (like SGD) that attempts to "find needles in haystacks in the form of very predictive but rarely seen features," in Duchi et al.'s words. +The **adaptive gradient** (`type: "AdaGrad"`) method (Duchi et al. [1]) is a gradient-based optimization method (like SGD) that attempts to "find needles in haystacks in the form of very predictive but rarely seen features," in Duchi et al.'s words. Given the update information from all previous iterations $$ \left( \nabla L(W) \right)_{t'} $$ for $$ t' \in \{1, 2, ..., t\} $$, the update formulas proposed by [1] are as follows, specified for each component $$i$$ of the weights $$W$$: @@ -159,7 +159,7 @@ Note that in practice, for weights $$ W \in \mathcal{R}^d $$, AdaGrad implementa ### Adam -The **Adam** (`solver_type: ADAM`), proposed in Kingma et al. [1], is a gradient-based optimization method (like SGD). This includes an "adaptive moment estimation" ($$m_t, v_t$$) and can be regarded as a generalization of AdaGrad. The update formulas are +The **Adam** (`type: "Adam"`), proposed in Kingma et al. [1], is a gradient-based optimization method (like SGD). This includes an "adaptive moment estimation" ($$m_t, v_t$$) and can be regarded as a generalization of AdaGrad. The update formulas are $$ (m_t)_i = \beta_1 (m_{t-1})_i + (1-\beta_1)(\nabla L(W_t))_i,\\ @@ -181,7 +181,7 @@ Kingma et al. [1] proposed to use $$\beta_1 = 0.9, \beta_2 = 0.999, \varepsilon ### NAG -**Nesterov's accelerated gradient** (`solver_type: NESTEROV`) was proposed by Nesterov [1] as an "optimal" method of convex optimization, achieving a convergence rate of $$ \mathcal{O}(1/t^2) $$ rather than the $$ \mathcal{O}(1/t) $$. +**Nesterov's accelerated gradient** (`type: "Nesterov"`) was proposed by Nesterov [1] as an "optimal" method of convex optimization, achieving a convergence rate of $$ \mathcal{O}(1/t^2) $$ rather than the $$ \mathcal{O}(1/t) $$. Though the required assumptions to achieve the $$ \mathcal{O}(1/t^2) $$ convergence typically will not hold for deep networks trained with Caffe (e.g., due to non-smoothness and non-convexity), in practice NAG can be a very effective method for optimizing certain types of deep learning architectures, as demonstrated for deep MNIST autoencoders by Sutskever et al. [2]. The weight update formulas look very similar to the SGD updates given above: @@ -206,10 +206,10 @@ What distinguishes the method from SGD is the weight setting $$ W $$ on which we ### RMSprop -The **RMSprop** (`solver_type: RMSPROP`), suggested by Tieleman in a Coursera course lecture, is a gradient-based optimization method (like SGD). The update formulas are +The **RMSprop** (`type: "RMSProp"`), suggested by Tieleman in a Coursera course lecture, is a gradient-based optimization method (like SGD). The update formulas are $$ -(v_t)_i = +(v_t)_i = \begin{cases} (v_{t-1})_i + \delta, &(\nabla L(W_t))_i(\nabla L(W_{t-1}))_i > 0\\ (v_{t-1})_i \cdot (1-\delta), & \text{else} diff --git a/examples/mnist/lenet_adadelta_solver.prototxt b/examples/mnist/lenet_adadelta_solver.prototxt index 776d1e06139..16176c0ffae 100644 --- a/examples/mnist/lenet_adadelta_solver.prototxt +++ b/examples/mnist/lenet_adadelta_solver.prototxt @@ -20,5 +20,5 @@ snapshot: 5000 snapshot_prefix: "examples/mnist/lenet_adadelta" # solver mode: CPU or GPU solver_mode: GPU -solver_type: ADADELTA +type: "AdaDelta" delta: 1e-6 diff --git a/examples/mnist/lenet_solver_adam.prototxt b/examples/mnist/lenet_solver_adam.prototxt index d22c5718f3f..4b5336b1a04 100644 --- a/examples/mnist/lenet_solver_adam.prototxt +++ b/examples/mnist/lenet_solver_adam.prototxt @@ -22,5 +22,5 @@ max_iter: 10000 snapshot: 5000 snapshot_prefix: "examples/mnist/lenet" # solver mode: CPU or GPU -solver_type: ADAM +type: "Adam" solver_mode: GPU diff --git a/examples/mnist/lenet_solver_rmsprop.prototxt b/examples/mnist/lenet_solver_rmsprop.prototxt index 74dadc51069..924b72d306e 100644 --- a/examples/mnist/lenet_solver_rmsprop.prototxt +++ b/examples/mnist/lenet_solver_rmsprop.prototxt @@ -23,5 +23,5 @@ snapshot: 5000 snapshot_prefix: "examples/mnist/lenet_rmsprop" # solver mode: CPU or GPU solver_mode: GPU -solver_type: RMSPROP +type: "RMSProp" rms_decay: 0.98 diff --git a/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt b/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt index 065647df31b..26c4084a374 100644 --- a/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt +++ b/examples/mnist/mnist_autoencoder_solver_adadelta.prototxt @@ -16,4 +16,4 @@ snapshot: 10000 snapshot_prefix: "examples/mnist/mnist_autoencoder_adadelta_train" # solver mode: CPU or GPU solver_mode: GPU -solver_type: ADADELTA +type: "AdaDelta" diff --git a/examples/mnist/mnist_autoencoder_solver_adagrad.prototxt b/examples/mnist/mnist_autoencoder_solver_adagrad.prototxt index cc0ed9e310a..065cdb20ddc 100644 --- a/examples/mnist/mnist_autoencoder_solver_adagrad.prototxt +++ b/examples/mnist/mnist_autoencoder_solver_adagrad.prototxt @@ -14,4 +14,4 @@ snapshot: 10000 snapshot_prefix: "examples/mnist/mnist_autoencoder_adagrad_train" # solver mode: CPU or GPU solver_mode: GPU -solver_type: ADAGRAD +type: "AdaGrad" diff --git a/examples/mnist/mnist_autoencoder_solver_nesterov.prototxt b/examples/mnist/mnist_autoencoder_solver_nesterov.prototxt index 2a59fd45c8d..c95e3fe7e49 100644 --- a/examples/mnist/mnist_autoencoder_solver_nesterov.prototxt +++ b/examples/mnist/mnist_autoencoder_solver_nesterov.prototxt @@ -17,4 +17,4 @@ snapshot_prefix: "examples/mnist/mnist_autoencoder_nesterov_train" momentum: 0.95 # solver mode: CPU or GPU solver_mode: GPU -solver_type: NESTEROV +type: "Nesterov" diff --git a/include/caffe/caffe.hpp b/include/caffe/caffe.hpp index 68a5e1d1d1a..a339efba5c0 100644 --- a/include/caffe/caffe.hpp +++ b/include/caffe/caffe.hpp @@ -13,8 +13,10 @@ #include "caffe/parallel.hpp" #include "caffe/proto/caffe.pb.h" #include "caffe/solver.hpp" +#include "caffe/solver_factory.hpp" #include "caffe/util/benchmark.hpp" #include "caffe/util/io.hpp" +#include "caffe/util/upgrade_proto.hpp" #include "caffe/vision_layers.hpp" #endif // CAFFE_CAFFE_HPP_ diff --git a/include/caffe/sgd_solvers.hpp b/include/caffe/sgd_solvers.hpp new file mode 100644 index 00000000000..1fc52d87137 --- /dev/null +++ b/include/caffe/sgd_solvers.hpp @@ -0,0 +1,148 @@ +#ifndef CAFFE_SGD_SOLVERS_HPP_ +#define CAFFE_SGD_SOLVERS_HPP_ + +#include +#include + +#include "caffe/solver.hpp" + +namespace caffe { + +/** + * @brief Optimizes the parameters of a Net using + * stochastic gradient descent (SGD) with momentum. + */ +template +class SGDSolver : public Solver { + public: + explicit SGDSolver(const SolverParameter& param) + : Solver(param) { PreSolve(); } + explicit SGDSolver(const string& param_file) + : Solver(param_file) { PreSolve(); } + virtual inline const char* type() const { return "SGD"; } + + const vector > >& history() { return history_; } + + protected: + void PreSolve(); + Dtype GetLearningRate(); + virtual void ApplyUpdate(); + virtual void Normalize(int param_id); + virtual void Regularize(int param_id); + virtual void ComputeUpdateValue(int param_id, Dtype rate); + virtual void ClipGradients(); + virtual void SnapshotSolverState(const string& model_filename); + virtual void SnapshotSolverStateToBinaryProto(const string& model_filename); + virtual void SnapshotSolverStateToHDF5(const string& model_filename); + virtual void RestoreSolverStateFromHDF5(const string& state_file); + virtual void RestoreSolverStateFromBinaryProto(const string& state_file); + // history maintains the historical momentum data. + // update maintains update related data and is not needed in snapshots. + // temp maintains other information that might be needed in computation + // of gradients/updates and is not needed in snapshots + vector > > history_, update_, temp_; + + DISABLE_COPY_AND_ASSIGN(SGDSolver); +}; + +template +class NesterovSolver : public SGDSolver { + public: + explicit NesterovSolver(const SolverParameter& param) + : SGDSolver(param) {} + explicit NesterovSolver(const string& param_file) + : SGDSolver(param_file) {} + virtual inline const char* type() const { return "Nesterov"; } + + protected: + virtual void ComputeUpdateValue(int param_id, Dtype rate); + + DISABLE_COPY_AND_ASSIGN(NesterovSolver); +}; + +template +class AdaGradSolver : public SGDSolver { + public: + explicit AdaGradSolver(const SolverParameter& param) + : SGDSolver(param) { constructor_sanity_check(); } + explicit AdaGradSolver(const string& param_file) + : SGDSolver(param_file) { constructor_sanity_check(); } + virtual inline const char* type() const { return "AdaGrad"; } + + protected: + virtual void ComputeUpdateValue(int param_id, Dtype rate); + void constructor_sanity_check() { + CHECK_EQ(0, this->param_.momentum()) + << "Momentum cannot be used with AdaGrad."; + } + + DISABLE_COPY_AND_ASSIGN(AdaGradSolver); +}; + + +template +class RMSPropSolver : public SGDSolver { + public: + explicit RMSPropSolver(const SolverParameter& param) + : SGDSolver(param) { constructor_sanity_check(); } + explicit RMSPropSolver(const string& param_file) + : SGDSolver(param_file) { constructor_sanity_check(); } + virtual inline const char* type() const { return "RMSProp"; } + + protected: + virtual void ComputeUpdateValue(int param_id, Dtype rate); + void constructor_sanity_check() { + CHECK_EQ(0, this->param_.momentum()) + << "Momentum cannot be used with RMSProp."; + CHECK_GE(this->param_.rms_decay(), 0) + << "rms_decay should lie between 0 and 1."; + CHECK_LT(this->param_.rms_decay(), 1) + << "rms_decay should lie between 0 and 1."; + } + + DISABLE_COPY_AND_ASSIGN(RMSPropSolver); +}; + +template +class AdaDeltaSolver : public SGDSolver { + public: + explicit AdaDeltaSolver(const SolverParameter& param) + : SGDSolver(param) { AdaDeltaPreSolve(); } + explicit AdaDeltaSolver(const string& param_file) + : SGDSolver(param_file) { AdaDeltaPreSolve(); } + virtual inline const char* type() const { return "AdaDelta"; } + + protected: + void AdaDeltaPreSolve(); + virtual void ComputeUpdateValue(int param_id, Dtype rate); + + DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver); +}; + +/** + * @brief AdamSolver, an algorithm for first-order gradient-based optimization + * of stochastic objective functions, based on adaptive estimates of + * lower-order moments. Described in [1]. + * + * [1] D. P. Kingma and J. L. Ba, "ADAM: A Method for Stochastic Optimization." + * arXiv preprint arXiv:1412.6980v8 (2014). + */ +template +class AdamSolver : public SGDSolver { + public: + explicit AdamSolver(const SolverParameter& param) + : SGDSolver(param) { AdamPreSolve();} + explicit AdamSolver(const string& param_file) + : SGDSolver(param_file) { AdamPreSolve(); } + virtual inline const char* type() const { return "Adam"; } + + protected: + void AdamPreSolve(); + virtual void ComputeUpdateValue(int param_id, Dtype rate); + + DISABLE_COPY_AND_ASSIGN(AdamSolver); +}; + +} // namespace caffe + +#endif // CAFFE_SGD_SOLVERS_HPP_ diff --git a/include/caffe/solver.hpp b/include/caffe/solver.hpp index 2ecf539baef..298a68f37df 100644 --- a/include/caffe/solver.hpp +++ b/include/caffe/solver.hpp @@ -1,10 +1,11 @@ -#ifndef CAFFE_OPTIMIZATION_SOLVER_HPP_ -#define CAFFE_OPTIMIZATION_SOLVER_HPP_ +#ifndef CAFFE_SOLVER_HPP_ +#define CAFFE_SOLVER_HPP_ #include #include #include #include "caffe/net.hpp" +#include "caffe/solver_factory.hpp" namespace caffe { @@ -83,6 +84,10 @@ class Solver { } void CheckSnapshotWritePermissions(); + /** + * @brief Returns the solver type. + */ + virtual inline const char* type() const { return ""; } protected: // Make and apply the update value for the current iteration. @@ -148,158 +153,6 @@ class WorkerSolver : public Solver { } }; -/** - * @brief Optimizes the parameters of a Net using - * stochastic gradient descent (SGD) with momentum. - */ -template -class SGDSolver : public Solver { - public: - explicit SGDSolver(const SolverParameter& param) - : Solver(param) { PreSolve(); } - explicit SGDSolver(const string& param_file) - : Solver(param_file) { PreSolve(); } - - const vector > >& history() { return history_; } - - protected: - void PreSolve(); - Dtype GetLearningRate(); - virtual void ApplyUpdate(); - virtual void Normalize(int param_id); - virtual void Regularize(int param_id); - virtual void ComputeUpdateValue(int param_id, Dtype rate); - virtual void ClipGradients(); - virtual void SnapshotSolverState(const string& model_filename); - virtual void SnapshotSolverStateToBinaryProto(const string& model_filename); - virtual void SnapshotSolverStateToHDF5(const string& model_filename); - virtual void RestoreSolverStateFromHDF5(const string& state_file); - virtual void RestoreSolverStateFromBinaryProto(const string& state_file); - // history maintains the historical momentum data. - // update maintains update related data and is not needed in snapshots. - // temp maintains other information that might be needed in computation - // of gradients/updates and is not needed in snapshots - vector > > history_, update_, temp_; - - DISABLE_COPY_AND_ASSIGN(SGDSolver); -}; - -template -class NesterovSolver : public SGDSolver { - public: - explicit NesterovSolver(const SolverParameter& param) - : SGDSolver(param) {} - explicit NesterovSolver(const string& param_file) - : SGDSolver(param_file) {} - - protected: - virtual void ComputeUpdateValue(int param_id, Dtype rate); - - DISABLE_COPY_AND_ASSIGN(NesterovSolver); -}; - -template -class AdaGradSolver : public SGDSolver { - public: - explicit AdaGradSolver(const SolverParameter& param) - : SGDSolver(param) { constructor_sanity_check(); } - explicit AdaGradSolver(const string& param_file) - : SGDSolver(param_file) { constructor_sanity_check(); } - - protected: - virtual void ComputeUpdateValue(int param_id, Dtype rate); - void constructor_sanity_check() { - CHECK_EQ(0, this->param_.momentum()) - << "Momentum cannot be used with AdaGrad."; - } - - DISABLE_COPY_AND_ASSIGN(AdaGradSolver); -}; - - -template -class RMSPropSolver : public SGDSolver { - public: - explicit RMSPropSolver(const SolverParameter& param) - : SGDSolver(param) { constructor_sanity_check(); } - explicit RMSPropSolver(const string& param_file) - : SGDSolver(param_file) { constructor_sanity_check(); } - - protected: - virtual void ComputeUpdateValue(int param_id, Dtype rate); - void constructor_sanity_check() { - CHECK_EQ(0, this->param_.momentum()) - << "Momentum cannot be used with RMSProp."; - CHECK_GE(this->param_.rms_decay(), 0) - << "rms_decay should lie between 0 and 1."; - CHECK_LT(this->param_.rms_decay(), 1) - << "rms_decay should lie between 0 and 1."; - } - - DISABLE_COPY_AND_ASSIGN(RMSPropSolver); -}; - -template -class AdaDeltaSolver : public SGDSolver { - public: - explicit AdaDeltaSolver(const SolverParameter& param) - : SGDSolver(param) { AdaDeltaPreSolve(); } - explicit AdaDeltaSolver(const string& param_file) - : SGDSolver(param_file) { AdaDeltaPreSolve(); } - - protected: - void AdaDeltaPreSolve(); - virtual void ComputeUpdateValue(int param_id, Dtype rate); - - DISABLE_COPY_AND_ASSIGN(AdaDeltaSolver); -}; - -/** - * @brief AdamSolver, an algorithm for first-order gradient-based optimization - * of stochastic objective functions, based on adaptive estimates of - * lower-order moments. Described in [1]. - * - * [1] D. P. Kingma and J. L. Ba, "ADAM: A Method for Stochastic Optimization." - * arXiv preprint arXiv:1412.6980v8 (2014). - */ -template -class AdamSolver : public SGDSolver { - public: - explicit AdamSolver(const SolverParameter& param) - : SGDSolver(param) { AdamPreSolve();} - explicit AdamSolver(const string& param_file) - : SGDSolver(param_file) { AdamPreSolve(); } - - protected: - void AdamPreSolve(); - virtual void ComputeUpdateValue(int param_id, Dtype rate); - - DISABLE_COPY_AND_ASSIGN(AdamSolver); -}; - -template -Solver* GetSolver(const SolverParameter& param) { - SolverParameter_SolverType type = param.solver_type(); - - switch (type) { - case SolverParameter_SolverType_SGD: - return new SGDSolver(param); - case SolverParameter_SolverType_NESTEROV: - return new NesterovSolver(param); - case SolverParameter_SolverType_ADAGRAD: - return new AdaGradSolver(param); - case SolverParameter_SolverType_RMSPROP: - return new RMSPropSolver(param); - case SolverParameter_SolverType_ADADELTA: - return new AdaDeltaSolver(param); - case SolverParameter_SolverType_ADAM: - return new AdamSolver(param); - default: - LOG(FATAL) << "Unknown SolverType: " << type; - } - return (Solver*) NULL; -} - } // namespace caffe -#endif // CAFFE_OPTIMIZATION_SOLVER_HPP_ +#endif // CAFFE_SOLVER_HPP_ diff --git a/include/caffe/solver_factory.hpp b/include/caffe/solver_factory.hpp new file mode 100644 index 00000000000..cfff721af40 --- /dev/null +++ b/include/caffe/solver_factory.hpp @@ -0,0 +1,137 @@ +/** + * @brief A solver factory that allows one to register solvers, similar to + * layer factory. During runtime, registered solvers could be called by passing + * a SolverParameter protobuffer to the CreateSolver function: + * + * SolverRegistry::CreateSolver(param); + * + * There are two ways to register a solver. Assuming that we have a solver like: + * + * template + * class MyAwesomeSolver : public Solver { + * // your implementations + * }; + * + * and its type is its C++ class name, but without the "Solver" at the end + * ("MyAwesomeSolver" -> "MyAwesome"). + * + * If the solver is going to be created simply by its constructor, in your c++ + * file, add the following line: + * + * REGISTER_SOLVER_CLASS(MyAwesome); + * + * Or, if the solver is going to be created by another creator function, in the + * format of: + * + * template + * Solver GetMyAwesomeSolver(const SolverParameter& param) { + * // your implementation + * } + * + * then you can register the creator function instead, like + * + * REGISTER_SOLVER_CREATOR(MyAwesome, GetMyAwesomeSolver) + * + * Note that each solver type should only be registered once. + */ + +#ifndef CAFFE_SOLVER_FACTORY_H_ +#define CAFFE_SOLVER_FACTORY_H_ + +#include +#include +#include + +#include "caffe/common.hpp" +#include "caffe/proto/caffe.pb.h" + +namespace caffe { + +template +class Solver; + +template +class SolverRegistry { + public: + typedef Solver* (*Creator)(const SolverParameter&); + typedef std::map CreatorRegistry; + + static CreatorRegistry& Registry() { + static CreatorRegistry* g_registry_ = new CreatorRegistry(); + return *g_registry_; + } + + // Adds a creator. + static void AddCreator(const string& type, Creator creator) { + CreatorRegistry& registry = Registry(); + CHECK_EQ(registry.count(type), 0) + << "Solver type " << type << " already registered."; + registry[type] = creator; + } + + // Get a solver using a SolverParameter. + static Solver* CreateSolver(const SolverParameter& param) { + const string& type = param.type(); + CreatorRegistry& registry = Registry(); + CHECK_EQ(registry.count(type), 1) << "Unknown solver type: " << type + << " (known types: " << SolverTypeListString() << ")"; + return registry[type](param); + } + + static vector SolverTypeList() { + CreatorRegistry& registry = Registry(); + vector solver_types; + for (typename CreatorRegistry::iterator iter = registry.begin(); + iter != registry.end(); ++iter) { + solver_types.push_back(iter->first); + } + return solver_types; + } + + private: + // Solver registry should never be instantiated - everything is done with its + // static variables. + SolverRegistry() {} + + static string SolverTypeListString() { + vector solver_types = SolverTypeList(); + string solver_types_str; + for (vector::iterator iter = solver_types.begin(); + iter != solver_types.end(); ++iter) { + if (iter != solver_types.begin()) { + solver_types_str += ", "; + } + solver_types_str += *iter; + } + return solver_types_str; + } +}; + + +template +class SolverRegisterer { + public: + SolverRegisterer(const string& type, + Solver* (*creator)(const SolverParameter&)) { + // LOG(INFO) << "Registering solver type: " << type; + SolverRegistry::AddCreator(type, creator); + } +}; + + +#define REGISTER_SOLVER_CREATOR(type, creator) \ + static SolverRegisterer g_creator_f_##type(#type, creator); \ + static SolverRegisterer g_creator_d_##type(#type, creator) \ + +#define REGISTER_SOLVER_CLASS(type) \ + template \ + Solver* Creator_##type##Solver( \ + const SolverParameter& param) \ + { \ + return new type##Solver(param); \ + } \ + REGISTER_SOLVER_CREATOR(type, Creator_##type##Solver) + +} // namespace caffe + +#endif // CAFFE_SOLVER_FACTORY_H_ diff --git a/include/caffe/util/upgrade_proto.hpp b/include/caffe/util/upgrade_proto.hpp index 6a1418434a6..c94bb3caaa3 100644 --- a/include/caffe/util/upgrade_proto.hpp +++ b/include/caffe/util/upgrade_proto.hpp @@ -59,6 +59,18 @@ bool UpgradeV1LayerParameter(const V1LayerParameter& v1_layer_param, const char* UpgradeV1LayerType(const V1LayerParameter_LayerType type); +// Return true iff the solver contains any old solver_type specified as enums +bool SolverNeedsTypeUpgrade(const SolverParameter& solver_param); + +bool UpgradeSolverType(SolverParameter* solver_param); + +// Check for deprecations and upgrade the SolverParameter as needed. +bool UpgradeSolverAsNeeded(const string& param_file, SolverParameter* param); + +// Read parameters from a file into a SolverParameter proto message. +void ReadSolverParamsFromTextFileOrDie(const string& param_file, + SolverParameter* param); + } // namespace caffe #endif // CAFFE_UTIL_UPGRADE_PROTO_H_ diff --git a/matlab/+caffe/private/caffe_.cpp b/matlab/+caffe/private/caffe_.cpp index 7883f79ebd9..1641e14b534 100644 --- a/matlab/+caffe/private/caffe_.cpp +++ b/matlab/+caffe/private/caffe_.cpp @@ -188,7 +188,10 @@ static void get_solver(MEX_ARGS) { "Usage: caffe_('get_solver', solver_file)"); char* solver_file = mxArrayToString(prhs[0]); mxCHECK_FILE_EXIST(solver_file); - shared_ptr > solver(new caffe::SGDSolver(solver_file)); + SolverParameter solver_param; + ReadSolverParamsFromTextFileOrDie(solver_file, &solver_param); + shared_ptr > solver( + SolverRegistry::CreateSolver(solver_param)); solvers_.push_back(solver); plhs[0] = ptr_to_handle >(solver.get()); mxFree(solver_file); diff --git a/python/caffe/_caffe.cpp b/python/caffe/_caffe.cpp index ccd5776ac40..8687dd872eb 100644 --- a/python/caffe/_caffe.cpp +++ b/python/caffe/_caffe.cpp @@ -16,6 +16,7 @@ #include "caffe/caffe.hpp" #include "caffe/python_layer.hpp" +#include "caffe/sgd_solvers.hpp" // Temporary solution for numpy < 1.7 versions: old macro, no promises. // You're strongly advised to upgrade to >= 1.7. @@ -133,8 +134,8 @@ void Net_SetInputArrays(Net* net, bp::object data_obj, Solver* GetSolverFromFile(const string& filename) { SolverParameter param; - ReadProtoFromTextFileOrDie(filename, ¶m); - return GetSolver(param); + ReadSolverParamsFromTextFileOrDie(filename, ¶m); + return SolverRegistry::CreateSolver(param); } struct NdarrayConverterGenerator { diff --git a/src/caffe/proto/caffe.proto b/src/caffe/proto/caffe.proto index 4794991f917..76c869c127e 100644 --- a/src/caffe/proto/caffe.proto +++ b/src/caffe/proto/caffe.proto @@ -98,7 +98,7 @@ message NetParameter { // NOTE // Update the next available ID when you add a new SolverParameter field. // -// SolverParameter next available ID: 40 (last added: momentum2) +// SolverParameter next available ID: 41 (last added: type) message SolverParameter { ////////////////////////////////////////////////////////////////////////////// // Specifying the train and test networks @@ -209,16 +209,9 @@ message SolverParameter { // (and by default) initialize using a seed derived from the system clock. optional int64 random_seed = 20 [default = -1]; - // Solver type - enum SolverType { - SGD = 0; - NESTEROV = 1; - ADAGRAD = 2; - RMSPROP = 3; - ADADELTA = 4; - ADAM = 5; - } - optional SolverType solver_type = 30 [default = SGD]; + // type of the solver + optional string type = 40 [default = "SGD"]; + // numerical stability for RMSProp, AdaGrad and AdaDelta and Adam optional float delta = 31 [default = 1e-8]; // parameters for the Adam solver @@ -234,6 +227,18 @@ message SolverParameter { // If false, don't save a snapshot after training finishes. optional bool snapshot_after_train = 28 [default = true]; + + // DEPRECATED: old solver enum types, use string instead + enum SolverType { + SGD = 0; + NESTEROV = 1; + ADAGRAD = 2; + RMSPROP = 3; + ADADELTA = 4; + ADAM = 5; + } + // DEPRECATED: use type instead of solver_type + optional SolverType solver_type = 30 [default = SGD]; } // A message that stores the solver snapshots diff --git a/src/caffe/solver.cpp b/src/caffe/solver.cpp index 12c13dd8385..d3bc7361dd5 100644 --- a/src/caffe/solver.cpp +++ b/src/caffe/solver.cpp @@ -1,18 +1,11 @@ #include -#include #include #include -#include "hdf5.h" -#include "hdf5_hl.h" - -#include "caffe/net.hpp" -#include "caffe/proto/caffe.pb.h" #include "caffe/solver.hpp" #include "caffe/util/hdf5.hpp" #include "caffe/util/io.hpp" -#include "caffe/util/math_functions.hpp" #include "caffe/util/upgrade_proto.hpp" namespace caffe { @@ -43,7 +36,7 @@ Solver::Solver(const string& param_file, const Solver* root_solver) : net_(), callbacks_(), root_solver_(root_solver), requested_early_exit_(false) { SolverParameter param; - ReadProtoFromTextFileOrDie(param_file, ¶m); + ReadSolverParamsFromTextFileOrDie(param_file, ¶m); Init(param); } @@ -492,810 +485,6 @@ void Solver::Restore(const char* state_file) { } } -// Return the current learning rate. The currently implemented learning rate -// policies are as follows: -// - fixed: always return base_lr. -// - step: return base_lr * gamma ^ (floor(iter / step)) -// - exp: return base_lr * gamma ^ iter -// - inv: return base_lr * (1 + gamma * iter) ^ (- power) -// - multistep: similar to step but it allows non uniform steps defined by -// stepvalue -// - poly: the effective learning rate follows a polynomial decay, to be -// zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) -// - sigmoid: the effective learning rate follows a sigmod decay -// return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize)))) -// -// where base_lr, max_iter, gamma, step, stepvalue and power are defined -// in the solver parameter protocol buffer, and iter is the current iteration. -template -Dtype SGDSolver::GetLearningRate() { - Dtype rate; - const string& lr_policy = this->param_.lr_policy(); - if (lr_policy == "fixed") { - rate = this->param_.base_lr(); - } else if (lr_policy == "step") { - this->current_step_ = this->iter_ / this->param_.stepsize(); - rate = this->param_.base_lr() * - pow(this->param_.gamma(), this->current_step_); - } else if (lr_policy == "exp") { - rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_); - } else if (lr_policy == "inv") { - rate = this->param_.base_lr() * - pow(Dtype(1) + this->param_.gamma() * this->iter_, - - this->param_.power()); - } else if (lr_policy == "multistep") { - if (this->current_step_ < this->param_.stepvalue_size() && - this->iter_ >= this->param_.stepvalue(this->current_step_)) { - this->current_step_++; - LOG(INFO) << "MultiStep Status: Iteration " << - this->iter_ << ", step = " << this->current_step_; - } - rate = this->param_.base_lr() * - pow(this->param_.gamma(), this->current_step_); - } else if (lr_policy == "poly") { - rate = this->param_.base_lr() * pow(Dtype(1.) - - (Dtype(this->iter_) / Dtype(this->param_.max_iter())), - this->param_.power()); - } else if (lr_policy == "sigmoid") { - rate = this->param_.base_lr() * (Dtype(1.) / - (Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) - - Dtype(this->param_.stepsize()))))); - } else { - LOG(FATAL) << "Unknown learning rate policy: " << lr_policy; - } - return rate; -} - -template -void SGDSolver::PreSolve() { - // Initialize the history - const vector*>& net_params = this->net_->learnable_params(); - history_.clear(); - update_.clear(); - temp_.clear(); - for (int i = 0; i < net_params.size(); ++i) { - const vector& shape = net_params[i]->shape(); - history_.push_back(shared_ptr >(new Blob(shape))); - update_.push_back(shared_ptr >(new Blob(shape))); - temp_.push_back(shared_ptr >(new Blob(shape))); - } -} - -template -void SGDSolver::ClipGradients() { - const Dtype clip_gradients = this->param_.clip_gradients(); - if (clip_gradients < 0) { return; } - const vector*>& net_params = this->net_->learnable_params(); - Dtype sumsq_diff = 0; - for (int i = 0; i < net_params.size(); ++i) { - sumsq_diff += net_params[i]->sumsq_diff(); - } - const Dtype l2norm_diff = std::sqrt(sumsq_diff); - if (l2norm_diff > clip_gradients) { - Dtype scale_factor = clip_gradients / l2norm_diff; - LOG(INFO) << "Gradient clipping: scaling down gradients (L2 norm " - << l2norm_diff << " > " << clip_gradients << ") " - << "by scale factor " << scale_factor; - for (int i = 0; i < net_params.size(); ++i) { - net_params[i]->scale_diff(scale_factor); - } - } -} - -template -void SGDSolver::ApplyUpdate() { - CHECK(Caffe::root_solver()); - Dtype rate = GetLearningRate(); - if (this->param_.display() && this->iter_ % this->param_.display() == 0) { - LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; - } - ClipGradients(); - for (int param_id = 0; param_id < this->net_->learnable_params().size(); - ++param_id) { - Normalize(param_id); - Regularize(param_id); - ComputeUpdateValue(param_id, rate); - } - this->net_->Update(); -} - -template -void SGDSolver::Normalize(int param_id) { - if (this->param_.iter_size() == 1) { return; } - // Scale gradient to counterbalance accumulation. - const vector*>& net_params = this->net_->learnable_params(); - const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size(); - switch (Caffe::mode()) { - case Caffe::CPU: { - caffe_scal(net_params[param_id]->count(), accum_normalization, - net_params[param_id]->mutable_cpu_diff()); - break; - } - case Caffe::GPU: { -#ifndef CPU_ONLY - caffe_gpu_scal(net_params[param_id]->count(), accum_normalization, - net_params[param_id]->mutable_gpu_diff()); -#else - NO_GPU; -#endif - break; - } - default: - LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); - } -} - -template -void SGDSolver::Regularize(int param_id) { - const vector*>& net_params = this->net_->learnable_params(); - const vector& net_params_weight_decay = - this->net_->params_weight_decay(); - Dtype weight_decay = this->param_.weight_decay(); - string regularization_type = this->param_.regularization_type(); - Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; - switch (Caffe::mode()) { - case Caffe::CPU: { - if (local_decay) { - if (regularization_type == "L2") { - // add weight decay - caffe_axpy(net_params[param_id]->count(), - local_decay, - net_params[param_id]->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); - } else if (regularization_type == "L1") { - caffe_cpu_sign(net_params[param_id]->count(), - net_params[param_id]->cpu_data(), - temp_[param_id]->mutable_cpu_data()); - caffe_axpy(net_params[param_id]->count(), - local_decay, - temp_[param_id]->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); - } else { - LOG(FATAL) << "Unknown regularization type: " << regularization_type; - } - } - break; - } - case Caffe::GPU: { -#ifndef CPU_ONLY - if (local_decay) { - if (regularization_type == "L2") { - // add weight decay - caffe_gpu_axpy(net_params[param_id]->count(), - local_decay, - net_params[param_id]->gpu_data(), - net_params[param_id]->mutable_gpu_diff()); - } else if (regularization_type == "L1") { - caffe_gpu_sign(net_params[param_id]->count(), - net_params[param_id]->gpu_data(), - temp_[param_id]->mutable_gpu_data()); - caffe_gpu_axpy(net_params[param_id]->count(), - local_decay, - temp_[param_id]->gpu_data(), - net_params[param_id]->mutable_gpu_diff()); - } else { - LOG(FATAL) << "Unknown regularization type: " << regularization_type; - } - } -#else - NO_GPU; -#endif - break; - } - default: - LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); - } -} - -template -void SGDSolver::ComputeUpdateValue(int param_id, Dtype rate) { - const vector*>& net_params = this->net_->learnable_params(); - const vector& net_params_lr = this->net_->params_lr(); - Dtype momentum = this->param_.momentum(); - Dtype local_rate = rate * net_params_lr[param_id]; - // Compute the update to history, then copy it to the parameter diff. - switch (Caffe::mode()) { - case Caffe::CPU: { - caffe_cpu_axpby(net_params[param_id]->count(), local_rate, - net_params[param_id]->cpu_diff(), momentum, - history_[param_id]->mutable_cpu_data()); - caffe_copy(net_params[param_id]->count(), - history_[param_id]->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); - break; - } - case Caffe::GPU: { -#ifndef CPU_ONLY - caffe_gpu_axpby(net_params[param_id]->count(), local_rate, - net_params[param_id]->gpu_diff(), momentum, - history_[param_id]->mutable_gpu_data()); - caffe_copy(net_params[param_id]->count(), - history_[param_id]->gpu_data(), - net_params[param_id]->mutable_gpu_diff()); -#else - NO_GPU; -#endif - break; - } - default: - LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); - } -} - -template -void SGDSolver::SnapshotSolverState(const string& model_filename) { - switch (this->param_.snapshot_format()) { - case caffe::SolverParameter_SnapshotFormat_BINARYPROTO: - SnapshotSolverStateToBinaryProto(model_filename); - break; - case caffe::SolverParameter_SnapshotFormat_HDF5: - SnapshotSolverStateToHDF5(model_filename); - break; - default: - LOG(FATAL) << "Unsupported snapshot format."; - } -} - -template -void SGDSolver::SnapshotSolverStateToBinaryProto( - const string& model_filename) { - SolverState state; - state.set_iter(this->iter_); - state.set_learned_net(model_filename); - state.set_current_step(this->current_step_); - state.clear_history(); - for (int i = 0; i < history_.size(); ++i) { - // Add history - BlobProto* history_blob = state.add_history(); - history_[i]->ToProto(history_blob); - } - string snapshot_filename = Solver::SnapshotFilename(".solverstate"); - LOG(INFO) - << "Snapshotting solver state to binary proto file " << snapshot_filename; - WriteProtoToBinaryFile(state, snapshot_filename.c_str()); -} - -template -void SGDSolver::SnapshotSolverStateToHDF5( - const string& model_filename) { - string snapshot_filename = - Solver::SnapshotFilename(".solverstate.h5"); - LOG(INFO) << "Snapshotting solver state to HDF5 file " << snapshot_filename; - hid_t file_hid = H5Fcreate(snapshot_filename.c_str(), H5F_ACC_TRUNC, - H5P_DEFAULT, H5P_DEFAULT); - CHECK_GE(file_hid, 0) - << "Couldn't open " << snapshot_filename << " to save solver state."; - hdf5_save_int(file_hid, "iter", this->iter_); - hdf5_save_string(file_hid, "learned_net", model_filename); - hdf5_save_int(file_hid, "current_step", this->current_step_); - hid_t history_hid = H5Gcreate2(file_hid, "history", H5P_DEFAULT, H5P_DEFAULT, - H5P_DEFAULT); - CHECK_GE(history_hid, 0) - << "Error saving solver state to " << snapshot_filename << "."; - for (int i = 0; i < history_.size(); ++i) { - ostringstream oss; - oss << i; - hdf5_save_nd_dataset(history_hid, oss.str(), *history_[i]); - } - H5Gclose(history_hid); - H5Fclose(file_hid); -} - -template -void SGDSolver::RestoreSolverStateFromBinaryProto( - const string& state_file) { - SolverState state; - ReadProtoFromBinaryFile(state_file, &state); - this->iter_ = state.iter(); - if (state.has_learned_net()) { - NetParameter net_param; - ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param); - this->net_->CopyTrainedLayersFrom(net_param); - } - this->current_step_ = state.current_step(); - CHECK_EQ(state.history_size(), history_.size()) - << "Incorrect length of history blobs."; - LOG(INFO) << "SGDSolver: restoring history"; - for (int i = 0; i < history_.size(); ++i) { - history_[i]->FromProto(state.history(i)); - } -} - -template -void SGDSolver::RestoreSolverStateFromHDF5(const string& state_file) { - hid_t file_hid = H5Fopen(state_file.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT); - CHECK_GE(file_hid, 0) << "Couldn't open solver state file " << state_file; - this->iter_ = hdf5_load_int(file_hid, "iter"); - if (H5LTfind_dataset(file_hid, "learned_net")) { - string learned_net = hdf5_load_string(file_hid, "learned_net"); - this->net_->CopyTrainedLayersFrom(learned_net); - } - this->current_step_ = hdf5_load_int(file_hid, "current_step"); - hid_t history_hid = H5Gopen2(file_hid, "history", H5P_DEFAULT); - CHECK_GE(history_hid, 0) << "Error reading history from " << state_file; - int state_history_size = hdf5_get_num_links(history_hid); - CHECK_EQ(state_history_size, history_.size()) - << "Incorrect length of history blobs."; - for (int i = 0; i < history_.size(); ++i) { - ostringstream oss; - oss << i; - hdf5_load_nd_dataset(history_hid, oss.str().c_str(), 0, - kMaxBlobAxes, history_[i].get()); - } - H5Gclose(history_hid); - H5Fclose(file_hid); -} - -template -void NesterovSolver::ComputeUpdateValue(int param_id, Dtype rate) { - CHECK(Caffe::root_solver()); - const vector*>& net_params = this->net_->learnable_params(); - const vector& net_params_lr = this->net_->params_lr(); - Dtype momentum = this->param_.momentum(); - Dtype local_rate = rate * net_params_lr[param_id]; - switch (Caffe::mode()) { - case Caffe::CPU: { - // save history momentum for stepping back - caffe_copy(net_params[param_id]->count(), - this->history_[param_id]->cpu_data(), - this->update_[param_id]->mutable_cpu_data()); - - // update history - caffe_cpu_axpby(net_params[param_id]->count(), local_rate, - net_params[param_id]->cpu_diff(), momentum, - this->history_[param_id]->mutable_cpu_data()); - - // compute update: step back then over step - caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum, - this->history_[param_id]->cpu_data(), -momentum, - this->update_[param_id]->mutable_cpu_data()); - - // copy - caffe_copy(net_params[param_id]->count(), - this->update_[param_id]->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); - break; - } - case Caffe::GPU: { -#ifndef CPU_ONLY - // save history momentum for stepping back - caffe_copy(net_params[param_id]->count(), - this->history_[param_id]->gpu_data(), - this->update_[param_id]->mutable_gpu_data()); - - // update history - caffe_gpu_axpby(net_params[param_id]->count(), local_rate, - net_params[param_id]->gpu_diff(), momentum, - this->history_[param_id]->mutable_gpu_data()); - - // compute update: step back then over step - caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum, - this->history_[param_id]->gpu_data(), -momentum, - this->update_[param_id]->mutable_gpu_data()); - - // copy - caffe_copy(net_params[param_id]->count(), - this->update_[param_id]->gpu_data(), - net_params[param_id]->mutable_gpu_diff()); -#else - NO_GPU; -#endif - break; - } - default: - LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); - } -} - -template -void AdaGradSolver::ComputeUpdateValue(int param_id, Dtype rate) { - CHECK(Caffe::root_solver()); - const vector*>& net_params = this->net_->learnable_params(); - const vector& net_params_lr = this->net_->params_lr(); - Dtype delta = this->param_.delta(); - Dtype local_rate = rate * net_params_lr[param_id]; - switch (Caffe::mode()) { - case Caffe::CPU: { - // compute square of gradient in update - caffe_powx(net_params[param_id]->count(), - net_params[param_id]->cpu_diff(), Dtype(2), - this->update_[param_id]->mutable_cpu_data()); - - // update history - caffe_add(net_params[param_id]->count(), - this->update_[param_id]->cpu_data(), - this->history_[param_id]->cpu_data(), - this->history_[param_id]->mutable_cpu_data()); - - // prepare update - caffe_powx(net_params[param_id]->count(), - this->history_[param_id]->cpu_data(), Dtype(0.5), - this->update_[param_id]->mutable_cpu_data()); - - caffe_add_scalar(net_params[param_id]->count(), - delta, this->update_[param_id]->mutable_cpu_data()); - - caffe_div(net_params[param_id]->count(), - net_params[param_id]->cpu_diff(), - this->update_[param_id]->cpu_data(), - this->update_[param_id]->mutable_cpu_data()); - - // scale and copy - caffe_cpu_axpby(net_params[param_id]->count(), local_rate, - this->update_[param_id]->cpu_data(), Dtype(0), - net_params[param_id]->mutable_cpu_diff()); - break; - } - case Caffe::GPU: { -#ifndef CPU_ONLY - // compute square of gradient in update - caffe_gpu_powx(net_params[param_id]->count(), - net_params[param_id]->gpu_diff(), Dtype(2), - this->update_[param_id]->mutable_gpu_data()); - - // update history - caffe_gpu_add(net_params[param_id]->count(), - this->update_[param_id]->gpu_data(), - this->history_[param_id]->gpu_data(), - this->history_[param_id]->mutable_gpu_data()); - - // prepare update - caffe_gpu_powx(net_params[param_id]->count(), - this->history_[param_id]->gpu_data(), Dtype(0.5), - this->update_[param_id]->mutable_gpu_data()); - - caffe_gpu_add_scalar(net_params[param_id]->count(), - delta, this->update_[param_id]->mutable_gpu_data()); - - caffe_gpu_div(net_params[param_id]->count(), - net_params[param_id]->gpu_diff(), - this->update_[param_id]->gpu_data(), - this->update_[param_id]->mutable_gpu_data()); - - // scale and copy - caffe_gpu_axpby(net_params[param_id]->count(), local_rate, - this->update_[param_id]->gpu_data(), Dtype(0), - net_params[param_id]->mutable_gpu_diff()); -#else - NO_GPU; -#endif - break; - } - default: - LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); - } -} - -template -void RMSPropSolver::ComputeUpdateValue(int param_id, Dtype rate) { - const vector*>& net_params = this->net_->learnable_params(); - const vector& net_params_lr = this->net_->params_lr(); - - // get the learning rate - Dtype delta = this->param_.delta(); - Dtype rms_decay = this->param_.rms_decay(); - Dtype local_rate = rate * net_params_lr[param_id]; - - switch (Caffe::mode()) { - case Caffe::CPU: - // compute square of gradient in update - caffe_powx(net_params[param_id]->count(), - net_params[param_id]->cpu_diff(), Dtype(2), - this->update_[param_id]->mutable_cpu_data()); - - // update history - caffe_cpu_axpby(net_params[param_id] -> count(), - Dtype(1-rms_decay), this->update_[param_id]->cpu_data(), - rms_decay, this->history_[param_id]-> mutable_cpu_data()); - - // prepare update - caffe_powx(net_params[param_id]->count(), - this->history_[param_id]->cpu_data(), Dtype(0.5), - this->update_[param_id]->mutable_cpu_data()); - - caffe_add_scalar(net_params[param_id]->count(), - delta, this->update_[param_id]->mutable_cpu_data()); - - caffe_div(net_params[param_id]->count(), - net_params[param_id]->cpu_diff(), this->update_[param_id]->cpu_data(), - this->update_[param_id]->mutable_cpu_data()); - - // scale and copy - caffe_cpu_axpby(net_params[param_id]->count(), local_rate, - this->update_[param_id]->cpu_data(), Dtype(0), - net_params[param_id]->mutable_cpu_diff()); - break; - case Caffe::GPU: -#ifndef CPU_ONLY - // compute square of gradient in update - caffe_gpu_powx(net_params[param_id]->count(), - net_params[param_id]->gpu_diff(), Dtype(2), - this->update_[param_id]->mutable_gpu_data()); - - // update history - caffe_gpu_axpby(net_params[param_id] -> count(), - Dtype(1-rms_decay), this->update_[param_id]->gpu_data(), - rms_decay, this->history_[param_id]-> mutable_gpu_data()); - - // prepare update - caffe_gpu_powx(net_params[param_id]->count(), - this->history_[param_id]->gpu_data(), Dtype(0.5), - this->update_[param_id]->mutable_gpu_data()); - - caffe_gpu_add_scalar(net_params[param_id]->count(), - delta, this->update_[param_id]->mutable_gpu_data()); - - caffe_gpu_div(net_params[param_id]->count(), - net_params[param_id]->gpu_diff(), this->update_[param_id]->gpu_data(), - this->update_[param_id]->mutable_gpu_data()); - - caffe_gpu_axpby(net_params[param_id]->count(), local_rate, - this->update_[param_id]->gpu_data(), Dtype(0), - net_params[param_id]->mutable_gpu_diff()); -#else - NO_GPU; -#endif - break; - default: - LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); - } -} - -template -void AdaDeltaSolver::AdaDeltaPreSolve() { - // Add the extra history entries for AdaDelta after those from - // SGDSolver::PreSolve - const vector*>& net_params = this->net_->learnable_params(); - for (int i = 0; i < net_params.size(); ++i) { - const vector& shape = net_params[i]->shape(); - this->history_.push_back( - shared_ptr >(new Blob(shape))); - } -} - -template -void AdaDeltaSolver::ComputeUpdateValue(int param_id, Dtype rate) { - const vector*>& net_params = this->net_->learnable_params(); - const vector& net_params_lr = this->net_->params_lr(); - Dtype delta = this->param_.delta(); - Dtype momentum = this->param_.momentum(); - Dtype local_rate = rate * net_params_lr[param_id]; - size_t update_history_offset = net_params.size(); - switch (Caffe::mode()) { - case Caffe::CPU: { - // compute square of gradient in update - caffe_powx(net_params[param_id]->count(), - net_params[param_id]->cpu_diff(), Dtype(2), - this->update_[param_id]->mutable_cpu_data()); - - // update history of gradients - caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, - this->update_[param_id]->cpu_data(), momentum, - this->history_[param_id]->mutable_cpu_data()); - - // add delta to history to guard against dividing by zero later - caffe_set(net_params[param_id]->count(), delta, - this->temp_[param_id]->mutable_cpu_data()); - - caffe_add(net_params[param_id]->count(), - this->temp_[param_id]->cpu_data(), - this->history_[update_history_offset + param_id]->cpu_data(), - this->update_[param_id]->mutable_cpu_data()); - - caffe_add(net_params[param_id]->count(), - this->temp_[param_id]->cpu_data(), - this->history_[param_id]->cpu_data(), - this->temp_[param_id]->mutable_cpu_data()); - - // divide history of updates by history of gradients - caffe_div(net_params[param_id]->count(), - this->update_[param_id]->cpu_data(), - this->temp_[param_id]->cpu_data(), - this->update_[param_id]->mutable_cpu_data()); - - // jointly compute the RMS of both for update and gradient history - caffe_powx(net_params[param_id]->count(), - this->update_[param_id]->cpu_data(), Dtype(0.5), - this->update_[param_id]->mutable_cpu_data()); - - // compute the update - caffe_mul(net_params[param_id]->count(), - net_params[param_id]->cpu_diff(), - this->update_[param_id]->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); - - // compute square of update - caffe_powx(net_params[param_id]->count(), - net_params[param_id]->cpu_diff(), Dtype(2), - this->update_[param_id]->mutable_cpu_data()); - - // update history of updates - caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, - this->update_[param_id]->cpu_data(), momentum, - this->history_[update_history_offset + param_id]->mutable_cpu_data()); - - // apply learning rate - caffe_cpu_scale(net_params[param_id]->count(), local_rate, - net_params[param_id]->cpu_diff(), - net_params[param_id]->mutable_cpu_diff()); - break; - } - case Caffe::GPU: { -#ifndef CPU_ONLY - // compute square of gradient in update - caffe_gpu_powx(net_params[param_id]->count(), - net_params[param_id]->gpu_diff(), Dtype(2), - this->update_[param_id]->mutable_gpu_data()); - - // update history of gradients - caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, - this->update_[param_id]->gpu_data(), momentum, - this->history_[param_id]->mutable_gpu_data()); - - // add delta to history to guard against dividing by zero later - caffe_gpu_set(net_params[param_id]->count(), delta, - this->temp_[param_id]->mutable_gpu_data()); - - caffe_gpu_add(net_params[param_id]->count(), - this->temp_[param_id]->gpu_data(), - this->history_[update_history_offset + param_id]->gpu_data(), - this->update_[param_id]->mutable_gpu_data()); - - caffe_gpu_add(net_params[param_id]->count(), - this->temp_[param_id]->gpu_data(), - this->history_[param_id]->gpu_data(), - this->temp_[param_id]->mutable_gpu_data()); - - // divide history of updates by history of gradients - caffe_gpu_div(net_params[param_id]->count(), - this->update_[param_id]->gpu_data(), - this->temp_[param_id]->gpu_data(), - this->update_[param_id]->mutable_gpu_data()); - - // jointly compute the RMS of both for update and gradient history - caffe_gpu_powx(net_params[param_id]->count(), - this->update_[param_id]->gpu_data(), Dtype(0.5), - this->update_[param_id]->mutable_gpu_data()); - - // compute the update and copy to net_diff - caffe_gpu_mul(net_params[param_id]->count(), - net_params[param_id]->gpu_diff(), - this->update_[param_id]->gpu_data(), - net_params[param_id]->mutable_gpu_diff()); - - // compute square of update - caffe_gpu_powx(net_params[param_id]->count(), - net_params[param_id]->gpu_diff(), Dtype(2), - this->update_[param_id]->mutable_gpu_data()); - - // update history of updates - caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, - this->update_[param_id]->gpu_data(), momentum, - this->history_[update_history_offset + param_id]->mutable_gpu_data()); - - // apply learning rate - caffe_gpu_scale(net_params[param_id]->count(), local_rate, - net_params[param_id]->gpu_diff(), - net_params[param_id]->mutable_gpu_diff()); -#else - NO_GPU; -#endif - break; - } - default: - LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); - } -} - -template -void AdamSolver::AdamPreSolve() { - // Add the extra history entries for Adam after those from - // SGDSolver::PreSolve - const vector*>& net_params = this->net_->learnable_params(); - for (int i = 0; i < net_params.size(); ++i) { - const vector& shape = net_params[i]->shape(); - this->history_.push_back( - shared_ptr >(new Blob(shape))); - } -} - -template -void AdamSolver::ComputeUpdateValue(int param_id, Dtype rate) { - const vector*>& net_params = this->net_->learnable_params(); - const vector& net_params_lr = this->net_->params_lr(); - Dtype local_rate = rate * net_params_lr[param_id]; - const Dtype beta1 = this->param_.momentum(); - const Dtype beta2 = this->param_.momentum2(); - - // we create aliases for convenience - size_t update_history_offset = net_params.size(); - Blob* val_m = this->history_[param_id].get(); - Blob* val_v = this->history_[param_id + update_history_offset].get(); - Blob* val_t = this->temp_[param_id].get(); - - const int t = this->iter_ + 1; - const Dtype correction = std::sqrt(Dtype(1) - pow(beta2, t)) / - (Dtype(1.) - pow(beta1, t)); - const int N = net_params[param_id]->count(); - const Dtype eps_hat = this->param_.delta(); - - switch (Caffe::mode()) { - case Caffe::CPU: { - // update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t - caffe_cpu_axpby(N, Dtype(1)-beta1, - net_params[param_id]->cpu_diff(), beta1, - val_m->mutable_cpu_data()); - - // update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2 - caffe_mul(N, - net_params[param_id]->cpu_diff(), - net_params[param_id]->cpu_diff(), - val_t->mutable_cpu_data()); - caffe_cpu_axpby(N, Dtype(1)-beta2, - val_t->cpu_data(), beta2, - val_v->mutable_cpu_data()); - - // set update - caffe_powx(N, - val_v->cpu_data(), Dtype(0.5), - val_t->mutable_cpu_data()); - caffe_add_scalar(N, eps_hat, val_t->mutable_cpu_data()); - caffe_div(N, - val_m->cpu_data(), - val_t->cpu_data(), - val_t->mutable_cpu_data()); - - caffe_cpu_scale(N, local_rate*correction, - val_t->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); - break; - } - case Caffe::GPU: { -#ifndef CPU_ONLY - // update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t - caffe_gpu_axpby(N, Dtype(1)-beta1, - net_params[param_id]->gpu_diff(), beta1, - val_m->mutable_gpu_data()); - - // update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2 - caffe_gpu_mul(N, - net_params[param_id]->gpu_diff(), - net_params[param_id]->gpu_diff(), - val_t->mutable_gpu_data()); - caffe_gpu_axpby(N, Dtype(1)-beta2, - val_t->gpu_data(), beta2, - val_v->mutable_gpu_data()); - - // set update - caffe_gpu_powx(N, - val_v->gpu_data(), Dtype(0.5), - val_t->mutable_gpu_data()); - caffe_gpu_add_scalar(N, eps_hat, - val_t->mutable_gpu_data()); - caffe_gpu_div(N, - val_m->gpu_data(), - val_t->gpu_data(), - val_t->mutable_gpu_data()); - - caffe_gpu_scale(N, local_rate*correction, - val_t->gpu_data(), - net_params[param_id]->mutable_gpu_diff()); -#else - NO_GPU; -#endif - break; - } - default: - LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); - } -} - INSTANTIATE_CLASS(Solver); -INSTANTIATE_CLASS(SGDSolver); -INSTANTIATE_CLASS(NesterovSolver); -INSTANTIATE_CLASS(AdaGradSolver); -INSTANTIATE_CLASS(RMSPropSolver); -INSTANTIATE_CLASS(AdaDeltaSolver); -INSTANTIATE_CLASS(AdamSolver); } // namespace caffe diff --git a/src/caffe/solvers/adadelta_solver.cpp b/src/caffe/solvers/adadelta_solver.cpp new file mode 100644 index 00000000000..a37899ebbb4 --- /dev/null +++ b/src/caffe/solvers/adadelta_solver.cpp @@ -0,0 +1,156 @@ +#include + +#include "caffe/sgd_solvers.hpp" + +namespace caffe { + +template +void AdaDeltaSolver::AdaDeltaPreSolve() { + // Add the extra history entries for AdaDelta after those from + // SGDSolver::PreSolve + const vector*>& net_params = this->net_->learnable_params(); + for (int i = 0; i < net_params.size(); ++i) { + const vector& shape = net_params[i]->shape(); + this->history_.push_back( + shared_ptr >(new Blob(shape))); + } +} + +template +void AdaDeltaSolver::ComputeUpdateValue(int param_id, Dtype rate) { + const vector*>& net_params = this->net_->learnable_params(); + const vector& net_params_lr = this->net_->params_lr(); + Dtype delta = this->param_.delta(); + Dtype momentum = this->param_.momentum(); + Dtype local_rate = rate * net_params_lr[param_id]; + size_t update_history_offset = net_params.size(); + switch (Caffe::mode()) { + case Caffe::CPU: { + // compute square of gradient in update + caffe_powx(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), Dtype(2), + this->update_[param_id]->mutable_cpu_data()); + + // update history of gradients + caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, + this->update_[param_id]->cpu_data(), momentum, + this->history_[param_id]->mutable_cpu_data()); + + // add delta to history to guard against dividing by zero later + caffe_set(net_params[param_id]->count(), delta, + this->temp_[param_id]->mutable_cpu_data()); + + caffe_add(net_params[param_id]->count(), + this->temp_[param_id]->cpu_data(), + this->history_[update_history_offset + param_id]->cpu_data(), + this->update_[param_id]->mutable_cpu_data()); + + caffe_add(net_params[param_id]->count(), + this->temp_[param_id]->cpu_data(), + this->history_[param_id]->cpu_data(), + this->temp_[param_id]->mutable_cpu_data()); + + // divide history of updates by history of gradients + caffe_div(net_params[param_id]->count(), + this->update_[param_id]->cpu_data(), + this->temp_[param_id]->cpu_data(), + this->update_[param_id]->mutable_cpu_data()); + + // jointly compute the RMS of both for update and gradient history + caffe_powx(net_params[param_id]->count(), + this->update_[param_id]->cpu_data(), Dtype(0.5), + this->update_[param_id]->mutable_cpu_data()); + + // compute the update + caffe_mul(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), + this->update_[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + + // compute square of update + caffe_powx(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), Dtype(2), + this->update_[param_id]->mutable_cpu_data()); + + // update history of updates + caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, + this->update_[param_id]->cpu_data(), momentum, + this->history_[update_history_offset + param_id]->mutable_cpu_data()); + + // apply learning rate + caffe_cpu_scale(net_params[param_id]->count(), local_rate, + net_params[param_id]->cpu_diff(), + net_params[param_id]->mutable_cpu_diff()); + break; + } + case Caffe::GPU: { +#ifndef CPU_ONLY + // compute square of gradient in update + caffe_gpu_powx(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), Dtype(2), + this->update_[param_id]->mutable_gpu_data()); + + // update history of gradients + caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, + this->update_[param_id]->gpu_data(), momentum, + this->history_[param_id]->mutable_gpu_data()); + + // add delta to history to guard against dividing by zero later + caffe_gpu_set(net_params[param_id]->count(), delta, + this->temp_[param_id]->mutable_gpu_data()); + + caffe_gpu_add(net_params[param_id]->count(), + this->temp_[param_id]->gpu_data(), + this->history_[update_history_offset + param_id]->gpu_data(), + this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_add(net_params[param_id]->count(), + this->temp_[param_id]->gpu_data(), + this->history_[param_id]->gpu_data(), + this->temp_[param_id]->mutable_gpu_data()); + + // divide history of updates by history of gradients + caffe_gpu_div(net_params[param_id]->count(), + this->update_[param_id]->gpu_data(), + this->temp_[param_id]->gpu_data(), + this->update_[param_id]->mutable_gpu_data()); + + // jointly compute the RMS of both for update and gradient history + caffe_gpu_powx(net_params[param_id]->count(), + this->update_[param_id]->gpu_data(), Dtype(0.5), + this->update_[param_id]->mutable_gpu_data()); + + // compute the update and copy to net_diff + caffe_gpu_mul(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), + this->update_[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); + + // compute square of update + caffe_gpu_powx(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), Dtype(2), + this->update_[param_id]->mutable_gpu_data()); + + // update history of updates + caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) - momentum, + this->update_[param_id]->gpu_data(), momentum, + this->history_[update_history_offset + param_id]->mutable_gpu_data()); + + // apply learning rate + caffe_gpu_scale(net_params[param_id]->count(), local_rate, + net_params[param_id]->gpu_diff(), + net_params[param_id]->mutable_gpu_diff()); +#else + NO_GPU; +#endif + break; + } + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + +INSTANTIATE_CLASS(AdaDeltaSolver); +REGISTER_SOLVER_CLASS(AdaDelta); + +} // namespace caffe diff --git a/src/caffe/solvers/adagrad_solver.cpp b/src/caffe/solvers/adagrad_solver.cpp new file mode 100644 index 00000000000..5e406326095 --- /dev/null +++ b/src/caffe/solvers/adagrad_solver.cpp @@ -0,0 +1,89 @@ +#include + +#include "caffe/sgd_solvers.hpp" + +namespace caffe { + +template +void AdaGradSolver::ComputeUpdateValue(int param_id, Dtype rate) { + CHECK(Caffe::root_solver()); + const vector*>& net_params = this->net_->learnable_params(); + const vector& net_params_lr = this->net_->params_lr(); + Dtype delta = this->param_.delta(); + Dtype local_rate = rate * net_params_lr[param_id]; + switch (Caffe::mode()) { + case Caffe::CPU: { + // compute square of gradient in update + caffe_powx(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), Dtype(2), + this->update_[param_id]->mutable_cpu_data()); + + // update history + caffe_add(net_params[param_id]->count(), + this->update_[param_id]->cpu_data(), + this->history_[param_id]->cpu_data(), + this->history_[param_id]->mutable_cpu_data()); + + // prepare update + caffe_powx(net_params[param_id]->count(), + this->history_[param_id]->cpu_data(), Dtype(0.5), + this->update_[param_id]->mutable_cpu_data()); + + caffe_add_scalar(net_params[param_id]->count(), + delta, this->update_[param_id]->mutable_cpu_data()); + + caffe_div(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), + this->update_[param_id]->cpu_data(), + this->update_[param_id]->mutable_cpu_data()); + + // scale and copy + caffe_cpu_axpby(net_params[param_id]->count(), local_rate, + this->update_[param_id]->cpu_data(), Dtype(0), + net_params[param_id]->mutable_cpu_diff()); + break; + } + case Caffe::GPU: { +#ifndef CPU_ONLY + // compute square of gradient in update + caffe_gpu_powx(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), Dtype(2), + this->update_[param_id]->mutable_gpu_data()); + + // update history + caffe_gpu_add(net_params[param_id]->count(), + this->update_[param_id]->gpu_data(), + this->history_[param_id]->gpu_data(), + this->history_[param_id]->mutable_gpu_data()); + + // prepare update + caffe_gpu_powx(net_params[param_id]->count(), + this->history_[param_id]->gpu_data(), Dtype(0.5), + this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_add_scalar(net_params[param_id]->count(), + delta, this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_div(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), + this->update_[param_id]->gpu_data(), + this->update_[param_id]->mutable_gpu_data()); + + // scale and copy + caffe_gpu_axpby(net_params[param_id]->count(), local_rate, + this->update_[param_id]->gpu_data(), Dtype(0), + net_params[param_id]->mutable_gpu_diff()); +#else + NO_GPU; +#endif + break; + } + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + +INSTANTIATE_CLASS(AdaGradSolver); +REGISTER_SOLVER_CLASS(AdaGrad); + +} // namespace caffe diff --git a/src/caffe/solvers/adam_solver.cpp b/src/caffe/solvers/adam_solver.cpp new file mode 100644 index 00000000000..cb0fbfe2f78 --- /dev/null +++ b/src/caffe/solvers/adam_solver.cpp @@ -0,0 +1,113 @@ +#include + +#include "caffe/sgd_solvers.hpp" + +namespace caffe { + +template +void AdamSolver::AdamPreSolve() { + // Add the extra history entries for Adam after those from + // SGDSolver::PreSolve + const vector*>& net_params = this->net_->learnable_params(); + for (int i = 0; i < net_params.size(); ++i) { + const vector& shape = net_params[i]->shape(); + this->history_.push_back( + shared_ptr >(new Blob(shape))); + } +} + +template +void AdamSolver::ComputeUpdateValue(int param_id, Dtype rate) { + const vector*>& net_params = this->net_->learnable_params(); + const vector& net_params_lr = this->net_->params_lr(); + Dtype local_rate = rate * net_params_lr[param_id]; + const Dtype beta1 = this->param_.momentum(); + const Dtype beta2 = this->param_.momentum2(); + + // we create aliases for convenience + size_t update_history_offset = net_params.size(); + Blob* val_m = this->history_[param_id].get(); + Blob* val_v = this->history_[param_id + update_history_offset].get(); + Blob* val_t = this->temp_[param_id].get(); + + const int t = this->iter_ + 1; + const Dtype correction = std::sqrt(Dtype(1) - pow(beta2, t)) / + (Dtype(1.) - pow(beta1, t)); + const int N = net_params[param_id]->count(); + const Dtype eps_hat = this->param_.delta(); + + switch (Caffe::mode()) { + case Caffe::CPU: { + // update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t + caffe_cpu_axpby(N, Dtype(1)-beta1, + net_params[param_id]->cpu_diff(), beta1, + val_m->mutable_cpu_data()); + + // update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2 + caffe_mul(N, + net_params[param_id]->cpu_diff(), + net_params[param_id]->cpu_diff(), + val_t->mutable_cpu_data()); + caffe_cpu_axpby(N, Dtype(1)-beta2, + val_t->cpu_data(), beta2, + val_v->mutable_cpu_data()); + + // set update + caffe_powx(N, + val_v->cpu_data(), Dtype(0.5), + val_t->mutable_cpu_data()); + caffe_add_scalar(N, eps_hat, val_t->mutable_cpu_data()); + caffe_div(N, + val_m->cpu_data(), + val_t->cpu_data(), + val_t->mutable_cpu_data()); + + caffe_cpu_scale(N, local_rate*correction, + val_t->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + break; + } + case Caffe::GPU: { +#ifndef CPU_ONLY + // update m <- \beta_1 m_{t-1} + (1-\beta_1)g_t + caffe_gpu_axpby(N, Dtype(1)-beta1, + net_params[param_id]->gpu_diff(), beta1, + val_m->mutable_gpu_data()); + + // update v <- \beta_2 m_{t-1} + (1-\beta_2)g_t^2 + caffe_gpu_mul(N, + net_params[param_id]->gpu_diff(), + net_params[param_id]->gpu_diff(), + val_t->mutable_gpu_data()); + caffe_gpu_axpby(N, Dtype(1)-beta2, + val_t->gpu_data(), beta2, + val_v->mutable_gpu_data()); + + // set update + caffe_gpu_powx(N, + val_v->gpu_data(), Dtype(0.5), + val_t->mutable_gpu_data()); + caffe_gpu_add_scalar(N, eps_hat, + val_t->mutable_gpu_data()); + caffe_gpu_div(N, + val_m->gpu_data(), + val_t->gpu_data(), + val_t->mutable_gpu_data()); + + caffe_gpu_scale(N, local_rate*correction, + val_t->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); +#else + NO_GPU; +#endif + break; + } + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + +INSTANTIATE_CLASS(AdamSolver); +REGISTER_SOLVER_CLASS(Adam); + +} // namespace caffe diff --git a/src/caffe/solvers/nesterov_solver.cpp b/src/caffe/solvers/nesterov_solver.cpp new file mode 100644 index 00000000000..34bf01ebf29 --- /dev/null +++ b/src/caffe/solvers/nesterov_solver.cpp @@ -0,0 +1,71 @@ +#include + +#include "caffe/sgd_solvers.hpp" + +namespace caffe { + +template +void NesterovSolver::ComputeUpdateValue(int param_id, Dtype rate) { + CHECK(Caffe::root_solver()); + const vector*>& net_params = this->net_->learnable_params(); + const vector& net_params_lr = this->net_->params_lr(); + Dtype momentum = this->param_.momentum(); + Dtype local_rate = rate * net_params_lr[param_id]; + switch (Caffe::mode()) { + case Caffe::CPU: { + // save history momentum for stepping back + caffe_copy(net_params[param_id]->count(), + this->history_[param_id]->cpu_data(), + this->update_[param_id]->mutable_cpu_data()); + + // update history + caffe_cpu_axpby(net_params[param_id]->count(), local_rate, + net_params[param_id]->cpu_diff(), momentum, + this->history_[param_id]->mutable_cpu_data()); + + // compute update: step back then over step + caffe_cpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum, + this->history_[param_id]->cpu_data(), -momentum, + this->update_[param_id]->mutable_cpu_data()); + + // copy + caffe_copy(net_params[param_id]->count(), + this->update_[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + break; + } + case Caffe::GPU: { +#ifndef CPU_ONLY + // save history momentum for stepping back + caffe_copy(net_params[param_id]->count(), + this->history_[param_id]->gpu_data(), + this->update_[param_id]->mutable_gpu_data()); + + // update history + caffe_gpu_axpby(net_params[param_id]->count(), local_rate, + net_params[param_id]->gpu_diff(), momentum, + this->history_[param_id]->mutable_gpu_data()); + + // compute update: step back then over step + caffe_gpu_axpby(net_params[param_id]->count(), Dtype(1) + momentum, + this->history_[param_id]->gpu_data(), -momentum, + this->update_[param_id]->mutable_gpu_data()); + + // copy + caffe_copy(net_params[param_id]->count(), + this->update_[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); +#else + NO_GPU; +#endif + break; + } + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + +INSTANTIATE_CLASS(NesterovSolver); +REGISTER_SOLVER_CLASS(Nesterov); + +} // namespace caffe diff --git a/src/caffe/solvers/rmsprop_solver.cpp b/src/caffe/solvers/rmsprop_solver.cpp new file mode 100644 index 00000000000..c6247676094 --- /dev/null +++ b/src/caffe/solvers/rmsprop_solver.cpp @@ -0,0 +1,85 @@ +#include + +#include "caffe/sgd_solvers.hpp" + +namespace caffe { + +template +void RMSPropSolver::ComputeUpdateValue(int param_id, Dtype rate) { + const vector*>& net_params = this->net_->learnable_params(); + const vector& net_params_lr = this->net_->params_lr(); + + // get the learning rate + Dtype delta = this->param_.delta(); + Dtype rms_decay = this->param_.rms_decay(); + Dtype local_rate = rate * net_params_lr[param_id]; + + switch (Caffe::mode()) { + case Caffe::CPU: + // compute square of gradient in update + caffe_powx(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), Dtype(2), + this->update_[param_id]->mutable_cpu_data()); + + // update history + caffe_cpu_axpby(net_params[param_id] -> count(), + Dtype(1-rms_decay), this->update_[param_id]->cpu_data(), + rms_decay, this->history_[param_id]-> mutable_cpu_data()); + + // prepare update + caffe_powx(net_params[param_id]->count(), + this->history_[param_id]->cpu_data(), Dtype(0.5), + this->update_[param_id]->mutable_cpu_data()); + + caffe_add_scalar(net_params[param_id]->count(), + delta, this->update_[param_id]->mutable_cpu_data()); + + caffe_div(net_params[param_id]->count(), + net_params[param_id]->cpu_diff(), this->update_[param_id]->cpu_data(), + this->update_[param_id]->mutable_cpu_data()); + + // scale and copy + caffe_cpu_axpby(net_params[param_id]->count(), local_rate, + this->update_[param_id]->cpu_data(), Dtype(0), + net_params[param_id]->mutable_cpu_diff()); + break; + case Caffe::GPU: +#ifndef CPU_ONLY + // compute square of gradient in update + caffe_gpu_powx(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), Dtype(2), + this->update_[param_id]->mutable_gpu_data()); + + // update history + caffe_gpu_axpby(net_params[param_id] -> count(), + Dtype(1-rms_decay), this->update_[param_id]->gpu_data(), + rms_decay, this->history_[param_id]-> mutable_gpu_data()); + + // prepare update + caffe_gpu_powx(net_params[param_id]->count(), + this->history_[param_id]->gpu_data(), Dtype(0.5), + this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_add_scalar(net_params[param_id]->count(), + delta, this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_div(net_params[param_id]->count(), + net_params[param_id]->gpu_diff(), this->update_[param_id]->gpu_data(), + this->update_[param_id]->mutable_gpu_data()); + + caffe_gpu_axpby(net_params[param_id]->count(), local_rate, + this->update_[param_id]->gpu_data(), Dtype(0), + net_params[param_id]->mutable_gpu_diff()); +#else + NO_GPU; +#endif + break; + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + +INSTANTIATE_CLASS(RMSPropSolver); +REGISTER_SOLVER_CLASS(RMSProp); + +} // namespace caffe diff --git a/src/caffe/solvers/sgd_solver.cpp b/src/caffe/solvers/sgd_solver.cpp new file mode 100644 index 00000000000..32bf19b17c8 --- /dev/null +++ b/src/caffe/solvers/sgd_solver.cpp @@ -0,0 +1,348 @@ +#include +#include + +#include "caffe/sgd_solvers.hpp" +#include "caffe/util/hdf5.hpp" +#include "caffe/util/io.hpp" +#include "caffe/util/upgrade_proto.hpp" + +namespace caffe { + +// Return the current learning rate. The currently implemented learning rate +// policies are as follows: +// - fixed: always return base_lr. +// - step: return base_lr * gamma ^ (floor(iter / step)) +// - exp: return base_lr * gamma ^ iter +// - inv: return base_lr * (1 + gamma * iter) ^ (- power) +// - multistep: similar to step but it allows non uniform steps defined by +// stepvalue +// - poly: the effective learning rate follows a polynomial decay, to be +// zero by the max_iter. return base_lr (1 - iter/max_iter) ^ (power) +// - sigmoid: the effective learning rate follows a sigmod decay +// return base_lr ( 1/(1 + exp(-gamma * (iter - stepsize)))) +// +// where base_lr, max_iter, gamma, step, stepvalue and power are defined +// in the solver parameter protocol buffer, and iter is the current iteration. +template +Dtype SGDSolver::GetLearningRate() { + Dtype rate; + const string& lr_policy = this->param_.lr_policy(); + if (lr_policy == "fixed") { + rate = this->param_.base_lr(); + } else if (lr_policy == "step") { + this->current_step_ = this->iter_ / this->param_.stepsize(); + rate = this->param_.base_lr() * + pow(this->param_.gamma(), this->current_step_); + } else if (lr_policy == "exp") { + rate = this->param_.base_lr() * pow(this->param_.gamma(), this->iter_); + } else if (lr_policy == "inv") { + rate = this->param_.base_lr() * + pow(Dtype(1) + this->param_.gamma() * this->iter_, + - this->param_.power()); + } else if (lr_policy == "multistep") { + if (this->current_step_ < this->param_.stepvalue_size() && + this->iter_ >= this->param_.stepvalue(this->current_step_)) { + this->current_step_++; + LOG(INFO) << "MultiStep Status: Iteration " << + this->iter_ << ", step = " << this->current_step_; + } + rate = this->param_.base_lr() * + pow(this->param_.gamma(), this->current_step_); + } else if (lr_policy == "poly") { + rate = this->param_.base_lr() * pow(Dtype(1.) - + (Dtype(this->iter_) / Dtype(this->param_.max_iter())), + this->param_.power()); + } else if (lr_policy == "sigmoid") { + rate = this->param_.base_lr() * (Dtype(1.) / + (Dtype(1.) + exp(-this->param_.gamma() * (Dtype(this->iter_) - + Dtype(this->param_.stepsize()))))); + } else { + LOG(FATAL) << "Unknown learning rate policy: " << lr_policy; + } + return rate; +} + +template +void SGDSolver::PreSolve() { + // Initialize the history + const vector*>& net_params = this->net_->learnable_params(); + history_.clear(); + update_.clear(); + temp_.clear(); + for (int i = 0; i < net_params.size(); ++i) { + const vector& shape = net_params[i]->shape(); + history_.push_back(shared_ptr >(new Blob(shape))); + update_.push_back(shared_ptr >(new Blob(shape))); + temp_.push_back(shared_ptr >(new Blob(shape))); + } +} + +template +void SGDSolver::ClipGradients() { + const Dtype clip_gradients = this->param_.clip_gradients(); + if (clip_gradients < 0) { return; } + const vector*>& net_params = this->net_->learnable_params(); + Dtype sumsq_diff = 0; + for (int i = 0; i < net_params.size(); ++i) { + sumsq_diff += net_params[i]->sumsq_diff(); + } + const Dtype l2norm_diff = std::sqrt(sumsq_diff); + if (l2norm_diff > clip_gradients) { + Dtype scale_factor = clip_gradients / l2norm_diff; + LOG(INFO) << "Gradient clipping: scaling down gradients (L2 norm " + << l2norm_diff << " > " << clip_gradients << ") " + << "by scale factor " << scale_factor; + for (int i = 0; i < net_params.size(); ++i) { + net_params[i]->scale_diff(scale_factor); + } + } +} + +template +void SGDSolver::ApplyUpdate() { + CHECK(Caffe::root_solver()); + Dtype rate = GetLearningRate(); + if (this->param_.display() && this->iter_ % this->param_.display() == 0) { + LOG(INFO) << "Iteration " << this->iter_ << ", lr = " << rate; + } + ClipGradients(); + for (int param_id = 0; param_id < this->net_->learnable_params().size(); + ++param_id) { + Normalize(param_id); + Regularize(param_id); + ComputeUpdateValue(param_id, rate); + } + this->net_->Update(); +} + +template +void SGDSolver::Normalize(int param_id) { + if (this->param_.iter_size() == 1) { return; } + // Scale gradient to counterbalance accumulation. + const vector*>& net_params = this->net_->learnable_params(); + const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size(); + switch (Caffe::mode()) { + case Caffe::CPU: { + caffe_scal(net_params[param_id]->count(), accum_normalization, + net_params[param_id]->mutable_cpu_diff()); + break; + } + case Caffe::GPU: { +#ifndef CPU_ONLY + caffe_gpu_scal(net_params[param_id]->count(), accum_normalization, + net_params[param_id]->mutable_gpu_diff()); +#else + NO_GPU; +#endif + break; + } + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + +template +void SGDSolver::Regularize(int param_id) { + const vector*>& net_params = this->net_->learnable_params(); + const vector& net_params_weight_decay = + this->net_->params_weight_decay(); + Dtype weight_decay = this->param_.weight_decay(); + string regularization_type = this->param_.regularization_type(); + Dtype local_decay = weight_decay * net_params_weight_decay[param_id]; + switch (Caffe::mode()) { + case Caffe::CPU: { + if (local_decay) { + if (regularization_type == "L2") { + // add weight decay + caffe_axpy(net_params[param_id]->count(), + local_decay, + net_params[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + } else if (regularization_type == "L1") { + caffe_cpu_sign(net_params[param_id]->count(), + net_params[param_id]->cpu_data(), + temp_[param_id]->mutable_cpu_data()); + caffe_axpy(net_params[param_id]->count(), + local_decay, + temp_[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + } else { + LOG(FATAL) << "Unknown regularization type: " << regularization_type; + } + } + break; + } + case Caffe::GPU: { +#ifndef CPU_ONLY + if (local_decay) { + if (regularization_type == "L2") { + // add weight decay + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay, + net_params[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); + } else if (regularization_type == "L1") { + caffe_gpu_sign(net_params[param_id]->count(), + net_params[param_id]->gpu_data(), + temp_[param_id]->mutable_gpu_data()); + caffe_gpu_axpy(net_params[param_id]->count(), + local_decay, + temp_[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); + } else { + LOG(FATAL) << "Unknown regularization type: " << regularization_type; + } + } +#else + NO_GPU; +#endif + break; + } + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + +template +void SGDSolver::ComputeUpdateValue(int param_id, Dtype rate) { + const vector*>& net_params = this->net_->learnable_params(); + const vector& net_params_lr = this->net_->params_lr(); + Dtype momentum = this->param_.momentum(); + Dtype local_rate = rate * net_params_lr[param_id]; + // Compute the update to history, then copy it to the parameter diff. + switch (Caffe::mode()) { + case Caffe::CPU: { + caffe_cpu_axpby(net_params[param_id]->count(), local_rate, + net_params[param_id]->cpu_diff(), momentum, + history_[param_id]->mutable_cpu_data()); + caffe_copy(net_params[param_id]->count(), + history_[param_id]->cpu_data(), + net_params[param_id]->mutable_cpu_diff()); + break; + } + case Caffe::GPU: { +#ifndef CPU_ONLY + caffe_gpu_axpby(net_params[param_id]->count(), local_rate, + net_params[param_id]->gpu_diff(), momentum, + history_[param_id]->mutable_gpu_data()); + caffe_copy(net_params[param_id]->count(), + history_[param_id]->gpu_data(), + net_params[param_id]->mutable_gpu_diff()); +#else + NO_GPU; +#endif + break; + } + default: + LOG(FATAL) << "Unknown caffe mode: " << Caffe::mode(); + } +} + +template +void SGDSolver::SnapshotSolverState(const string& model_filename) { + switch (this->param_.snapshot_format()) { + case caffe::SolverParameter_SnapshotFormat_BINARYPROTO: + SnapshotSolverStateToBinaryProto(model_filename); + break; + case caffe::SolverParameter_SnapshotFormat_HDF5: + SnapshotSolverStateToHDF5(model_filename); + break; + default: + LOG(FATAL) << "Unsupported snapshot format."; + } +} + +template +void SGDSolver::SnapshotSolverStateToBinaryProto( + const string& model_filename) { + SolverState state; + state.set_iter(this->iter_); + state.set_learned_net(model_filename); + state.set_current_step(this->current_step_); + state.clear_history(); + for (int i = 0; i < history_.size(); ++i) { + // Add history + BlobProto* history_blob = state.add_history(); + history_[i]->ToProto(history_blob); + } + string snapshot_filename = Solver::SnapshotFilename(".solverstate"); + LOG(INFO) + << "Snapshotting solver state to binary proto file " << snapshot_filename; + WriteProtoToBinaryFile(state, snapshot_filename.c_str()); +} + +template +void SGDSolver::SnapshotSolverStateToHDF5( + const string& model_filename) { + string snapshot_filename = + Solver::SnapshotFilename(".solverstate.h5"); + LOG(INFO) << "Snapshotting solver state to HDF5 file " << snapshot_filename; + hid_t file_hid = H5Fcreate(snapshot_filename.c_str(), H5F_ACC_TRUNC, + H5P_DEFAULT, H5P_DEFAULT); + CHECK_GE(file_hid, 0) + << "Couldn't open " << snapshot_filename << " to save solver state."; + hdf5_save_int(file_hid, "iter", this->iter_); + hdf5_save_string(file_hid, "learned_net", model_filename); + hdf5_save_int(file_hid, "current_step", this->current_step_); + hid_t history_hid = H5Gcreate2(file_hid, "history", H5P_DEFAULT, H5P_DEFAULT, + H5P_DEFAULT); + CHECK_GE(history_hid, 0) + << "Error saving solver state to " << snapshot_filename << "."; + for (int i = 0; i < history_.size(); ++i) { + ostringstream oss; + oss << i; + hdf5_save_nd_dataset(history_hid, oss.str(), *history_[i]); + } + H5Gclose(history_hid); + H5Fclose(file_hid); +} + +template +void SGDSolver::RestoreSolverStateFromBinaryProto( + const string& state_file) { + SolverState state; + ReadProtoFromBinaryFile(state_file, &state); + this->iter_ = state.iter(); + if (state.has_learned_net()) { + NetParameter net_param; + ReadNetParamsFromBinaryFileOrDie(state.learned_net().c_str(), &net_param); + this->net_->CopyTrainedLayersFrom(net_param); + } + this->current_step_ = state.current_step(); + CHECK_EQ(state.history_size(), history_.size()) + << "Incorrect length of history blobs."; + LOG(INFO) << "SGDSolver: restoring history"; + for (int i = 0; i < history_.size(); ++i) { + history_[i]->FromProto(state.history(i)); + } +} + +template +void SGDSolver::RestoreSolverStateFromHDF5(const string& state_file) { + hid_t file_hid = H5Fopen(state_file.c_str(), H5F_ACC_RDONLY, H5P_DEFAULT); + CHECK_GE(file_hid, 0) << "Couldn't open solver state file " << state_file; + this->iter_ = hdf5_load_int(file_hid, "iter"); + if (H5LTfind_dataset(file_hid, "learned_net")) { + string learned_net = hdf5_load_string(file_hid, "learned_net"); + this->net_->CopyTrainedLayersFrom(learned_net); + } + this->current_step_ = hdf5_load_int(file_hid, "current_step"); + hid_t history_hid = H5Gopen2(file_hid, "history", H5P_DEFAULT); + CHECK_GE(history_hid, 0) << "Error reading history from " << state_file; + int state_history_size = hdf5_get_num_links(history_hid); + CHECK_EQ(state_history_size, history_.size()) + << "Incorrect length of history blobs."; + for (int i = 0; i < history_.size(); ++i) { + ostringstream oss; + oss << i; + hdf5_load_nd_dataset(history_hid, oss.str().c_str(), 0, + kMaxBlobAxes, history_[i].get()); + } + H5Gclose(history_hid); + H5Fclose(file_hid); +} + +INSTANTIATE_CLASS(SGDSolver); +REGISTER_SOLVER_CLASS(SGD); + +} // namespace caffe diff --git a/src/caffe/test/test_gradient_based_solver.cpp b/src/caffe/test/test_gradient_based_solver.cpp index 7ad7467f86f..84c6747f61a 100644 --- a/src/caffe/test/test_gradient_based_solver.cpp +++ b/src/caffe/test/test_gradient_based_solver.cpp @@ -10,7 +10,7 @@ #include "caffe/common.hpp" #include "caffe/parallel.hpp" #include "caffe/proto/caffe.pb.h" -#include "caffe/solver.hpp" +#include "caffe/sgd_solvers.hpp" #include "caffe/util/io.hpp" #include "caffe/test/test_caffe_main.hpp" @@ -47,7 +47,6 @@ class GradientBasedSolverTest : public MultiDeviceTest { // Test data: check out generate_sample_data.py in the same directory. string* input_file_; - virtual SolverParameter_SolverType solver_type() = 0; virtual void InitSolver(const SolverParameter& param) = 0; virtual void InitSolverFromProtoString(const string& proto) { @@ -290,8 +289,8 @@ class GradientBasedSolverTest : public MultiDeviceTest { ((i == D) ? bias.cpu_data()[0] : weights.cpu_data()[i]); // Finally, compute update. const vector > >& history = solver_->history(); - if (solver_type() != SolverParameter_SolverType_ADADELTA - && solver_type() != SolverParameter_SolverType_ADAM) { + if (solver_->type() != string("AdaDelta") + && solver_->type() != string("Adam")) { ASSERT_EQ(2, history.size()); // 1 blob for weights, 1 for bias } else { ASSERT_EQ(4, history.size()); // additional blobs for update history @@ -300,26 +299,19 @@ class GradientBasedSolverTest : public MultiDeviceTest { const Dtype history_value = (i == D) ? history[1]->cpu_data()[0] : history[0]->cpu_data()[i]; const Dtype temp = momentum * history_value; - switch (solver_type()) { - case SolverParameter_SolverType_SGD: + if (solver_->type() == string("SGD")) { update_value += temp; - break; - case SolverParameter_SolverType_NESTEROV: + } else if (solver_->type() == string("Nesterov")) { update_value += temp; // step back then over-step update_value = (1 + momentum) * update_value - temp; - break; - case SolverParameter_SolverType_ADAGRAD: + } else if (solver_->type() == string("AdaGrad")) { update_value /= std::sqrt(history_value + grad * grad) + delta_; - break; - case SolverParameter_SolverType_RMSPROP: { + } else if (solver_->type() == string("RMSProp")) { const Dtype rms_decay = 0.95; update_value /= std::sqrt(rms_decay*history_value + grad * grad * (1 - rms_decay)) + delta_; - } - break; - case SolverParameter_SolverType_ADADELTA: - { + } else if (solver_->type() == string("AdaDelta")) { const Dtype update_history_value = (i == D) ? history[1 + num_param_blobs]->cpu_data()[0] : history[0 + num_param_blobs]->cpu_data()[i]; @@ -330,9 +322,7 @@ class GradientBasedSolverTest : public MultiDeviceTest { // not actually needed, just here for illustrative purposes // const Dtype weighted_update_average = // momentum * update_history_value + (1 - momentum) * (update_value); - break; - } - case SolverParameter_SolverType_ADAM: { + } else if (solver_->type() == string("Adam")) { const Dtype momentum2 = 0.999; const Dtype m = history_value; const Dtype v = (i == D) ? @@ -344,10 +334,8 @@ class GradientBasedSolverTest : public MultiDeviceTest { std::sqrt(Dtype(1) - pow(momentum2, num_iters)) / (Dtype(1.) - pow(momentum, num_iters)); update_value = alpha_t * val_m / (std::sqrt(val_v) + delta_); - break; - } - default: - LOG(FATAL) << "Unknown solver type: " << solver_type(); + } else { + LOG(FATAL) << "Unknown solver type: " << solver_->type(); } if (i == D) { updated_bias.mutable_cpu_diff()[0] = update_value; @@ -392,7 +380,7 @@ class GradientBasedSolverTest : public MultiDeviceTest { EXPECT_NEAR(expected_updated_bias, solver_updated_bias, error_margin); // Check the solver's history -- should contain the previous update value. - if (solver_type() == SolverParameter_SolverType_SGD) { + if (solver_->type() == string("SGD")) { const vector > >& history = solver_->history(); ASSERT_EQ(2, history.size()); for (int i = 0; i < D; ++i) { @@ -581,10 +569,6 @@ class SGDSolverTest : public GradientBasedSolverTest { virtual void InitSolver(const SolverParameter& param) { this->solver_.reset(new SGDSolver(param)); } - - virtual SolverParameter_SolverType solver_type() { - return SolverParameter_SolverType_SGD; - } }; TYPED_TEST_CASE(SGDSolverTest, TestDtypesAndDevices); @@ -721,9 +705,6 @@ class AdaGradSolverTest : public GradientBasedSolverTest { virtual void InitSolver(const SolverParameter& param) { this->solver_.reset(new AdaGradSolver(param)); } - virtual SolverParameter_SolverType solver_type() { - return SolverParameter_SolverType_ADAGRAD; - } }; TYPED_TEST_CASE(AdaGradSolverTest, TestDtypesAndDevices); @@ -824,9 +805,6 @@ class NesterovSolverTest : public GradientBasedSolverTest { virtual void InitSolver(const SolverParameter& param) { this->solver_.reset(new NesterovSolver(param)); } - virtual SolverParameter_SolverType solver_type() { - return SolverParameter_SolverType_NESTEROV; - } }; TYPED_TEST_CASE(NesterovSolverTest, TestDtypesAndDevices); @@ -960,10 +938,6 @@ class AdaDeltaSolverTest : public GradientBasedSolverTest { virtual void InitSolver(const SolverParameter& param) { this->solver_.reset(new AdaDeltaSolver(param)); } - - virtual SolverParameter_SolverType solver_type() { - return SolverParameter_SolverType_ADADELTA; - } }; TYPED_TEST_CASE(AdaDeltaSolverTest, TestDtypesAndDevices); @@ -1098,9 +1072,6 @@ class AdamSolverTest : public GradientBasedSolverTest { new_param.set_momentum2(momentum2); this->solver_.reset(new AdamSolver(new_param)); } - virtual SolverParameter_SolverType solver_type() { - return SolverParameter_SolverType_ADAM; - } }; TYPED_TEST_CASE(AdamSolverTest, TestDtypesAndDevices); @@ -1201,9 +1172,6 @@ class RMSPropSolverTest : public GradientBasedSolverTest { new_param.set_rms_decay(rms_decay); this->solver_.reset(new RMSPropSolver(new_param)); } - virtual SolverParameter_SolverType solver_type() { - return SolverParameter_SolverType_RMSPROP; - } }; TYPED_TEST_CASE(RMSPropSolverTest, TestDtypesAndDevices); diff --git a/src/caffe/test/test_solver.cpp b/src/caffe/test/test_solver.cpp index ceabc9cdd2c..b181642681c 100644 --- a/src/caffe/test/test_solver.cpp +++ b/src/caffe/test/test_solver.cpp @@ -7,6 +7,7 @@ #include "caffe/common.hpp" #include "caffe/proto/caffe.pb.h" +#include "caffe/sgd_solvers.hpp" #include "caffe/solver.hpp" #include "caffe/test/test_caffe_main.hpp" diff --git a/src/caffe/test/test_solver_factory.cpp b/src/caffe/test/test_solver_factory.cpp new file mode 100644 index 00000000000..eef5290fe2e --- /dev/null +++ b/src/caffe/test/test_solver_factory.cpp @@ -0,0 +1,50 @@ +#include +#include + +#include "boost/scoped_ptr.hpp" +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" + +#include "caffe/common.hpp" +#include "caffe/solver.hpp" +#include "caffe/solver_factory.hpp" + +#include "caffe/test/test_caffe_main.hpp" + +namespace caffe { + +template +class SolverFactoryTest : public MultiDeviceTest { + protected: + SolverParameter simple_solver_param() { + const string solver_proto = + "train_net_param { " + " layer { " + " name: 'data' type: 'DummyData' top: 'data' " + " dummy_data_param { shape { dim: 1 } } " + " } " + "} "; + SolverParameter solver_param; + CHECK(google::protobuf::TextFormat::ParseFromString( + solver_proto, &solver_param)); + return solver_param; + } +}; + +TYPED_TEST_CASE(SolverFactoryTest, TestDtypesAndDevices); + +TYPED_TEST(SolverFactoryTest, TestCreateSolver) { + typedef typename TypeParam::Dtype Dtype; + typename SolverRegistry::CreatorRegistry& registry = + SolverRegistry::Registry(); + shared_ptr > solver; + SolverParameter solver_param = this->simple_solver_param(); + for (typename SolverRegistry::CreatorRegistry::iterator iter = + registry.begin(); iter != registry.end(); ++iter) { + solver_param.set_type(iter->first); + solver.reset(SolverRegistry::CreateSolver(solver_param)); + EXPECT_EQ(iter->first, solver->type()); + } +} + +} // namespace caffe diff --git a/src/caffe/test/test_upgrade_proto.cpp b/src/caffe/test/test_upgrade_proto.cpp index ee05b151e72..df9aeb62464 100644 --- a/src/caffe/test/test_upgrade_proto.cpp +++ b/src/caffe/test/test_upgrade_proto.cpp @@ -2928,4 +2928,65 @@ TEST_F(NetUpgradeTest, TestUpgradeV1LayerType) { } } #endif // USE_OPENCV + +class SolverTypeUpgradeTest : public ::testing::Test { + protected: + void RunSolverTypeUpgradeTest( + const string& input_param_string, const string& output_param_string) { + // Test upgrading old solver_type field (enum) to new type field (string) + SolverParameter input_param; + CHECK(google::protobuf::TextFormat::ParseFromString( + input_param_string, &input_param)); + SolverParameter expected_output_param; + CHECK(google::protobuf::TextFormat::ParseFromString( + output_param_string, &expected_output_param)); + SolverParameter actual_output_param = input_param; + UpgradeSolverType(&actual_output_param); + EXPECT_EQ(expected_output_param.DebugString(), + actual_output_param.DebugString()); + } +}; + +TEST_F(SolverTypeUpgradeTest, TestSimple) { + const char* old_type_vec[6] = { "SGD", "ADAGRAD", "NESTEROV", "RMSPROP", + "ADADELTA", "ADAM" }; + const char* new_type_vec[6] = { "SGD", "AdaGrad", "Nesterov", "RMSProp", + "AdaDelta", "Adam" }; + for (int i = 0; i < 6; ++i) { + const string& input_proto = + "net: 'examples/mnist/lenet_train_test.prototxt' " + "test_iter: 100 " + "test_interval: 500 " + "base_lr: 0.01 " + "momentum: 0.0 " + "weight_decay: 0.0005 " + "lr_policy: 'inv' " + "gamma: 0.0001 " + "power: 0.75 " + "display: 100 " + "max_iter: 10000 " + "snapshot: 5000 " + "snapshot_prefix: 'examples/mnist/lenet_rmsprop' " + "solver_mode: GPU " + "solver_type: " + std::string(old_type_vec[i]) + " "; + const string& expected_output_proto = + "net: 'examples/mnist/lenet_train_test.prototxt' " + "test_iter: 100 " + "test_interval: 500 " + "base_lr: 0.01 " + "momentum: 0.0 " + "weight_decay: 0.0005 " + "lr_policy: 'inv' " + "gamma: 0.0001 " + "power: 0.75 " + "display: 100 " + "max_iter: 10000 " + "snapshot: 5000 " + "snapshot_prefix: 'examples/mnist/lenet_rmsprop' " + "solver_mode: GPU " + "type: '" + std::string(new_type_vec[i]) + "' "; + this->RunSolverTypeUpgradeTest(input_proto, expected_output_proto); + } +} + } // NOLINT(readability/fn_size) // namespace caffe diff --git a/src/caffe/util/upgrade_proto.cpp b/src/caffe/util/upgrade_proto.cpp index 6eae9fec00a..ff3f8ffc4f0 100644 --- a/src/caffe/util/upgrade_proto.cpp +++ b/src/caffe/util/upgrade_proto.cpp @@ -937,4 +937,78 @@ const char* UpgradeV1LayerType(const V1LayerParameter_LayerType type) { } } +// Return true iff the solver contains any old solver_type specified as enums +bool SolverNeedsTypeUpgrade(const SolverParameter& solver_param) { + if (solver_param.has_solver_type()) { + return true; + } + return false; +} + +bool UpgradeSolverType(SolverParameter* solver_param) { + CHECK(!solver_param->has_solver_type() || !solver_param->has_type()) + << "Failed to upgrade solver: old solver_type field (enum) and new type " + << "field (string) cannot be both specified in solver proto text."; + if (solver_param->has_solver_type()) { + string type; + switch (solver_param->solver_type()) { + case SolverParameter_SolverType_SGD: + type = "SGD"; + break; + case SolverParameter_SolverType_NESTEROV: + type = "Nesterov"; + break; + case SolverParameter_SolverType_ADAGRAD: + type = "AdaGrad"; + break; + case SolverParameter_SolverType_RMSPROP: + type = "RMSProp"; + break; + case SolverParameter_SolverType_ADADELTA: + type = "AdaDelta"; + break; + case SolverParameter_SolverType_ADAM: + type = "Adam"; + break; + default: + LOG(FATAL) << "Unknown SolverParameter solver_type: " << type; + } + solver_param->set_type(type); + solver_param->clear_solver_type(); + } else { + LOG(ERROR) << "Warning: solver type already up to date. "; + return false; + } + return true; +} + +// Check for deprecations and upgrade the SolverParameter as needed. +bool UpgradeSolverAsNeeded(const string& param_file, SolverParameter* param) { + bool success = true; + // Try to upgrade old style solver_type enum fields into new string type + if (SolverNeedsTypeUpgrade(*param)) { + LOG(INFO) << "Attempting to upgrade input file specified using deprecated " + << "'solver_type' field (enum)': " << param_file; + if (!UpgradeSolverType(param)) { + success = false; + LOG(ERROR) << "Warning: had one or more problems upgrading " + << "SolverType (see above)."; + } else { + LOG(INFO) << "Successfully upgraded file specified using deprecated " + << "'solver_type' field (enum) to 'type' field (string)."; + LOG(WARNING) << "Note that future Caffe releases will only support " + << "'type' field (string) for a solver's type."; + } + } + return success; +} + +// Read parameters from a file into a SolverParameter proto message. +void ReadSolverParamsFromTextFileOrDie(const string& param_file, + SolverParameter* param) { + CHECK(ReadProtoFromTextFile(param_file, param)) + << "Failed to parse SolverParameter file: " << param_file; + UpgradeSolverAsNeeded(param_file, param); +} + } // namespace caffe diff --git a/tools/caffe.cpp b/tools/caffe.cpp index e3f684b5ab3..305cfc3635d 100644 --- a/tools/caffe.cpp +++ b/tools/caffe.cpp @@ -157,7 +157,7 @@ int train() { "but not both."; caffe::SolverParameter solver_param; - caffe::ReadProtoFromTextFileOrDie(FLAGS_solver, &solver_param); + caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param); // If the gpus flag is not provided, allow the mode and device to be set // in the solver prototxt. @@ -194,7 +194,7 @@ int train() { GetRequestedAction(FLAGS_sighup_effect)); shared_ptr > - solver(caffe::GetSolver(solver_param)); + solver(caffe::SolverRegistry::CreateSolver(solver_param)); solver->SetActionFunction(signal_handler.GetActionFunction()); diff --git a/tools/upgrade_solver_proto_text.cpp b/tools/upgrade_solver_proto_text.cpp new file mode 100644 index 00000000000..7130232aed7 --- /dev/null +++ b/tools/upgrade_solver_proto_text.cpp @@ -0,0 +1,50 @@ +// This is a script to upgrade old solver prototxts to the new format. +// Usage: +// upgrade_solver_proto_text old_solver_proto_file_in solver_proto_file_out + +#include +#include // NOLINT(readability/streams) +#include // NOLINT(readability/streams) +#include + +#include "caffe/caffe.hpp" +#include "caffe/util/io.hpp" +#include "caffe/util/upgrade_proto.hpp" + +using std::ofstream; + +using namespace caffe; // NOLINT(build/namespaces) + +int main(int argc, char** argv) { + ::google::InitGoogleLogging(argv[0]); + if (argc != 3) { + LOG(ERROR) << "Usage: upgrade_solver_proto_text " + << "old_solver_proto_file_in solver_proto_file_out"; + return 1; + } + + SolverParameter solver_param; + string input_filename(argv[1]); + if (!ReadProtoFromTextFile(input_filename, &solver_param)) { + LOG(ERROR) << "Failed to parse input text file as SolverParameter: " + << input_filename; + return 2; + } + bool need_upgrade = SolverNeedsTypeUpgrade(solver_param); + bool success = true; + if (need_upgrade) { + success = UpgradeSolverAsNeeded(input_filename, &solver_param); + if (!success) { + LOG(ERROR) << "Encountered error(s) while upgrading prototxt; " + << "see details above."; + } + } else { + LOG(ERROR) << "File already in latest proto format: " << input_filename; + } + + // Save new format prototxt. + WriteProtoToTextFile(solver_param, argv[2]); + + LOG(ERROR) << "Wrote upgraded SolverParameter text proto to " << argv[2]; + return !success; +}