Skip to content

Commit

Permalink
Improve im2col for certain cases (pytorch#715)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#715

Copy across pixels along width dim for im2col. This should help convolution with small number of input channels.

 Copy across pixels of input width if we can. We can only do this
 if the following conditions are met.
  1) If the number of groups is 1. For number of groups > 1, im2col
     doesn't copy data across groups.
  2) If dilation is 1. For dilation > 1, copying from input
     across channels is not sequential.
  3) For copy from the last channel (end of filter or
     end of image width) for the current filter,
     only copy if we have enough in the current channel.

Reviewed By: jspark1105

Differential Revision: D31227743

fbshipit-source-id: f331071e6939781e092ed912f6d69247f93aafc9
  • Loading branch information
dskhudia authored and facebook-github-bot committed Oct 14, 2021
1 parent 4e3241f commit 0375f13
Showing 1 changed file with 29 additions and 5 deletions.
34 changes: 29 additions & 5 deletions src/PackAWithIm2Col.cc
Original file line number Diff line number Diff line change
Expand Up @@ -573,14 +573,38 @@ void PackAWithIm2Col<T, accT, SPATIAL_DIM>::pack(const block_type_t& block) {
a_zero_pt_,
sizeof(T) * (j_blk_end - j_blk_start));
} else {
int chn_start_idx = j_blk_start % ic_per_group;
int src_offset =
((n * conv_p_.IN_DIM[0] + h_in) * conv_p_.IN_DIM[1] + w_in) *
conv_p_.IC + g * ic_per_group + chn_start_idx;
// fast path
// Copy across pixels of input width if we can. We can only do this
// if the following conditions are met. 1) If the number of groups
// is 1. For number of groups > 1, im2col
// doesn't copy data across groups.
// 2) If dilation is 1. For dilation > 1, copying from input
// across channels is not sequential.
// 3) For copy from the last channel (end of filter or
// end of image width) for the current filter,
// only copy if we have enough in the current channel.
//
if (conv_p_.G == 1 && conv_p_.dilation[1] == 1 &&
((s < (conv_p_.K[1] - 1) && w_in < (conv_p_.IN_DIM[1] - 1)) ||
((chn_start_idx + block.col_size) <= ic_per_group))) {
// left edge adjustment with s
j_blk_end = std::min(
(j_blk_id + conv_p_.K[1] - s) * ic_per_group,
block.col_start + block.col_size);
// right edge adjustment with w_in
j_blk_end = std::min(
(j_blk_id + conv_p_.IN_DIM[1] - w_in) * ic_per_group,
j_blk_end);
j += j_blk_end - j_blk_start - ic_per_group;
}
std::memcpy(
out + (i - block.row_start) * BaseType::blockColSize() +
j_blk_start - block.col_start,
sdata_ +
((n * conv_p_.IN_DIM[0] + h_in) * conv_p_.IN_DIM[1] +
w_in) *
conv_p_.IC +
g * ic_per_group + (j_blk_start % ic_per_group),
sdata_ + src_offset,
sizeof(T) * (j_blk_end - j_blk_start));
}
}
Expand Down

0 comments on commit 0375f13

Please sign in to comment.