Skip to content

Commit

Permalink
[pycaffe] expose mutable solver parameter, base lr, and effective lr
Browse files Browse the repository at this point in the history
`solver.lr` is the effective learning rate in use while `solver.base_lr`
is the configured learning rate at initialization. the solver parameter
is now editable for setting fields that are in use throughout the
lifetime of the solver, such as the maximum iteration.
  • Loading branch information
mitar authored and shelhamer committed Jun 7, 2018
1 parent c74913d commit cfcf74f
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 9 deletions.
2 changes: 1 addition & 1 deletion include/caffe/sgd_solvers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@ class SGDSolver : public Solver<Dtype> {
const vector<shared_ptr<Blob<Dtype> > >& history() { return history_; }

virtual void ApplyUpdate();
Dtype GetLearningRate();

protected:
void PreSolve();
Dtype GetLearningRate();
virtual void Normalize(int param_id);
virtual void Regularize(int param_id);
virtual void ComputeUpdateValue(int param_id, Dtype rate);
Expand Down
19 changes: 11 additions & 8 deletions python/caffe/_caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,9 @@ BOOST_PYTHON_MODULE(_caffe) {
bp::class_<SolverParameter>("SolverParameter", bp::no_init)
.add_property("max_iter", &SolverParameter::max_iter)
.add_property("display", &SolverParameter::display)
.add_property("layer_wise_reduce", &SolverParameter::layer_wise_reduce);
.add_property("layer_wise_reduce", &SolverParameter::layer_wise_reduce)
.add_property("base_lr", &SolverParameter::base_lr,
&SolverParameter::set_base_lr);
bp::class_<LayerParameter>("LayerParameter", bp::no_init);

bp::class_<Solver<Dtype>, shared_ptr<Solver<Dtype> >, boost::noncopyable>(
Expand All @@ -509,25 +511,26 @@ BOOST_PYTHON_MODULE(_caffe) {
.def("share_weights", &share_weights)
.def("apply_update", &Solver<Dtype>::ApplyUpdate)
.add_property("param", bp::make_function(&Solver<Dtype>::param,
bp::return_value_policy<bp::copy_const_reference>()));
bp::return_internal_reference<>()));
BP_REGISTER_SHARED_PTR_TO_PYTHON(Solver<Dtype>);

bp::class_<SGDSolver<Dtype>, bp::bases<Solver<Dtype> >,
shared_ptr<SGDSolver<Dtype> >, boost::noncopyable>(
"SGDSolver", bp::init<string>());
bp::class_<NesterovSolver<Dtype>, bp::bases<Solver<Dtype> >,
"SGDSolver", bp::init<string>())
.add_property("lr", &SGDSolver<Dtype>::GetLearningRate);
bp::class_<NesterovSolver<Dtype>, bp::bases<SGDSolver<Dtype> >,
shared_ptr<NesterovSolver<Dtype> >, boost::noncopyable>(
"NesterovSolver", bp::init<string>());
bp::class_<AdaGradSolver<Dtype>, bp::bases<Solver<Dtype> >,
bp::class_<AdaGradSolver<Dtype>, bp::bases<SGDSolver<Dtype> >,
shared_ptr<AdaGradSolver<Dtype> >, boost::noncopyable>(
"AdaGradSolver", bp::init<string>());
bp::class_<RMSPropSolver<Dtype>, bp::bases<Solver<Dtype> >,
bp::class_<RMSPropSolver<Dtype>, bp::bases<SGDSolver<Dtype> >,
shared_ptr<RMSPropSolver<Dtype> >, boost::noncopyable>(
"RMSPropSolver", bp::init<string>());
bp::class_<AdaDeltaSolver<Dtype>, bp::bases<Solver<Dtype> >,
bp::class_<AdaDeltaSolver<Dtype>, bp::bases<SGDSolver<Dtype> >,
shared_ptr<AdaDeltaSolver<Dtype> >, boost::noncopyable>(
"AdaDeltaSolver", bp::init<string>());
bp::class_<AdamSolver<Dtype>, bp::bases<Solver<Dtype> >,
bp::class_<AdamSolver<Dtype>, bp::bases<SGDSolver<Dtype> >,
shared_ptr<AdamSolver<Dtype> >, boost::noncopyable>(
"AdamSolver", bp::init<string>());

Expand Down

0 comments on commit cfcf74f

Please sign in to comment.