Skip to content

Commit

Permalink
Support padded layout
Browse files Browse the repository at this point in the history
  • Loading branch information
hshen14 committed Aug 25, 2017
1 parent bf824c4 commit c55cac3
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 97 deletions.
3 changes: 1 addition & 2 deletions src/caffe/mkldnn_memory.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,7 @@ void MKLDNNMemoryDescriptor<Dtype, is_diff>::convert_from_extprv(shared_ptr<prim
CHECK(aprimitive);
if(this->_reorder_extprv2prv_pd == NULL)
return;
if (this->_extprv_memory_pd->desc().data.format == this->_prv_memory_pd->desc().data.format &&
this->_extprv_memory_pd->desc().data.data_type == this->_prv_memory_pd->desc().data.data_type)
if (*this->_extprv_memory_pd == *this->_prv_memory_pd)
{
#ifdef DEBUG
LOG(INFO) << "The format and data_type of _extprv_memory_pd and _prv_memory_pd is same, no need do conversion.";
Expand Down
107 changes: 12 additions & 95 deletions src/caffe/solvers/sgd_solver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -354,38 +354,25 @@ void SGDSolver<Dtype>::SGDFusion(int param_id, Dtype rate) {
bool prv_diff_condition_flag = false;
if (net_params[param_id]->prv_diff()
&& (net_params[param_id]->prv_diff_count()
== net_params[param_id]->prv_data_count())) {
== net_params[param_id]->count())) {
prv_diff_condition_flag = true;
//LOG(INFO) << "Common condition judgement: prv_diff_condition_flag = true.";
}
else
{
//LOG(INFO) << "Common condition judgement: prv_diff_condition_flag = false.";
}
//#pragma endregion

//#pragma region 3. Normalize stage
if (skip_Normalize_stage_flag == false)
{
//LOG(INFO) << "Normalize stage: Normalize stage is not skipped.";

const Dtype accum_normalization = Dtype(1.) / this->param_.iter_size();

if (prv_diff_condition_flag) {
//LOG(INFO) << "Normalize stage: prv_diff_condition_flag = true.";
caffe_scal(net_params[param_id]->prv_data_count(), accum_normalization,
caffe_scal(net_params[param_id]->prv_diff_count(), accum_normalization,
net_params[param_id]->mutable_prv_diff());
}
else {
//LOG(INFO) << "Normalize stage: prv_diff_condition_flag = false.";
caffe_scal(net_params[param_id]->count(), accum_normalization,
net_params[param_id]->mutable_cpu_diff());
}
}
else
{
//LOG(INFO) << "Normalize stage: Normalize stage is skipped.";
}
//#pragma endregion

//For most common topologies from BVLC, all skipped the Normalize stage, and use L2 regularization
Expand All @@ -401,97 +388,35 @@ void SGDSolver<Dtype>::SGDFusion(int param_id, Dtype rate) {
//Regularize stage (Fused ComputeUpdateValue_stage in some situations)
if (local_decay) {
if (regularization_type == "L2") {
//LOG(INFO) << "Regularize stage: regularization_type == L2.";
// add weight decay
if (net_params[param_id]->prv_data()
&& (net_params[param_id]->prv_data_count()
== net_params[param_id]->count())) {
//LOG(INFO) << "Regularize stage: prv_data_condition_flag = true.";
CHECK_EQ(true,
net_params[param_id]->get_prv_data_descriptor()->layout_compare(
net_params[param_id]->get_prv_diff_descriptor()));
/*
caffe_axpy(net_params[param_id]->count(),
local_decay,
net_params[param_id]->prv_data(),
net_params[param_id]->mutable_prv_diff());
*/
if (prv_diff_condition_flag) {
//situation (1)
//LOG(INFO) << "Fused ComputeUpdateValue stage: prv_diff_condition_flag = true.";
/*
caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->prv_diff(), momentum,
history_[param_id]->mutable_cpu_data());
caffe_copy(net_params[param_id]->count(),
history_[param_id]->cpu_data(),
net_params[param_id]->mutable_prv_diff());
*/

if(net_params[param_id]->prv_data_count() != history_[param_id]->count())
history_[param_id]->Reshape(net_params[param_id]->shape());

axpy_axpby_copy_axpy(net_params[param_id]->prv_data_count(), local_decay,
net_params[param_id]->mutable_prv_data(), net_params[param_id]->mutable_prv_diff(),
local_rate, momentum, history_[param_id]->mutable_cpu_data(), Dtype(-1));

is_separate_ComputeUpdateValue_Update = false;
}
else
{
//Will not happen!
//LOG(INFO) << "Cannot Fused ComputeUpdateValue stage: prv_diff_condition_flag = false.";
}
} else {
//LOG(INFO) << "Regularize stage: prv_data_condition_flag = false.";
/*
caffe_axpy(net_params[param_id]->count(),
local_decay,
net_params[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
*/
if (!prv_diff_condition_flag)
{
//situation (2)
//LOG(INFO) << "Fused ComputeUpdateValue stage: prv_diff_condition_flag = false.";
/*
caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->cpu_diff(), momentum,
history_[param_id]->mutable_cpu_data());
caffe_copy(net_params[param_id]->count(),
history_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
*/

axpy_axpby_copy_axpy(net_params[param_id]->count(), local_decay,
net_params[param_id]->mutable_cpu_data(), net_params[param_id]->mutable_cpu_diff(),
local_rate, momentum, history_[param_id]->mutable_cpu_data(), Dtype(-1));

is_separate_ComputeUpdateValue_Update = false;
}
else
{
//Will not happen!
//LOG(INFO) << "Cannot Fused ComputeUpdateValue stage: prv_diff_condition_flag = true.";
if(net_params[param_id]->prv_data_count() != history_[param_id]->count())
history_[param_id]->Reshape(net_params[param_id]->shape());
}
}
} else if (regularization_type == "L1") {
//LOG(INFO) << "Regularize stage: regularization_type == L1.";
caffe_cpu_sign(net_params[param_id]->count(),
net_params[param_id]->cpu_data(),
temp_[param_id]->mutable_cpu_data());

/*
caffe_axpy(net_params[param_id]->count(),
local_decay,
temp_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
*/

axpy_axpby_copy(net_params[param_id]->count(), local_decay,
temp_[param_id]->cpu_data(), net_params[param_id]->mutable_cpu_diff(),
local_rate, momentum, history_[param_id]->mutable_cpu_data());
Expand All @@ -513,18 +438,14 @@ void SGDSolver<Dtype>::SGDFusion(int param_id, Dtype rate) {
//No Regularize stage, only ComputeUpdateValue stage
//ComputeUpdateValue stage
if (prv_diff_condition_flag) {
//LOG(INFO) << "ComputeUpdateValue stage: prv_diff_condition_flag = true.";
if(net_params[param_id]->prv_data_count() != history_[param_id]->count())
history_[param_id]->Reshape(net_params[param_id]->shape());
caffe_cpu_axpby(net_params[param_id]->prv_data_count(), local_rate,
caffe_cpu_axpby(net_params[param_id]->prv_diff_count(), local_rate,
net_params[param_id]->prv_diff(), momentum,
history_[param_id]->mutable_cpu_data());

caffe_copy(net_params[param_id]->prv_data_count(),
caffe_copy(net_params[param_id]->count(),
history_[param_id]->cpu_data(),
net_params[param_id]->mutable_prv_diff());
} else {
//LOG(INFO) << "ComputeUpdateValue stage: prv_diff_condition_flag = false.";
caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->cpu_diff(), momentum,
history_[param_id]->mutable_cpu_data());
Expand All @@ -537,7 +458,6 @@ void SGDSolver<Dtype>::SGDFusion(int param_id, Dtype rate) {
//Update stage (separate)
net_params[param_id]->Update();
}

}
#endif /* ENABLE_SGD_FUSION */

Expand All @@ -561,12 +481,10 @@ void SGDSolver<Dtype>::Normalize(int param_id) {
if (net_params[param_id]->prv_diff()
&& (net_params[param_id]->prv_diff_count()
== net_params[param_id]->count())) {
//LOG(INFO) << "Normalize stage: prv_diff_condition_flag = true.";
caffe_scal(net_params[param_id]->count(), accum_normalization,
caffe_scal(net_params[param_id]->prv_diff_count(), accum_normalization,
net_params[param_id]->mutable_prv_diff());
}
else {
//LOG(INFO) << "Normalize stage: prv_diff_condition_flag = false.";
caffe_scal(net_params[param_id]->count(), accum_normalization,
net_params[param_id]->mutable_cpu_diff());
}
Expand Down Expand Up @@ -599,29 +517,25 @@ void SGDSolver<Dtype>::Regularize(int param_id) {
case Caffe::CPU: {
if (local_decay) {
if (regularization_type == "L2") {
//LOG(INFO) << "Regularize stage: regularization_type == L2.";
// add weight decay
if (net_params[param_id]->prv_data()
&& (net_params[param_id]->prv_data_count()
== net_params[param_id]->count())) {
//LOG(INFO) << "Regularize stage: prv_data_condition_flag = true.";
CHECK_EQ(true,
net_params[param_id]->get_prv_data_descriptor()->layout_compare(
net_params[param_id]->get_prv_diff_descriptor()));

caffe_axpy(net_params[param_id]->count(),
caffe_axpy(net_params[param_id]->prv_data_count(),
local_decay,
net_params[param_id]->prv_data(),
net_params[param_id]->mutable_prv_diff());
} else {
//LOG(INFO) << "Regularize stage: prv_data_condition_flag = false.";
caffe_axpy(net_params[param_id]->count(),
local_decay,
net_params[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());
}
} else if (regularization_type == "L1") {
//LOG(INFO) << "Regularize stage: regularization_type == L1.";
caffe_cpu_sign(net_params[param_id]->count(),
net_params[param_id]->cpu_data(),
temp_[param_id]->mutable_cpu_data());
Expand Down Expand Up @@ -692,23 +606,26 @@ void SGDSolver<Dtype>::ComputeUpdateValue(int param_id, Dtype rate) {
if (net_params[param_id]->prv_diff()
&& (net_params[param_id]->prv_diff_count()
== net_params[param_id]->count())) {
//LOG(INFO) << "ComputeUpdateValue stage: prv_diff_condition_flag = true.";
caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
caffe_cpu_axpby(net_params[param_id]->prv_diff_count(), local_rate,
net_params[param_id]->prv_diff(), momentum,
history_[param_id]->mutable_cpu_data());

caffe_copy(net_params[param_id]->count(),
history_[param_id]->cpu_data(),
net_params[param_id]->mutable_prv_diff());
} else {
//LOG(INFO) << "ComputeUpdateValue stage: prv_diff_condition_flag = false.";
caffe_cpu_axpby(net_params[param_id]->count(), local_rate,
net_params[param_id]->cpu_diff(), momentum,
history_[param_id]->mutable_cpu_data());

caffe_copy(net_params[param_id]->count(),
history_[param_id]->cpu_data(),
net_params[param_id]->mutable_cpu_diff());

if (net_params[param_id]->prv_diff_count()
!= net_params[param_id]->count()) {
net_params[param_id]->mutable_prv_diff();
}
}
break;
}
Expand Down

0 comments on commit c55cac3

Please sign in to comment.