Skip to content

Commit

Permalink
Make grid support stopping graients. (PaddlePaddle#27630)
Browse files Browse the repository at this point in the history
  • Loading branch information
wanghaoshuang authored Sep 28, 2020
1 parent 074a71b commit 9cc5603
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 55 deletions.
2 changes: 0 additions & 2 deletions paddle/fluid/operators/grid_sampler_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,6 @@ class GridSampleOpGrad : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("X")), "Output",
framework::GradVarName("X"), "grid_sampler");
OP_INOUT_CHECK(ctx->HasOutput(framework::GradVarName("Grid")), "Output",
framework::GradVarName("Grid"), "grid_sampler");
auto input_dims = ctx->GetInputDim("X");
auto grid_dims = ctx->GetInputDim("Grid");
if (ctx->HasOutput(framework::GradVarName("X"))) {
Expand Down
34 changes: 21 additions & 13 deletions paddle/fluid/operators/grid_sampler_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,11 @@ __global__ void grid_sampler_cuda_backward_kernel(
}
}

T* gGrid_ptr_NHW = grad_grid + index * grid_sW;
gGrid_ptr_NHW[0] = gix_mult * gix;
gGrid_ptr_NHW[1] = giy_mult * giy;
if (grad_grid != nullptr) {
T* gGrid_ptr_NHW = grad_grid + index * grid_sW;
gGrid_ptr_NHW[0] = gix_mult * gix;
gGrid_ptr_NHW[1] = giy_mult * giy;
}
} else if (mode == Mode::nearest) {
int ix_nearest = static_cast<int>(::round(ix));
int iy_nearest = static_cast<int>(::round(iy));
Expand All @@ -412,9 +414,11 @@ __global__ void grid_sampler_cuda_backward_kernel(
in_w, grad_output[gOut_offset]);
}

T* gGrid_ptr_NHW = grad_grid + index * grid_sW;
gGrid_ptr_NHW[0] = static_cast<T>(0);
gGrid_ptr_NHW[1] = static_cast<T>(0);
if (grad_grid != nullptr) {
T* gGrid_ptr_NHW = grad_grid + index * grid_sW;
gGrid_ptr_NHW[0] = static_cast<T>(0);
gGrid_ptr_NHW[1] = static_cast<T>(0);
}
}
}
}
Expand Down Expand Up @@ -460,20 +464,24 @@ class GridSampleGradOpCUDAKernel : public framework::OpKernel<T> {
math::SetConstant<paddle::platform::CUDADeviceContext, T>()(
ctx.template device_context<paddle::platform::CUDADeviceContext>(),
input_grad, static_cast<T>(0));
auto* grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
grid_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<paddle::platform::CUDADeviceContext, T>()(
ctx.template device_context<paddle::platform::CUDADeviceContext>(),
grid_grad, static_cast<T>(0));

T* grid_grad_data = nullptr;
if (ctx.HasOutput(framework::GradVarName("Grid"))) {
auto* grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
grid_grad_data = grid_grad->mutable_data<T>(ctx.GetPlace());
math::SetConstant<paddle::platform::CUDADeviceContext, T>()(
ctx.template device_context<paddle::platform::CUDADeviceContext>(),
grid_grad, static_cast<T>(0));
}

int count = static_cast<int>(n * out_h * out_w);
auto cu_stream = dev_ctx.stream();
int block = 512;
int grid_size = (count + block - 1) / block;
grid_sampler_cuda_backward_kernel<T><<<block, grid_size, 0, cu_stream>>>(
count, output_grad->data<T>(), input->data<T>(), grid->data<T>(), n, c,
out_h, out_w, in_h, in_w, input_grad->data<T>(), grid_grad->data<T>(),
mode, padding_mode, align_corners);
out_h, out_w, in_h, in_w, input_grad->data<T>(), grid_grad_data, mode,
padding_mode, align_corners);
}
};

