Skip to content

Commit

Permalink
support scaleshift accum with stats batch size>1
Browse files Browse the repository at this point in the history
Change-Id: I3b1a16dae1a6a2965b43ce61109d0a58b70e9093
  • Loading branch information
Gong, Jiong committed Aug 25, 2017
1 parent c83332f commit bf824c4
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 25 deletions.
5 changes: 3 additions & 2 deletions include/caffe/layers/mkldnn_layers.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ class MKLDNNBatchNormLayer : public MKLDNNLayer<Dtype>, public Layer<Dtype> {
, scaleshift_memory(), bwd_scaleshift_diff_memory()
, output_memory(), bwd_bottom_diff_memory()
, input_primitive(), bwd_top_diff_primitive()
, scaleshift_combination()
{
PERFORMANCE_EVENT_ID_RESET(perf_id_fw_);
PERFORMANCE_EVENT_ID_RESET(perf_id_bw_);
Expand Down Expand Up @@ -100,6 +99,7 @@ class MKLDNNBatchNormLayer : public MKLDNNLayer<Dtype>, public Layer<Dtype> {
void InitBatchNormBwdPrimitive(int stats_batch_idx);
template <bool diff> shared_ptr<memory> GetStatsBatchMemory(
shared_ptr<MKLDNNMemoryDescriptor<Dtype, diff> > mkldnn_data, int idx);
void InitStatsBatchVars(int batch_size);
shared_ptr<MKLDNNData<Dtype> > fwd_top_data, fwd_bottom_data;
shared_ptr<MKLDNNDiff<Dtype> > bwd_top_diff, bwd_bottom_diff;
shared_ptr<batch_normalization_forward::primitive_desc> BatchNormFwd_pd;
Expand All @@ -119,7 +119,8 @@ class MKLDNNBatchNormLayer : public MKLDNNLayer<Dtype>, public Layer<Dtype> {
bool use_weight_bias_, bias_term_, use_global_stats_;
int num_stats_batches_;
int stats_batch_size_;
shared_ptr<Blob<Dtype>> scaleshift_combination;
shared_ptr<Blob<Dtype> > scaleshift_blob_;
shared_ptr<Blob<Dtype> > scaleshift_acc_;

PERFORMANCE_EVENT_ID_DECL(perf_id_fw_);
PERFORMANCE_EVENT_ID_DECL(perf_id_bw_);
Expand Down
63 changes: 40 additions & 23 deletions src/caffe/layers/mkldnn_batch_norm_layer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

namespace caffe {

template <typename Dtype>
void MKLDNNBatchNormLayer<Dtype>::InitStatsBatchVars(int batch_size) {
num_stats_batches_ = 1;
stats_batch_size_ = batch_size;
BatchNormParameter param = this->layer_param_.batch_norm_param();
if (!use_global_stats_ && param.stats_batch_size() > 0) {
CHECK_EQ(batch_size % param.stats_batch_size(), 0);
num_stats_batches_ = batch_size / param.stats_batch_size();
stats_batch_size_ = param.stats_batch_size();
}
}

template <typename Dtype>
void MKLDNNBatchNormLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom
,const vector<Blob<Dtype>*>& top)
Expand All @@ -65,6 +77,8 @@ void MKLDNNBatchNormLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom
if (this->layer_param_.batch_norm_param().has_use_global_stats())
use_global_stats_ = this->layer_param_.batch_norm_param().use_global_stats();

InitStatsBatchVars(num_);

this->blobs_.resize(3 + (use_weight_bias_ ? 1:0) + (use_weight_bias_ && bias_term_ ? 1:0));

vector<int> sz;
Expand All @@ -81,15 +95,18 @@ void MKLDNNBatchNormLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom
//IntelCaffe treat scale and shift as different blobs, so current MKL-DNN integration has additional copies from Caffe to MKL-DNN buffer on fwd pass and from MKL-DNN to Caffe buffer on bwd pass.
//Optimization: use the temp blob to combine the scale and shift together. Avoid the additional copies.
// Initialize scale and shift combination blob
vector<int> scaleshift_combination_shape(1);
scaleshift_combination_shape[0] = 2*channels_;
this->scaleshift_combination.reset(new Blob<Dtype>(scaleshift_combination_shape));
//Should initialize the scaleshift_combine buffer to 0, because when bias_term_ == false, need to pass zero bias to MKLDNN
caffe_set(scaleshift_combination_shape[0], static_cast<Dtype>(0),
scaleshift_combination->mutable_cpu_data());
//Not so necessary, because the diff will initialize to 0 automatically
caffe_set(scaleshift_combination_shape[0], static_cast<Dtype>(0),
scaleshift_combination->mutable_cpu_diff());
vector<int> scaleshift_blob_shape(1);
scaleshift_blob_shape[0] = 2*channels_;
scaleshift_blob_.reset(new Blob<Dtype>(scaleshift_blob_shape));
//Should initialize the scaleshift_blob_ buffer to 0, because when bias_term_ == false, need to pass zero bias to MKLDNN
caffe_set(scaleshift_blob_shape[0], static_cast<Dtype>(0),
scaleshift_blob_->mutable_cpu_data());
shared_ptr<Blob<Dtype> > scaleshift_diff_blob = scaleshift_blob_;
scaleshift_acc_ = scaleshift_blob_;
if (num_stats_batches_ > 1) {
this->scaleshift_acc_.reset(new Blob<Dtype>(scaleshift_blob_shape));
scaleshift_diff_blob = scaleshift_acc_;
}

if (use_weight_bias_) {
// Initialize scale and shift
Expand All @@ -98,8 +115,8 @@ void MKLDNNBatchNormLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom
VLOG(1) << "MKLDNNBatchNormLayer<Dtype>::LayerSetUp: channels_ = " << channels_;

this->blobs_[3].reset(new Blob<Dtype>(scaleshift_shape));
this->blobs_[3]->set_cpu_data(scaleshift_combination->mutable_cpu_data());
this->blobs_[3]->set_cpu_diff(scaleshift_combination->mutable_cpu_diff());
this->blobs_[3]->set_cpu_data(scaleshift_blob_->mutable_cpu_data());
this->blobs_[3]->set_cpu_diff(scaleshift_diff_blob->mutable_cpu_diff());
FillerParameter filler_param(this->layer_param_.batch_norm_param().filler());
if (!this->layer_param_.batch_norm_param().has_filler()) {
filler_param.set_type("constant");
Expand All @@ -111,8 +128,8 @@ void MKLDNNBatchNormLayer<Dtype>::LayerSetUp(const vector<Blob<Dtype>*>& bottom

if (bias_term_) {
this->blobs_[4].reset(new Blob<Dtype>(scaleshift_shape));
this->blobs_[4]->set_cpu_data(scaleshift_combination->mutable_cpu_data() + scaleshift_combination->offset(channels_));
this->blobs_[4]->set_cpu_diff(scaleshift_combination->mutable_cpu_diff() + scaleshift_combination->offset(channels_));
this->blobs_[4]->set_cpu_data(scaleshift_blob_->mutable_cpu_data() + scaleshift_blob_->offset(channels_));
this->blobs_[4]->set_cpu_diff(scaleshift_diff_blob->mutable_cpu_diff() + scaleshift_blob_->offset(channels_));
FillerParameter bias_filler_param(this->layer_param_.batch_norm_param().bias_filler());
if (!this->layer_param_.batch_norm_param().has_bias_filler()) {
bias_filler_param.set_type("constant");
Expand Down Expand Up @@ -149,14 +166,7 @@ void MKLDNNBatchNormLayer<Dtype>::Reshape(const vector<Blob<Dtype>*>& bottom
this->num_ = bottom[0]->num();
this->channels_ = bottom[0]->channels();

num_stats_batches_ = 1;
stats_batch_size_ = bottom[0]->shape(0);
BatchNormParameter param = this->layer_param_.batch_norm_param();
if (!use_global_stats_ && param.stats_batch_size() > 0) {
CHECK_EQ(bottom[0]->shape(0) % param.stats_batch_size(), 0);
num_stats_batches_ = bottom[0]->shape(0) / param.stats_batch_size();
stats_batch_size_ = param.stats_batch_size();
}
InitStatsBatchVars(this->num_);

//Fix: should reshape the top blob with the real size of bottom blob
//top[0]->Reshape(this->num_, this->channels_, this->height_, this->width_);
Expand Down Expand Up @@ -229,7 +239,7 @@ void MKLDNNBatchNormLayer<Dtype>::InitBatchNorm(const vector<Blob<Dtype>*>& bott

// ---- Create memory ---------------------
if (use_weight_bias_) {
scaleshift_memory.reset(new memory(BatchNormFwd_pd->weights_primitive_desc(), this->scaleshift_combination->mutable_cpu_data()));
scaleshift_memory.reset(new memory(BatchNormFwd_pd->weights_primitive_desc(), this->scaleshift_blob_->mutable_cpu_data()));
}

// --- init primitive and prv_memory descriptors ----------------------
Expand Down Expand Up @@ -463,7 +473,7 @@ void MKLDNNBatchNormLayer<Dtype>::InitBatchNormBwd(

if (use_weight_bias_) {
bwd_scaleshift_diff_memory.reset(new memory(
BatchNormFwd_pd->weights_primitive_desc(), this->scaleshift_combination->mutable_cpu_diff()));
BatchNormFwd_pd->weights_primitive_desc(), this->scaleshift_blob_->mutable_cpu_diff()));
}

// --- init primitive and prv_memory descriptors ----------------------
Expand Down Expand Up @@ -561,6 +571,13 @@ void MKLDNNBatchNormLayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
}
#endif
PERFORMANCE_MEASUREMENT_END_ID(perf_id_bw_);
if (num_stats_batches_ > 1) {
CHECK(scaleshift_blob_ != scaleshift_acc_);
CHECK(scaleshift_blob_->count() == scaleshift_acc_->count());
caffe_cpu_axpby(scaleshift_acc_->count(), Dtype(1),
scaleshift_blob_->mutable_cpu_diff(),
Dtype(1), scaleshift_acc_->mutable_cpu_diff());
}
}
}

Expand Down

0 comments on commit bf824c4

Please sign in to comment.