Skip to content

Commit

Permalink
Merge pull request #3082 from gustavla/pycaffe-snapshot
Browse files Browse the repository at this point in the history
Expose `Solver::Snapshot` to pycaffe
  • Loading branch information
shelhamer committed Oct 31, 2015
2 parents ca4e342 + 19d9927 commit f5fd18b
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 7 deletions.
10 changes: 5 additions & 5 deletions include/caffe/solver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ class Solver {
// RestoreSolverStateFrom___ protected methods. You should implement these
// methods to restore the state from the appropriate snapshot type.
void Restore(const char* resume_file);
// The Solver::Snapshot function implements the basic snapshotting utility
// that stores the learned net. You should implement the SnapshotSolverState()
// function that produces a SolverState protocol buffer that needs to be
// written to disk together with the learned net.
void Snapshot();
virtual ~Solver() {}
inline const SolverParameter& param() const { return param_; }
inline shared_ptr<Net<Dtype> > net() { return net_; }
Expand Down Expand Up @@ -92,11 +97,6 @@ class Solver {
protected:
// Make and apply the update value for the current iteration.
virtual void ApplyUpdate() = 0;
// The Solver::Snapshot function implements the basic snapshotting utility
// that stores the learned net. You should implement the SnapshotSolverState()
// function that produces a SolverState protocol buffer that needs to be
// written to disk together with the learned net.
void Snapshot();
string SnapshotFilename(const string extension);
string SnapshotToBinaryProto();
string SnapshotToHDF5();
Expand Down
3 changes: 2 additions & 1 deletion python/caffe/_caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,8 @@ BOOST_PYTHON_MODULE(_caffe) {
.def("solve", static_cast<void (Solver<Dtype>::*)(const char*)>(
&Solver<Dtype>::Solve), SolveOverloads())
.def("step", &Solver<Dtype>::Step)
.def("restore", &Solver<Dtype>::Restore);
.def("restore", &Solver<Dtype>::Restore)
.def("snapshot", &Solver<Dtype>::Snapshot);

bp::class_<SGDSolver<Dtype>, bp::bases<Solver<Dtype> >,
shared_ptr<SGDSolver<Dtype> >, boost::noncopyable>(
Expand Down
11 changes: 10 additions & 1 deletion python/caffe/test/test_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ def setUp(self):
f.write("""net: '""" + net_f + """'
test_iter: 10 test_interval: 10 base_lr: 0.01 momentum: 0.9
weight_decay: 0.0005 lr_policy: 'inv' gamma: 0.0001 power: 0.75
display: 100 max_iter: 100 snapshot_after_train: false""")
display: 100 max_iter: 100 snapshot_after_train: false
snapshot_prefix: "model" """)
f.close()
self.solver = caffe.SGDSolver(f.name)
# also make sure get_solver runs
Expand Down Expand Up @@ -51,3 +52,11 @@ def test_net_memory(self):
total += p.data.sum() + p.diff.sum()
for bl in six.itervalues(net.blobs):
total += bl.data.sum() + bl.diff.sum()

def test_snapshot(self):
self.solver.snapshot()
# Check that these files exist and then remove them
files = ['model_iter_0.caffemodel', 'model_iter_0.solverstate']
for fn in files:
assert os.path.isfile(fn)
os.remove(fn)

0 comments on commit f5fd18b

Please sign in to comment.