Skip to content

Commit

Permalink
Support pad for GeometryPoolGrad
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaying committed Jun 22, 2021
1 parent f6422c3 commit 2733909
Showing 1 changed file with 18 additions and 25 deletions.
43 changes: 18 additions & 25 deletions source/geometry/GeometryPoolGrad.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,19 +47,16 @@ class GeometryPoolGrad : public GeometryComputer {
stride_h = ih;
pad_w = 0;
pad_h = 0;
}

if (parameter->padType() == PoolPadType_SAME) {
int pad_w_total = (ow - 1) * stride_w + kernel_w - iw;
int pad_h_total = (oh - 1) * stride_h + kernel_h - ih;
pad_w = pad_w_total > 0 ? pad_w_total / 2 : 0;
pad_h = pad_h_total > 0 ? pad_h_total / 2 : 0;
} else if (parameter->padType() == PoolPadType_VALID) {
pad_w = 0;
pad_h = 0;
} else {
MNN_PRINT("Pool padtype not supported!\n");
return false;
if (parameter->padType() == PoolPadType_SAME) {
int pad_w_total = (ow - 1) * stride_w + kernel_w - iw;
int pad_h_total = (oh - 1) * stride_h + kernel_h - ih;
pad_w = pad_w_total > 0 ? pad_w_total / 2 : 0;
pad_h = pad_h_total > 0 ? pad_h_total / 2 : 0;
} else if (parameter->padType() == PoolPadType_VALID) {
pad_w = 0;
pad_h = 0;
}
}

std::vector<std::shared_ptr<Tensor>> originSplit;
Expand Down Expand Up @@ -304,21 +301,17 @@ class GeometryPoolGrad : public GeometryComputer {
stride_h = ih;
pad_w = 0;
pad_h = 0;
}

if (parameter->padType() == PoolPadType_SAME) {
int pad_w_total = (ow - 1) * stride_w + kernel_w - iw;
int pad_h_total = (oh - 1) * stride_h + kernel_h - ih;
pad_w = pad_w_total > 0 ? pad_w_total / 2 : 0;
pad_h = pad_h_total > 0 ? pad_h_total / 2 : 0;
} else if (parameter->padType() == PoolPadType_VALID) {
pad_w = 0;
pad_h = 0;
} else {
MNN_PRINT("Pool padtype not supported!\n");
return false;
if (parameter->padType() == PoolPadType_SAME) {
int pad_w_total = (ow - 1) * stride_w + kernel_w - iw;
int pad_h_total = (oh - 1) * stride_h + kernel_h - ih;
pad_w = pad_w_total > 0 ? pad_w_total / 2 : 0;
pad_h = pad_h_total > 0 ? pad_h_total / 2 : 0;
} else if (parameter->padType() == PoolPadType_VALID) {
pad_w = 0;
pad_h = 0;
}
}

std::shared_ptr<Tensor> inpDifTrans;

inpDifTrans.reset(new Tensor);
Expand Down

0 comments on commit 2733909

Please sign in to comment.