Expand Down
87 changes: 47 additions & 40 deletions paddle/fluid/operators/grid_sampler_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -450,45 +450,47 @@ static void gatherBilinearGrad(const platform::CPUDeviceContext& ctx,

auto output_grad_t = EigenTensor<T, 4>::From(output_grad);

Tensor grid_grad_x, grid_grad_y;
grid_grad_x.mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
grid_grad_y.mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
auto grid_grad_x_t =
EigenTensor<T, 3>::From(grid_grad_x).setConstant(static_cast<T>(0.0));
auto grid_grad_y_t =
EigenTensor<T, 3>::From(grid_grad_y).setConstant(static_cast<T>(0.0));
for (int i = 0; i < n; i++) {
for (int j = 0; j < c; j++) {
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
grid_grad_x_t(i, k, l) +=
((v_en_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_s_t(i, k, l) +
(v_es_t(i, j, k, l) - v_ws_t(i, j, k, l)) * d_n_t(i, k, l)) *
output_grad_t(i, j, k, l);
grid_grad_y_t(i, k, l) +=
((v_ws_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_e_t(i, k, l) +
(v_es_t(i, j, k, l) - v_en_t(i, j, k, l)) * d_w_t(i, k, l)) *
output_grad_t(i, j, k, l);
if (grid_grad != nullptr) {
Tensor grid_grad_x, grid_grad_y;
grid_grad_x.mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
grid_grad_y.mutable_data<T>({n, out_h, out_w}, ctx.GetPlace());
auto grid_grad_x_t =
EigenTensor<T, 3>::From(grid_grad_x).setConstant(static_cast<T>(0.0));
auto grid_grad_y_t =
EigenTensor<T, 3>::From(grid_grad_y).setConstant(static_cast<T>(0.0));
for (int i = 0; i < n; i++) {
for (int j = 0; j < c; j++) {
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
grid_grad_x_t(i, k, l) +=
((v_en_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_s_t(i, k, l) +
(v_es_t(i, j, k, l) - v_ws_t(i, j, k, l)) * d_n_t(i, k, l)) *
output_grad_t(i, j, k, l);
grid_grad_y_t(i, k, l) +=
((v_ws_t(i, j, k, l) - v_wn_t(i, j, k, l)) * d_e_t(i, k, l) +
(v_es_t(i, j, k, l) - v_en_t(i, j, k, l)) * d_w_t(i, k, l)) *
output_grad_t(i, j, k, l);
}
}
}
}
}

// const T x_max = static_cast<T>(in_w - 1);
// const T y_max = static_cast<T>(in_h - 1);

auto grid_x_scale_t = EigenTensor<T, 3>::From(*grid_x_scale);
auto grid_y_scale_t = EigenTensor<T, 3>::From(*grid_y_scale);
grid_grad_x_t = grid_grad_x_t * grid_x_scale_t;
grid_grad_y_t = grid_grad_y_t * grid_y_scale_t;

// gather grid_grad [x, y] in 3rd Dim
T* grid_grad_data = grid_grad->data<T>();
T* grid_grad_x_data = grid_grad_x.data<T>();
T* grid_grad_y_data = grid_grad_y.data<T>();
for (int i = 0; i < n * out_h * out_w; i++) {
grid_grad_data[2 * i] = grid_grad_x_data[i];
grid_grad_data[2 * i + 1] = grid_grad_y_data[i];
// const T x_max = static_cast<T>(in_w - 1);
// const T y_max = static_cast<T>(in_h - 1);

auto grid_x_scale_t = EigenTensor<T, 3>::From(*grid_x_scale);
auto grid_y_scale_t = EigenTensor<T, 3>::From(*grid_y_scale);
grid_grad_x_t = grid_grad_x_t * grid_x_scale_t;
grid_grad_y_t = grid_grad_y_t * grid_y_scale_t;

// gather grid_grad [x, y] in 3rd Dim
T* grid_grad_data = grid_grad->data<T>();
T* grid_grad_x_data = grid_grad_x.data<T>();
T* grid_grad_y_data = grid_grad_y.data<T>();
for (int i = 0; i < n * out_h * out_w; i++) {
grid_grad_data[2 * i] = grid_grad_x_data[i];
grid_grad_data[2 * i + 1] = grid_grad_y_data[i];
}
}
}

Expand Down Expand Up @@ -558,11 +560,16 @@ class GridSampleGradOpKernel : public framework::OpKernel<T> {
math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), input_grad,
static_cast<T>(0));
auto* grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
grid_grad->mutable_data<T>({n, out_h, out_w, 2}, ctx.GetPlace());
math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), grid_grad,
static_cast<T>(0));

Tensor* grid_grad = nullptr;
if (ctx.HasOutput(framework::GradVarName("Grid"))) {
grid_grad = ctx.Output<Tensor>(framework::GradVarName("Grid"));
grid_grad->mutable_data<T>({n, out_h, out_w, 2}, ctx.GetPlace());
math::SetConstant<DeviceContext, T>()(
ctx.template device_context<DeviceContext>(), grid_grad,
static_cast<T>(0));
}

Tensor grid_x, grid_y;
Tensor grid_x_scale, grid_y_scale;
calcGridLocationsWithGrad<T>(
Expand Down

0 comments on commit 9cc5603

Please sign in to comment.