Skip to content

Commit

Permalink
Merge pull request BVLC#5408 from cypof/multi_infer
Browse files Browse the repository at this point in the history
Init test network on all GPUs
  • Loading branch information
cypof authored Apr 12, 2017
2 parents 90eff9b + 8602a23 commit 41a7d21
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
5 changes: 5 additions & 0 deletions python/caffe/_caffe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,10 @@ void Solver_add_nccl(Solver<Dtype>* solver
#endif
}

void share_weights(Solver<Dtype>* solver, Net<Dtype>* net) {
net->ShareTrainedLayersWith(solver->net().get());
}

template<typename Dtype>
class NetCallback: public Net<Dtype>::Callback {
public:
Expand Down Expand Up @@ -459,6 +463,7 @@ BOOST_PYTHON_MODULE(_caffe) {
.def("step", &Solver<Dtype>::Step)
.def("restore", &Solver<Dtype>::Restore)
.def("snapshot", &Solver<Dtype>::Snapshot)
.def("share_weights", &share_weights)
.add_property("param", bp::make_function(&Solver<Dtype>::param,
bp::return_value_policy<bp::copy_const_reference>()));
BP_REGISTER_SHARED_PTR_TO_PYTHON(Solver<Dtype>);
Expand Down
3 changes: 1 addition & 2 deletions src/caffe/solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,8 @@ void Solver<Dtype>::Init(const SolverParameter& param) {
}
// Scaffolding code
InitTrainNet();
InitTestNets();
if (Caffe::root_solver()) {
InitTestNets();
LOG(INFO) << "Solver scaffolding done.";
}
iter_ = 0;
Expand Down Expand Up @@ -102,7 +102,6 @@ void Solver<Dtype>::InitTrainNet() {

template <typename Dtype>
void Solver<Dtype>::InitTestNets() {
CHECK(Caffe::root_solver());
const bool has_net_param = param_.has_net_param();
const bool has_net_file = param_.has_net();
const int num_generic_nets = has_net_param + has_net_file;
Expand Down

0 comments on commit 41a7d21

Please sign in to comment.