Skip to content

Commit 3044a62

Browse files
authored
fix the precise roi poop op test=develop (PaddlePaddle#20126)
* fix the precise roi poop op test=develop add roi backward implementation, fix the output-channel
1 parent cff9970 commit 3044a62

File tree

7 files changed

+221
-82
lines changed

7 files changed

+221
-82
lines changed

paddle/fluid/API.spec

+1-1
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ paddle.fluid.layers.shuffle_channel (ArgSpec(args=['x', 'group', 'name'], vararg
290290
paddle.fluid.layers.temporal_shift (ArgSpec(args=['x', 'seg_num', 'shift_ratio', 'name'], varargs=None, keywords=None, defaults=(0.25, None)), ('document', 'd5945431cdcae3cda21914db5bbf383e'))
291291
paddle.fluid.layers.py_func (ArgSpec(args=['func', 'x', 'out', 'backward_func', 'skip_vars_in_backward_input'], varargs=None, keywords=None, defaults=(None, None)), ('document', '231f91231430f5dae2b757df22317c67'))
292292
paddle.fluid.layers.psroi_pool (ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '9bf0cc6b0717010b8ceec5dc2541d566'))
293-
paddle.fluid.layers.prroi_pool (ArgSpec(args=['input', 'rois', 'output_channels', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(1.0, 1, 1, None)), ('document', '454c7ea8c73313dd41513929d7526303'))
293+
paddle.fluid.layers.prroi_pool (ArgSpec(args=['input', 'rois', 'spatial_scale', 'pooled_height', 'pooled_width', 'name'], varargs=None, keywords=None, defaults=(1.0, 1, 1, None)), ('document', '466be691ac4c1cd7b88fccb40846afce'))
294294
paddle.fluid.layers.teacher_student_sigmoid_loss (ArgSpec(args=['input', 'label', 'soft_max_up_bound', 'soft_max_lower_bound'], varargs=None, keywords=None, defaults=(15.0, -15.0)), ('document', 'b0e07aa41caae04b07a8e8217cc96020'))
295295
paddle.fluid.layers.huber_loss (ArgSpec(args=['input', 'label', 'delta'], varargs=None, keywords=None, defaults=None), ('document', '9d93ee81f7a3e526d68bb280bc695d6c'))
296296
paddle.fluid.layers.kldiv_loss (ArgSpec(args=['x', 'target', 'reduction', 'name'], varargs=None, keywords=None, defaults=('mean', None)), ('document', '45f3ebbcb766fca84cb2fe6307086573'))

paddle/fluid/operators/prroi_pool_op.cc

+4-17
Original file line numberDiff line numberDiff line change
@@ -43,12 +43,6 @@ class PRROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
4343
"(Tensor), "
4444
"the output of PRROIPoolOp is a 4-D Tensor with shape "
4545
"(num_rois, output_channels, pooled_h, pooled_w).");
46-
AddAttr<int>(
47-
"output_channels",
48-
"(int), "
49-
"the number of channels of the output feature map. "
50-
"For a task of C classes of objects, output_channels should be "
51-
"(C + 1) for classification only.");
5246
AddAttr<float>("spatial_scale",
5347
"(float, default 1.0), "
5448
"Multiplicative spatial scale factor "
@@ -100,28 +94,18 @@ class PRROIPoolOp : public framework::OperatorWithKernel {
10094

10195
int pooled_height = ctx->Attrs().Get<int>("pooled_height");
10296
int pooled_width = ctx->Attrs().Get<int>("pooled_width");
103-
int output_channels = ctx->Attrs().Get<int>("output_channels");
10497
float spatial_scale = ctx->Attrs().Get<float>("spatial_scale");
10598

106-
PADDLE_ENFORCE_EQ(
107-
input_dims[1], output_channels * pooled_height * pooled_width,
108-
"the channel of X(%d) should be equal to the product of "
109-
"output_channels(%d), pooled_height(%d) and pooled_width(%d)",
110-
input_dims[1], output_channels, pooled_height, pooled_width);
111-
11299
PADDLE_ENFORCE_GT(pooled_height, 0,
113100
"The pooled output height must be greater than 0");
114101
PADDLE_ENFORCE_GT(pooled_width, 0,
115102
"The pooled output width must be greater than 0");
116-
PADDLE_ENFORCE_GT(output_channels, 1,
117-
"The pooled output channels must greater than 1");
118103
PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
119104
"The spatial scale must greater than 0.");
120105

121106
auto out_dims = input_dims;
122107
out_dims[0] = rois_dims[0];
123-
out_dims[1] =
124-
output_channels; // input_dims[1] / (pooled_height * pooled_width);
108+
out_dims[1] = input_dims[1];
125109
out_dims[2] = pooled_height;
126110
out_dims[3] = pooled_width;
127111
ctx->SetOutputDim("Out", out_dims);
@@ -145,6 +129,7 @@ class PRROIPoolGradOp : public framework::OperatorWithKernel {
145129
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
146130
"The gradient of X should not be null.");
147131
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
132+
ctx->SetOutputDim(framework::GradVarName("ROIs"), ctx->GetInputDim("ROIs"));
148133
}
149134

150135
protected:
@@ -164,9 +149,11 @@ class PRROIPoolGradDescMaker : public framework::SingleGradOpDescMaker {
164149
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
165150
op->SetType("prroi_pool_grad");
166151
op->SetInput("X", Input("X"));
152+
op->SetInput("Out", Output("Out"));
167153
op->SetInput("ROIs", Input("ROIs"));
168154
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
169155
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
156+
op->SetOutput(framework::GradVarName("ROIs"), InputGrad("ROIs"));
170157
op->SetAttrMap(Attrs());
171158
return op;
172159
}

paddle/fluid/operators/prroi_pool_op.cu

+37-17
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ DEVICE void PrRoIPoolingDistributeDiffCUDA(T* diff, const T top_diff,
4040
}
4141
}
4242

43+
template <typename T>
44+
DEVICE void GPUAccumulateRois(T* offset, T data) {
45+
paddle::platform::CudaAtomicAdd(offset, data);
46+
}
47+
4348
template <typename T>
4449
__global__ void GPUPRROIPoolForward(
4550
const int nthreads, const T* input_data, const T* input_rois,
@@ -78,7 +83,7 @@ __global__ void GPUPRROIPoolForward(
7883
T win_end_h = win_start_h + bin_size_h;
7984

8085
T win_size = max(static_cast<T>(0.0), bin_size_w * bin_size_h);
81-
int input_channel = (c * pooled_height + ph) * pooled_width + pw;
86+
int input_channel = c;
8287
const T* offset_input_data =
8388
input_data +
8489
(roi_batch_id * input_channels + input_channel) * height * width;
@@ -110,10 +115,12 @@ __global__ void GPUPRROIPoolForward(
110115

111116
template <typename T>
112117
__global__ void GPUPRROIPoolBackward(
113-
const int nthreads, const T* input_rois, const T* output_grad_data,
114-
const float spatial_scale, const int input_channels, const int height,
115-
const int width, const int output_channels, const int pooled_height,
116-
const int pooled_width, const int* rois_batch_id_data, T* input_grad_data) {
118+
const int nthreads, const T* in_data, const T* input_rois,
119+
const T* output_grad_data, const float spatial_scale,
120+
const int input_channels, const int height, const int width,
121+
const int output_channels, const int pooled_height, const int pooled_width,
122+
const int* rois_batch_id_data, T* input_grad_data, const T* out_data,
123+
T* input_roi_grad_data) {
117124
int index = blockIdx.x * blockDim.x + threadIdx.x;
118125
int offset = blockDim.x * gridDim.x;
119126
for (int i = index; i < nthreads; i += offset) {
@@ -125,7 +132,7 @@ __global__ void GPUPRROIPoolBackward(
125132

126133
// set roi_batch_id
127134
int roi_batch_id = rois_batch_id_data[n];
128-
int input_channel = (c * pooled_height + ph) * pooled_width + pw;
135+
int input_channel = c;
129136
int input_offset =
130137
(roi_batch_id * input_channels + input_channel) * height * width;
131138
T* offset_input_grad_data = input_grad_data + input_offset;
@@ -137,6 +144,7 @@ __global__ void GPUPRROIPoolBackward(
137144
T roi_start_h = static_cast<T>(offset_input_rois[1]) * spatial_scale;
138145
T roi_end_w = static_cast<T>(offset_input_rois[2]) * spatial_scale;
139146
T roi_end_h = static_cast<T>(offset_input_rois[3]) * spatial_scale;
147+
T* offset_input_roi_grad_data = input_roi_grad_data + n * 4;
140148

141149
T roi_width = max(roi_end_w - roi_start_w, static_cast<T>(0.0));
142150
T roi_height = max(roi_end_h - roi_start_h, static_cast<T>(0.0));
@@ -171,6 +179,16 @@ __global__ void GPUPRROIPoolBackward(
171179
height, width, PrRoIPoolingDistributeDiffCUDA<T>);
172180
}
173181
}
182+
183+
const T* offset_out_data = out_data + i;
184+
const T* offset_in_data = in_data + input_offset;
185+
PrRoIPoolingCoorBackward(
186+
s_w, e_w, s_h, e_h, width, height, win_start_w, win_start_h, win_end_w,
187+
win_end_h, pw, ph, pooled_width, pooled_height, win_size, spatial_scale,
188+
offset_in_data, offset_out_data, offset_input_grad_data,
189+
offset_input_roi_grad_data, GPUAccumulateRois<T>,
190+
[](const T x, const T y) { return max(x, y); },
191+
[](const T x, const T y) { return min(x, y); });
174192
}
175193
}
176194

@@ -184,20 +202,15 @@ class GPUPRROIPoolOpKernel : public framework::OpKernel<T> {
184202

185203
auto pooled_height = ctx.Attr<int>("pooled_height");
186204
auto pooled_width = ctx.Attr<int>("pooled_width");
187-
auto output_channels = ctx.Attr<int>("output_channels");
188205
auto spatial_scale = ctx.Attr<float>("spatial_scale");
189206

190207
auto in_dims = in->dims();
191208
int batch_size = in_dims[0];
192209
int input_channels = in_dims[1];
210+
auto output_channels = input_channels;
193211
int height = in_dims[2];
194212
int width = in_dims[3];
195213

196-
PADDLE_ENFORCE_EQ(input_channels,
197-
output_channels * pooled_height * pooled_width,
198-
"the channels of input X should equal the product of "
199-
"output_channels x pooled_height x pooled_width");
200-
201214
int rois_num = rois->dims()[0];
202215
if (rois_num == 0) return;
203216

@@ -245,17 +258,20 @@ class GPUPRROIPoolGradOpKernel : public framework::OpKernel<T> {
245258
void Compute(const framework::ExecutionContext& ctx) const override {
246259
auto* in = ctx.Input<Tensor>("X");
247260
auto* rois = ctx.Input<LoDTensor>("ROIs");
261+
auto* out = ctx.Input<framework::Tensor>("Out");
248262

249263
auto* output_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
250264
auto* input_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
265+
auto* input_roi_grad =
266+
ctx.Output<LoDTensor>(framework::GradVarName("ROIs"));
251267

252268
auto pooled_height = ctx.Attr<int>("pooled_height");
253269
auto pooled_width = ctx.Attr<int>("pooled_width");
254-
auto output_channels = ctx.Attr<int>("output_channels");
255270
auto spatial_scale = ctx.Attr<float>("spatial_scale");
256271

257272
int rois_num = rois->dims()[0];
258273
int input_channels = in->dims()[1];
274+
auto output_channels = input_channels;
259275
int height = in->dims()[2];
260276
int width = in->dims()[3];
261277

@@ -280,6 +296,8 @@ class GPUPRROIPoolGradOpKernel : public framework::OpKernel<T> {
280296
input_grad->mutable_data<T>(ctx.GetPlace());
281297
math::SetConstant<DeviceContext, T> set_zero;
282298
set_zero(ctx.cuda_device_context(), input_grad, static_cast<T>(0));
299+
input_roi_grad->mutable_data<T>(ctx.GetPlace());
300+
set_zero(ctx.cuda_device_context(), input_roi_grad, static_cast<T>(0));
283301

284302
int output_grad_size = output_grad->numel();
285303
int blocks = NumBlocks(output_grad_size);
@@ -288,10 +306,12 @@ class GPUPRROIPoolGradOpKernel : public framework::OpKernel<T> {
288306
if (output_grad_size > 0) {
289307
GPUPRROIPoolBackward<
290308
T><<<blocks, threads, 0, ctx.cuda_device_context().stream()>>>(
291-
output_grad_size, rois->data<T>(), output_grad->data<T>(),
292-
spatial_scale, input_channels, height, width, output_channels,
293-
pooled_height, pooled_width, rois_batch_id_list_gpu.data<int>(),
294-
input_grad->mutable_data<T>(ctx.GetPlace()));
309+
output_grad_size, in->data<T>(), rois->data<T>(),
310+
output_grad->data<T>(), spatial_scale, input_channels, height,
311+
width, output_channels, pooled_height, pooled_width,
312+
rois_batch_id_list_gpu.data<int>(),
313+
input_grad->mutable_data<T>(ctx.GetPlace()), out->data<T>(),
314+
input_roi_grad->mutable_data<T>(ctx.GetPlace()));
295315
}
296316
}
297317
}

0 commit comments

Comments
 (0)