Skip to content

Commit

Permalink
unified conv to call dw conv with 2 oc per g (pytorch#360)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#360

Unified conv calls optimized depth-wise conv with 2 output channels per group. This diff also enables per-group quantization.

Reviewed By: dskhudia

Differential Revision: D20984689

fbshipit-source-id: 6b60afb6a6ae819e104ba59fa5e5fa1ea1429e65
  • Loading branch information
jspark1105 authored and facebook-github-bot committed May 6, 2020
1 parent dfbb13f commit 40f530d
Show file tree
Hide file tree
Showing 7 changed files with 298 additions and 46 deletions.
3 changes: 2 additions & 1 deletion bench/ConvUnifiedBenchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ vector<conv_param_t<2>> shapes_2d = {
// DW
conv_param_t<>(1, 272, 272, {47, 125}, 272, {3, 3},
{1, 1}, {1, 1, 1, 1}),
conv_param_t<>(1, 128, 256, {32, 100}, 128, {3, 3},
{1, 1}, {1, 1, 1, 1}),
// Pointwise
conv_param_t<>(1, 128, 128, {56, 56}, 1, {1, 1},
{1, 1}, {0, 0, 0, 0})

};

// 3D conv shapes
Expand Down
87 changes: 69 additions & 18 deletions src/FbgemmConv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ bool takeDepthWiseFastPath(const conv_param_t<SPATIAL_DIM>& conv_p) {
// Note: Depthwise convolutions (both 2D and 3D) are optimized for the most
// common case.
return std::is_same<ACC_T, std::int32_t>::value && conv_p.G == conv_p.IC &&
conv_p.G == conv_p.OC && conv_p.G % 8 == 0 &&
(conv_p.G == conv_p.OC || conv_p.G * 2 == conv_p.OC) &&
conv_p.G % 8 == 0 &&
std::all_of(
conv_p.stride.begin(),
conv_p.stride.end(),
Expand Down Expand Up @@ -101,37 +102,63 @@ int fbgemmConv(
"For depthwise, only requantized output is supported");

if (processOutputType::QGRANType == QuantizationGranularity::TENSOR) {
depthwise_3x3x3_pad_1(
depthwise_3x3x3_pad_1<QuantizationGranularity::TENSOR>(
conv_p.MB, // mini batch
conv_p.IN_DIM[0], // T
conv_p.IN_DIM[1], // H
conv_p.IN_DIM[2], // W
conv_p.IC, // input channels
conv_p.OC, // output channels
conv_p.stride[0], // stride_t
conv_p.stride[1], // stride_h
conv_p.stride[2], // stride_w
outProcess.getAZeroPoint(),
activations,
B_zero_point[0],
B_zero_point,
*(packed_weights.getPackedWForDepthwise()),
C_multiplier[0],
C_multiplier,
outProcess.getCZeroPoint(),
out,
outProcess.getColOffsets(),
outProcess.getBias(),
outProcess.RELU_FUSED, // fuse_relu
act_times_w_scale ? act_times_w_scale[0] : 1.0f,
act_times_w_scale,
thread_id,
num_threads);
} else if (
processOutputType::QGRANType ==
QuantizationGranularity::OUT_CHANNEL ||
processOutputType::QGRANType == QuantizationGranularity::GROUP) {
depthwise_3x3x3_per_channel_quantization_pad_1(
depthwise_3x3x3_pad_1<QuantizationGranularity::GROUP>(
conv_p.MB, // mini batch
conv_p.IN_DIM[0], // T
conv_p.IN_DIM[1], // H
conv_p.IN_DIM[2], // W
conv_p.IC, // input channels
conv_p.OC, // output channels
conv_p.stride[0], // stride_t
conv_p.stride[1], // stride_h
conv_p.stride[2], // stride_w
outProcess.getAZeroPoint(),
activations,
B_zero_point,
*(packed_weights.getPackedWForDepthwise()),
C_multiplier,
outProcess.getCZeroPoint(),
out,
outProcess.getColOffsets(),
outProcess.getBias(),
outProcess.RELU_FUSED, // fuse_relu
act_times_w_scale, // act_scale * weight_scale
thread_id,
num_threads);
} else if (
processOutputType::QGRANType ==
QuantizationGranularity::OUT_CHANNEL) {
depthwise_3x3x3_pad_1<QuantizationGranularity::OUT_CHANNEL>(
conv_p.MB, // mini batch
conv_p.IN_DIM[0], // T
conv_p.IN_DIM[1], // H
conv_p.IN_DIM[2], // W
conv_p.IC, // input channels
conv_p.OC, // output channels
conv_p.stride[0], // stride_t
conv_p.stride[1], // stride_h
Expand All @@ -146,7 +173,7 @@ int fbgemmConv(
outProcess.getColOffsets(),
outProcess.getBias(),
outProcess.RELU_FUSED, // fuse_relu
outProcess.getActWScale(), // act_scale * weight_scale
act_times_w_scale, // act_scale * weight_scale
thread_id,
num_threads);
} else {
Expand All @@ -157,35 +184,59 @@ int fbgemmConv(
}
} else if (SPATIAL_DIM == 2) {
if (processOutputType::QGRANType == QuantizationGranularity::TENSOR) {
depthwise_2d_same_pad(
depthwise_2d_same_pad<QuantizationGranularity::TENSOR>(
conv_p.MB, // mini batch
conv_p.IN_DIM[0], // H
conv_p.IN_DIM[1], // W
conv_p.IC, // input channels
conv_p.OC, // output channels
conv_p.stride[0], // stride_h
conv_p.stride[1], // stride_w
outProcess.getAZeroPoint(),
activations,
B_zero_point[0],
B_zero_point,
*(packed_weights.getPackedWForDepthwise()),
C_multiplier[0],
C_multiplier,
outProcess.getCZeroPoint(),
out,
outProcess.getColOffsets(),
outProcess.getBias(),
outProcess.RELU_FUSED, // fuse_relu
act_times_w_scale ? act_times_w_scale[0] : 1.0f,
act_times_w_scale,
thread_id,
num_threads);
} else if (
processOutputType::QGRANType ==
QuantizationGranularity::OUT_CHANNEL ||
processOutputType::QGRANType == QuantizationGranularity::GROUP) {
// The number of channels == groups for depthwise convolutions
depthwise_2d_per_channel_quantization_same_pad(
depthwise_2d_same_pad<QuantizationGranularity::GROUP>(
conv_p.MB, // mini batch
conv_p.IN_DIM[0], // H
conv_p.IN_DIM[1], // W
conv_p.IC, // input channels
conv_p.OC, // output channels
conv_p.stride[0], // stride_h
conv_p.stride[1], // stride_w
outProcess.getAZeroPoint(),
activations,
B_zero_point,
*(packed_weights.getPackedWForDepthwise()),
C_multiplier,
outProcess.getCZeroPoint(),
out,
outProcess.getColOffsets(),
outProcess.getBias(),
outProcess.RELU_FUSED, // fuse_relu
act_times_w_scale, // act_scale * weight_scale
thread_id,
num_threads);
} else if (
processOutputType::QGRANType ==
QuantizationGranularity::OUT_CHANNEL) {
// The number of input channels == groups for depthwise convolutions
depthwise_2d_same_pad<QuantizationGranularity::OUT_CHANNEL>(
conv_p.MB, // mini batch
conv_p.IN_DIM[0], // H
conv_p.IN_DIM[1], // W
conv_p.IC, // input channels
conv_p.OC, // output channels
conv_p.stride[0], // stride_h
conv_p.stride[1], // stride_w
Expand All @@ -199,7 +250,7 @@ int fbgemmConv(
outProcess.getColOffsets(),
outProcess.getBias(),
outProcess.RELU_FUSED, // fuse_relu
outProcess.getActWScale(), // act_scale * weight_scale
act_times_w_scale, // act_scale * weight_scale
thread_id,
num_threads);
} else {
Expand Down
48 changes: 48 additions & 0 deletions src/FbgemmI8Depthwise3DAvx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1127,6 +1127,54 @@ template FBGEMM_API void depthwise_3x3x3_pad_1<QuantizationGranularity::TENSOR>(
int thread_id,
int num_threads);

template FBGEMM_API void depthwise_3x3x3_pad_1<QuantizationGranularity::GROUP>(
int N,
int T,
int H,
int W,
int IC,
int OC,
int stride_t,
int stride_h,
int stride_w,
int32_t A_zero_point,
const uint8_t* A,
const int32_t* B_zero_point,
const PackedDepthWiseConvMatrix& B,
const float* C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
const int32_t* bias,
bool fuse_relu,
const float* act_times_w_scale,
int thread_id,
int num_threads);

template FBGEMM_API void depthwise_3x3x3_pad_1<QuantizationGranularity::GROUP>(
int N,
int T,
int H,
int W,
int IC,
int OC,
int stride_t,
int stride_h,
int stride_w,
int32_t A_zero_point,
const uint8_t* A,
const int32_t* B_zero_point,
const PackedDepthWiseConvMatrix& B,
const float* C_multiplier,
int32_t C_zero_point,
uint8_t* C,
const int32_t* col_offsets,
const float* bias,
bool fuse_relu,
const float* act_times_w_scale,
int thread_id,
int num_threads);

template FBGEMM_API void
depthwise_3x3x3_pad_1<QuantizationGranularity::OUT_CHANNEL>(
int N,
Expand Down
Loading

0 comments on commit 40f530d

Please sign in to comment.