From c55cac361e62ef1d99b9974d1d4d51665ad26e66 Mon Sep 17 00:00:00 2001 From: Haihao Shen Date: Sat, 26 Aug 2017 06:42:13 +0800 Subject: [PATCH] Support padded layout --- src/caffe/mkldnn_memory.cpp | 3 +- src/caffe/solvers/sgd_solver.cpp | 107 ++++--------------------------- 2 files changed, 13 insertions(+), 97 deletions(-) diff --git a/src/caffe/mkldnn_memory.cpp b/src/caffe/mkldnn_memory.cpp index 6e42e691d..c53cff7ff 100644 --- a/src/caffe/mkldnn_memory.cpp +++ b/src/caffe/mkldnn_memory.cpp @@ -212,8 +212,7 @@ void MKLDNNMemoryDescriptor::convert_from_extprv(shared_ptr_reorder_extprv2prv_pd == NULL) return; - if (this->_extprv_memory_pd->desc().data.format == this->_prv_memory_pd->desc().data.format && - this->_extprv_memory_pd->desc().data.data_type == this->_prv_memory_pd->desc().data.data_type) + if (*this->_extprv_memory_pd == *this->_prv_memory_pd) { #ifdef DEBUG LOG(INFO) << "The format and data_type of _extprv_memory_pd and _prv_memory_pd is same, no need do conversion."; diff --git a/src/caffe/solvers/sgd_solver.cpp b/src/caffe/solvers/sgd_solver.cpp index 929ff050f..6a7e2ca43 100644 --- a/src/caffe/solvers/sgd_solver.cpp +++ b/src/caffe/solvers/sgd_solver.cpp @@ -354,38 +354,25 @@ void SGDSolver::SGDFusion(int param_id, Dtype rate) { bool prv_diff_condition_flag = false; if (net_params[param_id]->prv_diff() && (net_params[param_id]->prv_diff_count() - == net_params[param_id]->prv_data_count())) { + == net_params[param_id]->count())) { prv_diff_condition_flag = true; - //LOG(INFO) << "Common condition judgement: prv_diff_condition_flag = true."; - } - else - { - //LOG(INFO) << "Common condition judgement: prv_diff_condition_flag = false."; } //#pragma endregion //#pragma region 3. Normalize stage if (skip_Normalize_stage_flag == false) { - //LOG(INFO) << "Normalize stage: Normalize stage is not skipped."; - const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size(); if (prv_diff_condition_flag) { - //LOG(INFO) << "Normalize stage: prv_diff_condition_flag = true."; - caffe_scal(net_params[param_id]->prv_data_count(), accum_normalization, + caffe_scal(net_params[param_id]->prv_diff_count(), accum_normalization, net_params[param_id]->mutable_prv_diff()); } else { - //LOG(INFO) << "Normalize stage: prv_diff_condition_flag = false."; caffe_scal(net_params[param_id]->count(), accum_normalization, net_params[param_id]->mutable_cpu_diff()); } } - else - { - //LOG(INFO) << "Normalize stage: Normalize stage is skipped."; - } //#pragma endregion //For most common topologies from BVLC, all skipped the Normalize stage, and use L2 regularization @@ -401,97 +388,35 @@ void SGDSolver::SGDFusion(int param_id, Dtype rate) { //Regularize stage (Fused ComputeUpdateValue_stage in some situations) if (local_decay) { if (regularization_type == "L2") { - //LOG(INFO) << "Regularize stage: regularization_type == L2."; // add weight decay if (net_params[param_id]->prv_data() && (net_params[param_id]->prv_data_count() == net_params[param_id]->count())) { - //LOG(INFO) << "Regularize stage: prv_data_condition_flag = true."; CHECK_EQ(true, net_params[param_id]->get_prv_data_descriptor()->layout_compare( net_params[param_id]->get_prv_diff_descriptor())); - /* - caffe_axpy(net_params[param_id]->count(), - local_decay, - net_params[param_id]->prv_data(), - net_params[param_id]->mutable_prv_diff()); - */ if (prv_diff_condition_flag) { - //situation (1) - //LOG(INFO) << "Fused ComputeUpdateValue stage: prv_diff_condition_flag = true."; - /* - caffe_cpu_axpby(net_params[param_id]->count(), local_rate, - net_params[param_id]->prv_diff(), momentum, - history_[param_id]->mutable_cpu_data()); - - caffe_copy(net_params[param_id]->count(), - history_[param_id]->cpu_data(), - net_params[param_id]->mutable_prv_diff()); - */ - - if(net_params[param_id]->prv_data_count() != history_[param_id]->count()) - history_[param_id]->Reshape(net_params[param_id]->shape()); - axpy_axpby_copy_axpy(net_params[param_id]->prv_data_count(), local_decay, net_params[param_id]->mutable_prv_data(), net_params[param_id]->mutable_prv_diff(), local_rate, momentum, history_[param_id]->mutable_cpu_data(), Dtype(-1)); is_separate_ComputeUpdateValue_Update = false; } - else - { - //Will not happen! - //LOG(INFO) << "Cannot Fused ComputeUpdateValue stage: prv_diff_condition_flag = false."; - } } else { - //LOG(INFO) << "Regularize stage: prv_data_condition_flag = false."; - /* - caffe_axpy(net_params[param_id]->count(), - local_decay, - net_params[param_id]->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); - */ if (!prv_diff_condition_flag) { - //situation (2) - //LOG(INFO) << "Fused ComputeUpdateValue stage: prv_diff_condition_flag = false."; - /* - caffe_cpu_axpby(net_params[param_id]->count(), local_rate, - net_params[param_id]->cpu_diff(), momentum, - history_[param_id]->mutable_cpu_data()); - - caffe_copy(net_params[param_id]->count(), - history_[param_id]->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); - */ - axpy_axpby_copy_axpy(net_params[param_id]->count(), local_decay, net_params[param_id]->mutable_cpu_data(), net_params[param_id]->mutable_cpu_diff(), local_rate, momentum, history_[param_id]->mutable_cpu_data(), Dtype(-1)); is_separate_ComputeUpdateValue_Update = false; } - else - { - //Will not happen! - //LOG(INFO) << "Cannot Fused ComputeUpdateValue stage: prv_diff_condition_flag = true."; - if(net_params[param_id]->prv_data_count() != history_[param_id]->count()) - history_[param_id]->Reshape(net_params[param_id]->shape()); - } } } else if (regularization_type == "L1") { - //LOG(INFO) << "Regularize stage: regularization_type == L1."; caffe_cpu_sign(net_params[param_id]->count(), net_params[param_id]->cpu_data(), temp_[param_id]->mutable_cpu_data()); - /* - caffe_axpy(net_params[param_id]->count(), - local_decay, - temp_[param_id]->cpu_data(), - net_params[param_id]->mutable_cpu_diff()); - */ - axpy_axpby_copy(net_params[param_id]->count(), local_decay, temp_[param_id]->cpu_data(), net_params[param_id]->mutable_cpu_diff(), local_rate, momentum, history_[param_id]->mutable_cpu_data()); @@ -513,18 +438,14 @@ void SGDSolver::SGDFusion(int param_id, Dtype rate) { //No Regularize stage, only ComputeUpdateValue stage //ComputeUpdateValue stage if (prv_diff_condition_flag) { - //LOG(INFO) << "ComputeUpdateValue stage: prv_diff_condition_flag = true."; - if(net_params[param_id]->prv_data_count() != history_[param_id]->count()) - history_[param_id]->Reshape(net_params[param_id]->shape()); - caffe_cpu_axpby(net_params[param_id]->prv_data_count(), local_rate, + caffe_cpu_axpby(net_params[param_id]->prv_diff_count(), local_rate, net_params[param_id]->prv_diff(), momentum, history_[param_id]->mutable_cpu_data()); - caffe_copy(net_params[param_id]->prv_data_count(), + caffe_copy(net_params[param_id]->count(), history_[param_id]->cpu_data(), net_params[param_id]->mutable_prv_diff()); } else { - //LOG(INFO) << "ComputeUpdateValue stage: prv_diff_condition_flag = false."; caffe_cpu_axpby(net_params[param_id]->count(), local_rate, net_params[param_id]->cpu_diff(), momentum, history_[param_id]->mutable_cpu_data()); @@ -537,7 +458,6 @@ void SGDSolver::SGDFusion(int param_id, Dtype rate) { //Update stage (separate) net_params[param_id]->Update(); } - } #endif /* ENABLE_SGD_FUSION */ @@ -561,12 +481,10 @@ void SGDSolver::Normalize(int param_id) { if (net_params[param_id]->prv_diff() && (net_params[param_id]->prv_diff_count() == net_params[param_id]->count())) { - //LOG(INFO) << "Normalize stage: prv_diff_condition_flag = true."; - caffe_scal(net_params[param_id]->count(), accum_normalization, + caffe_scal(net_params[param_id]->prv_diff_count(), accum_normalization, net_params[param_id]->mutable_prv_diff()); } else { - //LOG(INFO) << "Normalize stage: prv_diff_condition_flag = false."; caffe_scal(net_params[param_id]->count(), accum_normalization, net_params[param_id]->mutable_cpu_diff()); } @@ -599,29 +517,25 @@ void SGDSolver::Regularize(int param_id) { case Caffe::CPU: { if (local_decay) { if (regularization_type == "L2") { - //LOG(INFO) << "Regularize stage: regularization_type == L2."; // add weight decay if (net_params[param_id]->prv_data() && (net_params[param_id]->prv_data_count() == net_params[param_id]->count())) { - //LOG(INFO) << "Regularize stage: prv_data_condition_flag = true."; CHECK_EQ(true, net_params[param_id]->get_prv_data_descriptor()->layout_compare( net_params[param_id]->get_prv_diff_descriptor())); - caffe_axpy(net_params[param_id]->count(), + caffe_axpy(net_params[param_id]->prv_data_count(), local_decay, net_params[param_id]->prv_data(), net_params[param_id]->mutable_prv_diff()); } else { - //LOG(INFO) << "Regularize stage: prv_data_condition_flag = false."; caffe_axpy(net_params[param_id]->count(), local_decay, net_params[param_id]->cpu_data(), net_params[param_id]->mutable_cpu_diff()); } } else if (regularization_type == "L1") { - //LOG(INFO) << "Regularize stage: regularization_type == L1."; caffe_cpu_sign(net_params[param_id]->count(), net_params[param_id]->cpu_data(), temp_[param_id]->mutable_cpu_data()); @@ -692,8 +606,7 @@ void SGDSolver::ComputeUpdateValue(int param_id, Dtype rate) { if (net_params[param_id]->prv_diff() && (net_params[param_id]->prv_diff_count() == net_params[param_id]->count())) { - //LOG(INFO) << "ComputeUpdateValue stage: prv_diff_condition_flag = true."; - caffe_cpu_axpby(net_params[param_id]->count(), local_rate, + caffe_cpu_axpby(net_params[param_id]->prv_diff_count(), local_rate, net_params[param_id]->prv_diff(), momentum, history_[param_id]->mutable_cpu_data()); @@ -701,7 +614,6 @@ void SGDSolver::ComputeUpdateValue(int param_id, Dtype rate) { history_[param_id]->cpu_data(), net_params[param_id]->mutable_prv_diff()); } else { - //LOG(INFO) << "ComputeUpdateValue stage: prv_diff_condition_flag = false."; caffe_cpu_axpby(net_params[param_id]->count(), local_rate, net_params[param_id]->cpu_diff(), momentum, history_[param_id]->mutable_cpu_data()); @@ -709,6 +621,11 @@ void SGDSolver::ComputeUpdateValue(int param_id, Dtype rate) { caffe_copy(net_params[param_id]->count(), history_[param_id]->cpu_data(), net_params[param_id]->mutable_cpu_diff()); + + if (net_params[param_id]->prv_diff_count() + != net_params[param_id]->count()) { + net_params[param_id]->mutable_prv_diff(); + } } break; }