From 8daccc9ea7dbabec034882575b3738cf5c4c1dcc Mon Sep 17 00:00:00 2001 From: ceci3 Date: Fri, 25 Sep 2020 16:25:49 +0800 Subject: [PATCH] Fix batch norm double grad compute (#27549) * fix bn double grad, test=develop * update, test=develop --- paddle/fluid/operators/batch_norm_op.cc | 55 ++++++++------ paddle/fluid/operators/instance_norm_op.cc | 6 +- paddle/fluid/operators/norm_utils.cu.h | 75 ++++++++++++++----- .../tests/unittests/test_norm_nn_grad.py | 36 +++++++++ 4 files changed, 131 insertions(+), 41 deletions(-) diff --git a/paddle/fluid/operators/batch_norm_op.cc b/paddle/fluid/operators/batch_norm_op.cc index dcfe8bb1bb48a5..7a88403aa9daa7 100644 --- a/paddle/fluid/operators/batch_norm_op.cc +++ b/paddle/fluid/operators/batch_norm_op.cc @@ -839,6 +839,7 @@ void BatchNormDoubleGradMaker::Apply(GradOpPtr op) const { op->SetInput("SavedMean", this->Input("SavedMean")); op->SetInput("SavedVariance", this->Input("SavedVariance")); if (BOOST_GET_CONST(bool, this->GetAttr("use_global_stats"))) { + op->SetInput("Mean", this->Input("Mean")); op->SetInput("Variance", this->Input("Variance")); } op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X"))); @@ -868,14 +869,19 @@ void BatchNormDoubleGradOp::InferShape( "BatchNormDoubleGrad"); } - OP_INOUT_CHECK(ctx->HasInput("DDX"), "Input", "DDX", "BatchNormDoubleGrad"); OP_INOUT_CHECK(ctx->HasInput("DY"), "Input", "DY", "BatchNormDoubleGrad"); // check output OP_INOUT_CHECK(ctx->HasOutput("DX"), "Output", "DX", "BatchNormDoubleGrad"); const auto x_dims = ctx->GetInputDim("X"); - const int C = x_dims[1]; + const DataLayout data_layout = framework::StringToDataLayout( + ctx->Attrs().Get("data_layout")); + const int C = + ((this->IsMKLDNNType() == true) || (data_layout == DataLayout::kNCHW) + ? x_dims[1] + : x_dims[x_dims.size() - 1]); + if (ctx->HasOutput("DX")) { ctx->SetOutputDim("DX", x_dims); } @@ -957,7 +963,9 @@ class BatchNormDoubleGradKernel Tensor inv_var_tensor; if (use_global_stats) { + const auto *running_mean = ctx.Input("Mean"); const auto *running_variance = ctx.Input("Variance"); + mean_data = running_mean->data(); inv_var_tensor.Resize({C}); T *running_inv_var_data = inv_var_tensor.mutable_data(ctx.GetPlace()); @@ -1077,12 +1085,12 @@ class BatchNormDoubleGradKernel // (np.mean(dy, axis=(n,h,w)) - dy) + inv_var.pow(3) / NxHxW * // np.sum(dy, // axis=(n,h,w)) * (x - mean) * - // (np.mean(ddx, axis=(n,h,w)) - ddx) + ddr * (dy * inv_var - + // (np.mean(ddx, axis=(n,h,w)) - ddx)) + ddr * (dy * inv_var - // inv_var // * // np.mean(dy, axis=(n,h,w)) - // inv_var.pow(3) * (x - mean) * np.mean(dy * (x - mean), - // axis=(n,h,w)))) + // axis=(n,h,w))) if (ddX) { dx_arr += @@ -1176,7 +1184,8 @@ class BatchNormDoubleGradKernel C, sample_size); ddy_arr.setZero(); if (use_global_stats) { - // math: ddy = r * ddx * inv_var + // math: ddy = r * ddx * inv_var + ddbias + + // ddscale * (x - mean) * inv_var if (ddX) { ddy_arr = scale_tile_data * ddx_arr * inv_var_tile_data; } @@ -1196,25 +1205,29 @@ class BatchNormDoubleGradKernel .replicate(1, sample_size) / sample_size); } - if (ddScale && ddBias) { - ConstEigenVectorArrayMap ddscale_arr(ddScale->data(), C); - Tensor ddscale_tile; - ddscale_tile.Resize({C, sample_size}); - EigenArrayMap ddscale_tile_data( - ddscale_tile.mutable_data(ctx.GetPlace()), C, sample_size); - ddscale_tile_data = ddscale_arr.replicate(1, sample_size); + } + if (ddScale) { + ConstEigenVectorArrayMap ddscale_arr(ddScale->data(), C); + Tensor ddscale_tile; + ddscale_tile.Resize({C, sample_size}); + EigenArrayMap ddscale_tile_data( + ddscale_tile.mutable_data(ctx.GetPlace()), C, sample_size); + ddscale_tile_data = ddscale_arr.replicate(1, sample_size); + + ddy_arr += x_sub_mean_mul_invstd_arr * ddscale_tile_data; + } - ConstEigenVectorArrayMap ddbias_arr(ddBias->data(), C); - Tensor ddbias_tile; - ddbias_tile.Resize({C, sample_size}); - EigenArrayMap ddbias_tile_data( - ddbias_tile.mutable_data(ctx.GetPlace()), C, sample_size); - ddbias_tile_data = ddbias_arr.replicate(1, sample_size); + if (ddBias) { + ConstEigenVectorArrayMap ddbias_arr(ddBias->data(), C); + Tensor ddbias_tile; + ddbias_tile.Resize({C, sample_size}); + EigenArrayMap ddbias_tile_data( + ddbias_tile.mutable_data(ctx.GetPlace()), C, sample_size); + ddbias_tile_data = ddbias_arr.replicate(1, sample_size); - ddy_arr += x_sub_mean_mul_invstd_arr * ddscale_tile_data; - ddy_arr += ddbias_tile_data; - } + ddy_arr += ddbias_tile_data; } + if (data_layout == DataLayout::kNCHW) { VLOG(3) << "Transform batchnorm output from NHWC to NCHW"; TransToChannelFirst( diff --git a/paddle/fluid/operators/instance_norm_op.cc b/paddle/fluid/operators/instance_norm_op.cc index a5b270c1dfef14..03279a9b2c15b8 100644 --- a/paddle/fluid/operators/instance_norm_op.cc +++ b/paddle/fluid/operators/instance_norm_op.cc @@ -520,11 +520,11 @@ class InstanceNormDoubleGradKernel // (np.mean(dy, axis=(h,w)) - dy) + inv_var.pow(3) / HxW * // np.sum(dy, // axis=(h,w)) * (x - mean) * - // (np.mean(ddx, axis=(h,w)) - ddx) + ddr * (dy * inv_var - inv_var - // * + // (np.mean(ddx, axis=(h,w)) - ddx)) + ddr * (dy * inv_var - + // inv_var * // np.mean(dy, axis=(h,w)) - // inv_var.pow(3) * (x - mean) * np.mean(dy * (x - mean), - // axis=(h,w)))) + // axis=(h,w))) Tensor x_sub_mean_mul_invstd; x_sub_mean_mul_invstd.Resize({sample_size, NxC}); diff --git a/paddle/fluid/operators/norm_utils.cu.h b/paddle/fluid/operators/norm_utils.cu.h index 07333f1ae11c38..02dcb4045f4cde 100644 --- a/paddle/fluid/operators/norm_utils.cu.h +++ b/paddle/fluid/operators/norm_utils.cu.h @@ -40,12 +40,12 @@ using DataLayout = framework::DataLayout; // (np.mean(dy, axis=(n,h,w)) - dy) + inv_var.pow(3) / NxHxW * // np.sum(dy, // axis=(n,h,w)) * (x - mean) * -// (np.mean(ddx, axis=(n,h,w)) - ddx) + ddr * (dy * inv_var - +// (np.mean(ddx, axis=(n,h,w)) - ddx)) + ddr * (dy * inv_var - // inv_var // * // np.mean(dy, axis=(n,h,w)) - // inv_var.pow(3) * (x - mean) * np.mean(dy * (x - mean), -// axis=(n,h,w)))) +// axis=(n,h,w))) template __global__ void DoubleGradComputeDX(const T *x, const T *mean, @@ -138,7 +138,7 @@ __global__ void DoubleGradComputeDX(const T *x, const T *mean, ? (j / sample_size * C + i) * sample_size + j % sample_size : j * outer_size + i; dx[index] += (dy[index] * var_val - dy_sum_val / inner_size * var_val - - (x[index] - mean_val) * var_val * + (x[index] - mean_val) * var_val * var_val * dy_mul_x_sub_mean_sum_val * var_val / inner_size) * ddscale[i]; } @@ -326,19 +326,57 @@ __global__ void DoubleGradComputeDScaleWithGlobal( } // math: dx = ddscale * dy * inv_var -// math: ddy = scale * ddx * inv_var template -__global__ void DoubleGradComputeDataWithGlobal( - const T *dy, const T *scale, const T *variance, const double epsilon, - const int C, const int sample_size, const int num, T *dx) { +__global__ void DoubleGradComputeDXWithGlobal(const T *dy, const T *ddscale, + const T *variance, + const double epsilon, const int C, + const int sample_size, + const int num, T *dx) { int gid = blockIdx.x * blockDim.x + threadIdx.x; int stride = blockDim.x * gridDim.x; - if (scale != nullptr) { + if (ddscale != nullptr) { for (int i = gid; i < num; i += stride) { const int c = layout == framework::DataLayout::kNCHW ? i / sample_size % C : i % C; T inv_var = 1.0 / sqrt(variance[c] + epsilon); - dx[i] = dy[i] * scale[c] * inv_var; + dx[i] = dy[i] * ddscale[c] * inv_var; + } + } +} + +// math: ddy = scale * ddx * inv_var + ddbias + +// ddscale * (x - mean) * inv_var +template +__global__ void DoubleGradComputeDDYWithGlobal( + const T *ddx, const T *scale, const T *mean, const T *variance, const T *x, + const T *ddbias, const T *ddscale, const double epsilon, const int C, + const int sample_size, const int num, T *ddy) { + int gid = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + + if (ddx != nullptr) { + for (int i = gid; i < num; i += stride) { + const int c = + layout == framework::DataLayout::kNCHW ? i / sample_size % C : i % C; + T inv_var = 1.0 / sqrt(variance[c] + epsilon); + ddy[i] += ddx[i] * scale[c] * inv_var; + } + } + __syncthreads(); + if (ddscale != nullptr) { + for (int i = gid; i < num; i += stride) { + const int c = + layout == framework::DataLayout::kNCHW ? i / sample_size % C : i % C; + T inv_var = 1.0 / sqrt(variance[c] + epsilon); + ddy[i] += (x[i] - mean[c]) * inv_var * ddscale[c]; + } + } + __syncthreads(); + if (ddbias != nullptr) { + for (int i = gid; i < num; i += stride) { + const int c = + layout == framework::DataLayout::kNCHW ? i / sample_size % C : i % C; + ddy[i] += ddbias[c]; } } } @@ -383,8 +421,11 @@ void NormDoubleGradFunctor(const framework::ExecutionContext &ctx, const T *mean_data, *variance_data; if (use_global_stats) { + const auto *running_mean = ctx.Input("Mean"); const auto *running_var = ctx.Input("Variance"); + const auto *running_mean_data = running_mean->template data(); const auto *running_var_data = running_var->template data(); + mean_data = running_mean_data; variance_data = running_var_data; } else { const T *smean_data = Saved_mean->data(); @@ -398,12 +439,12 @@ void NormDoubleGradFunctor(const framework::ExecutionContext &ctx, set_constant(dev_ctx, dX, static_cast(0)); if (use_global_stats) { if (data_layout == DataLayout::kNHWC) { - DoubleGradComputeDataWithGlobal< + DoubleGradComputeDXWithGlobal< T, DataLayout::kNHWC><<>>( dy_data, ddscale_data, variance_data, epsilon, C, sample_size, num, dx_data); } else { - DoubleGradComputeDataWithGlobal< + DoubleGradComputeDXWithGlobal< T, DataLayout::kNCHW><<>>( dy_data, ddscale_data, variance_data, epsilon, C, sample_size, num, dx_data); @@ -456,15 +497,15 @@ void NormDoubleGradFunctor(const framework::ExecutionContext &ctx, set_constant(dev_ctx, ddY, static_cast(0)); if (use_global_stats) { if (data_layout == DataLayout::kNHWC) { - DoubleGradComputeDataWithGlobal< + DoubleGradComputeDDYWithGlobal< T, DataLayout::kNHWC><<>>( - ddx_data, scale_data, variance_data, epsilon, C, sample_size, num, - ddy_data); + ddx_data, scale_data, mean_data, variance_data, x_data, ddbias_data, + ddscale_data, epsilon, C, sample_size, num, ddy_data); } else { - DoubleGradComputeDataWithGlobal< + DoubleGradComputeDDYWithGlobal< T, DataLayout::kNCHW><<>>( - ddx_data, scale_data, variance_data, epsilon, C, sample_size, num, - ddy_data); + ddx_data, scale_data, mean_data, variance_data, x_data, ddbias_data, + ddscale_data, epsilon, C, sample_size, num, ddy_data); } } else { if (data_layout == DataLayout::kNHWC) { diff --git a/python/paddle/fluid/tests/unittests/test_norm_nn_grad.py b/python/paddle/fluid/tests/unittests/test_norm_nn_grad.py index a89b9fde7f92de..cb4bd16ce219f8 100644 --- a/python/paddle/fluid/tests/unittests/test_norm_nn_grad.py +++ b/python/paddle/fluid/tests/unittests/test_norm_nn_grad.py @@ -130,5 +130,41 @@ def init_test(self): self.shape = [2, 2, 3, 4, 5] +class TestBatchNormDoubleGradCheckCase5(TestBatchNormDoubleGradCheck): + @prog_scope() + def func(self, place): + prog = fluid.Program() + with fluid.program_guard(prog): + np.random.seed() + dtype = "float32" + eps = 0.005 + atol = 2e-4 + chn = self.shape[1] if self.data_layout == 'NCHW' else self.shape[ + -1] + x = layers.create_parameter(dtype=dtype, shape=self.shape, name='x') + z = fluid.layers.batch_norm( + input=x, + data_layout=self.data_layout, + use_global_stats=self.use_global_stats) + x_arr = np.random.uniform(-1, 1, self.shape).astype(dtype) + w, b = prog.global_block().all_parameters()[1:3] + w_arr = np.ones(chn).astype(dtype) + b_arr = np.zeros(chn).astype(dtype) + gradient_checker.double_grad_check( + [x, w, b], + z, + x_init=[x_arr, w_arr, b_arr], + atol=atol, + place=place, + eps=eps) + + +class TestBatchNormDoubleGradCheckCase6(TestBatchNormDoubleGradCheckCase5): + def init_test(self): + self.data_layout = 'NCHW' + self.use_global_stats = True + self.shape = [2, 3, 4, 5] + + if __name__ == "__main__": unittest.main()