Skip to content

Commit

Permalink
Bug fix in depthwise conv (pytorch#395)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#395

Depthwise conv was crashing when output image height or width was less than the pad size.

Fixes: pytorch/pytorch#41406

Reviewed By: jianyuh

Differential Revision: D22645420

fbshipit-source-id: 9bb4822410b674d548a893e817fe4f007beaa758
  • Loading branch information
dskhudia authored and facebook-github-bot committed Jul 21, 2020
1 parent 503e033 commit 139c6f2
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/FbgemmI8Depthwise2DAvx2-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ static ALWAYS_INLINE void depthwise_2d_(
int h = 0;
int w = 0;

for (h = h_begin; h < PAD_T; ++h) {
for (w = w_begin; w < PAD_L; ++w) {
for (h = h_begin; h < std::min(PAD_T, h_end); ++h) {
for (w = w_begin; w < std::min(PAD_L, w_end); ++w) {
depthwise_2d_kernel_<
S,
FUSE_RELU,
Expand Down Expand Up @@ -285,7 +285,7 @@ static ALWAYS_INLINE void depthwise_2d_(
// h_in + S - H <= PAD_B * (1 - stride_h) + 1 + (1 - stride_h) * stride_h
// <= -PAD_B + 1 - stride_h <= 0
for (; h < std::min(H_OUT - PAD_B - stride_h + 1, h_end); ++h) {
for (w = w_begin; w < PAD_L; ++w) {
for (w = w_begin; w < std::min(PAD_L, w_end); ++w) {
depthwise_2d_kernel_<
S,
FUSE_RELU,
Expand Down Expand Up @@ -399,7 +399,7 @@ static ALWAYS_INLINE void depthwise_2d_(
}

for (; h < h_end; ++h) {
for (w = w_begin; w < PAD_L; ++w) {
for (w = w_begin; w < std::min(PAD_L, w_end); ++w) {
depthwise_2d_kernel_<
S,
FUSE_RELU,
Expand Down
4 changes: 4 additions & 0 deletions test/I8DepthwiseTest.cc
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ static vector<vector<int>> shapes = {
// { 100, 544, 14, 14, 2, 3 },

{ 1, 8, 4, 4, 1, 3 },
// Tests for the shapes when OH/OW is less than padding
{ 1, 72, 1, 1, 2, 5 },
{ 1, 72, 7, 1, 2, 5 },
{ 1, 72, 1, 7, 2, 5 },
};

static vector<vector<int>> shapes_3d = {
Expand Down

0 comments on commit 139c6f2

Please sign in to comment.