Skip to content

Commit

Permalink
fix bug for pool2d and pool2d_grad when kernel_size > in_h/in_w in xpu (
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyk0314 authored Apr 19, 2023
1 parent 28492bf commit b1d3ec1
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
8 changes: 4 additions & 4 deletions paddle/phi/kernels/xpu/pool_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,11 @@ void Pool2dGradKernel(const Context& ctx,

PADDLE_ENFORCE_XDNN_SUCCESS(r, "adaptive_pool2d_grad");
} else {
if (kernel_size[0] > in_h) {
kernel_size[0] = in_h;
if (kernel_size[0] > (in_h + paddings[0] + paddings[1])) {
kernel_size[0] = in_h + paddings[0] + paddings[1];
}
if (kernel_size[1] > in_w) {
kernel_size[1] = in_w;
if (kernel_size[1] > (in_w + paddings[2] + paddings[3])) {
kernel_size[1] = in_w + paddings[2] + paddings[3];
}
if (pooling_type == "max") {
// TODO(zhanghuan05) to bind max_pool2d_grad_indices xpu api
Expand Down
8 changes: 4 additions & 4 deletions paddle/phi/kernels/xpu/pool_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,11 @@ void Pool2dKernel(const Context& ctx,
int* index_data = nullptr;
int r = xpu::Error_t::SUCCESS;
if (!adaptive) {
if (kernel_size[0] > in_h) {
kernel_size[0] = in_h;
if (kernel_size[0] > (in_h + paddings[0] + paddings[1])) {
kernel_size[0] = in_h + paddings[0] + paddings[1];
}
if (kernel_size[1] > in_w) {
kernel_size[1] = in_w;
if (kernel_size[1] > (in_w + paddings[2] + paddings[3])) {
kernel_size[1] = in_w + paddings[2] + paddings[3];
}
if (pooling_type == "max") {
r = xpu::max_pool2d<XPUType>(
Expand Down

0 comments on commit b1d3ec1

Please sign in to comment.