Skip to content

Commit

Permalink
add direct conv to fbgemm (pytorch#901)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#901

Integrate direct convolution code path into Fbgemm.
The direct convolution entrance is integrated through FbgemmConv.cc.
ConvFastPath will automatically determine if this case can use our direct convolution branch:
 - if spatial_dim=2, kh=2, kw<=6, stride=1 or 2, padding=0

Reviewed By: jspark1105

Differential Revision: D32273614

fbshipit-source-id: 16255395b4e14fae10c129ad98dd4445a5106989
  • Loading branch information
jiyuanzFB authored and facebook-github-bot committed Feb 8, 2022
1 parent 969941d commit f0f6ca7
Show file tree
Hide file tree
Showing 11 changed files with 1,639 additions and 35 deletions.
2 changes: 2 additions & 0 deletions defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def get_fbgemm_generic_srcs(with_base = False):
"src/PackMatrix.cc",
"src/PackWeightMatrixForGConv.cc",
"src/PackWeightsForConv.cc",
"src/PackWeightsForDirectConv.cc",
"src/QuantUtils.cc",
"src/RowWiseSparseAdagradFused.cc",
"src/SparseAdagrad.cc",
Expand All @@ -61,6 +62,7 @@ def get_fbgemm_public_headers():
"include/fbgemm/FbgemmFPCommon.h",
"include/fbgemm/FbgemmI64.h",
"include/fbgemm/FbgemmI8DepthwiseAvx2.h",
"include/fbgemm/FbgemmI8DirectconvAvx2.h",
"include/fbgemm/FbgemmI8Spmdm.h",
"include/fbgemm/FbgemmPackMatrixB.h",
"include/fbgemm/FbgemmSparse.h",
Expand Down
23 changes: 23 additions & 0 deletions include/fbgemm/Fbgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "./FbgemmBuild.h"
#include "./FbgemmEmbedding.h"
#include "./FbgemmI8DepthwiseAvx2.h"
#include "./FbgemmI8DirectconvAvx2.h"
#include "./FbgemmI8Spmdm.h"
#include "./QuantUtilsAvx2.h"
#include "./Types.h"
Expand Down Expand Up @@ -601,6 +602,10 @@ class FBGEMM_API PackWeightsForConv {
return W_dw_packed_;
}

std::shared_ptr<PackedDirectConvMatrix> getPackedWForDirectconv() {
return W_dc_packed_;
}

std::shared_ptr<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>
getPackedWForGroupwise() {
return W_gconv_packed_;
Expand Down Expand Up @@ -649,6 +654,8 @@ class FBGEMM_API PackWeightsForConv {
std::shared_ptr<PackBMatrix<T, accT>> W_im2col_packed_;
// Packed weights if we use depthwise convolution implementation
std::shared_ptr<PackedDepthWiseConvMatrix> W_dw_packed_;
// Packed weights if we use direct convolution implementation
std::shared_ptr<PackedDirectConvMatrix> W_dc_packed_;
// Packed weights if we use groupwise (small channels per group) convolution
// implementation
std::shared_ptr<PackWeightMatrixForGConv<T, accT, SPATIAL_DIM>>
Expand Down Expand Up @@ -1384,6 +1391,22 @@ FBGEMM_API void fbgemmGroupwiseConv(
int thread_id,
int num_threads);

template <
int SPATIAL_DIM,
QuantizationGranularity Q_GRAN,
bool FUSE_RELU,
typename BIAS_TYPE = std::int32_t>
FBGEMM_API void fbgemmDirectConv(
const conv_param_t<SPATIAL_DIM>& conv_p,
const uint8_t* Aint8,
PackedDirectConvMatrix& Bint8_tr,
uint8_t* C,
int32_t* C_buffer,
const ReQuantizeOutput<FUSE_RELU, Q_GRAN, BIAS_TYPE>& outProcess,
const BIAS_TYPE* bias,
int thread_id,
int num_threads);

/**
* @return Size of row offset buffer in number of elements needed for
* fbgemmGroupwiseConv
Expand Down
60 changes: 60 additions & 0 deletions include/fbgemm/FbgemmI8DirectconvAvx2.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#pragma once

#include <array>
#include <cstdint>
#include <vector>
#include "fbgemm/ConvUtils.h"
#include "fbgemm/FbgemmBuild.h"
#include "fbgemm/UtilsAvx2.h"

namespace fbgemm {

class FBGEMM_API PackedDirectConvMatrix {
public:
/**
* @param IC the number of input channels
* @param OC the number of output channels
* @param kernel_prod the product of all kernels. For example, kernel_prod =
* 9 for 3x3 conv, and 27 for 3x3x3 conv.
* @param smat the source unpacked weight in GRS layout
*/
PackedDirectConvMatrix(
int IC_per_G,
int OC_per_G,
int filter_prod,
const std::int8_t* smat);
virtual ~PackedDirectConvMatrix();

const std::int8_t* PackedMat() const {
return pmat_;
}

const bool& is_first_call() const {
return first_call;
}

/**
compute the column offsets of the weight matrix.
output of this function is the col_offsets vector
col_offses dimension is the same as conv_p.OUT_DIM
*/
template <int kSpatialDim>
FBGEMM_API void col_offsets_with_zero_pt_s8acc32_DirectConvT(
const fbgemm::conv_param_t<kSpatialDim>& conv_p,
std::int32_t* B_zero_point,
std::vector<int32_t>& col_offsets,
int ncols_per_quant_group);

private:
std::int8_t* pmat_; /** packed weight */
bool first_call{true};
};

} // namespace fbgemm
3 changes: 2 additions & 1 deletion include/fbgemm/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,8 @@ enum class optimized_conv_t {
groupwise,
pointwise,
fastpath1d,
im2col
im2col,
directconv
};

/**
Expand Down
22 changes: 11 additions & 11 deletions src/DirectConv.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,6 @@ namespace x86 = asmjit::x86;
*/
void initCRegs(x86::Emitter* a, int rowRegs, int colRegs);

static asmjit::JitRuntime& runtime() {
static asmjit::JitRuntime rt; //< JIT Runtime for asmjit,
// depents on other static
// variables. Required to prevent
// initialization order fiasco
return rt;
}

template <typename TA, typename TB, typename TC, typename accT>
class DirectConvCodeGenBase {
public:
Expand Down Expand Up @@ -164,9 +156,8 @@ class DirectConvCodeGenBase {
* store that into the code cache.
*/
template <inst_set_t instSet>
jit_micro_kernel_fp_convT getOrCreateDirectConvTrans(
bool accum,
int32_t stride);
jit_micro_kernel_fp_convT
getOrCreateDirectConvTrans(bool accum, int32_t stride, int32_t numColRegs);

/**
* @brief Generate instructions for computing block in the rank-k update.
Expand Down Expand Up @@ -205,6 +196,15 @@ class DirectConvCodeGenBase {
x86::Gp C_offset,
int rowRegs,
int colRegs);

private:
static asmjit::JitRuntime& runtime() {
static asmjit::JitRuntime rt; //< JIT Runtime for asmjit,
// depents on other static
// variables. Required to prevent
// initialization order fiasco
return rt;
}
};

template <typename TA, typename TB, typename TC, typename accT>
Expand Down
60 changes: 60 additions & 0 deletions src/FbgemmConv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,35 @@ bool take1DFastPath(const conv_param_t<SPATIAL_DIM>& conv_p) {
return false && !conv_p.transposed;
}

template <int SPATIAL_DIM, typename ACC_T>
bool takeDirectConvPath(const conv_param_t<SPATIAL_DIM>& conv_p) {
// Note: Direct convolutions (2D) are optimized for
// filter size: 2 x 1 to 2 x 6, transposed conv,
// in_channel % 8 == 0, out_channel % 8 == 0
// stride = 1 or 2
// padding = 0 ( non-zero padding will be supported soon)
bool ret = std::is_same<ACC_T, std::int32_t>::value && conv_p.transposed &&
conv_p.G == 1 && conv_p.IC % 8 == 0 && conv_p.OC % 8 == 0 &&
std::all_of(
conv_p.stride.begin(),
conv_p.stride.end(),
[](int i) { return i == 1 || i == 2; }) &&
SPATIAL_DIM == 2 && conv_p.K[SPATIAL_DIM - 2] == 2 &&
conv_p.K[SPATIAL_DIM - 1] <= 6 &&
std::all_of(conv_p.dilation.begin(), conv_p.dilation.end(), [](int i) {
return i == 1;
});

// Check pads: zero padding
for (int i = 0; i < SPATIAL_DIM; ++i) {
if (conv_p.pad[i] != 0) {
ret = false;
}
}

return ret;
}

template <int SPATIAL_DIM, typename ACC_T>
optimized_conv_t ConvFastPath(const conv_param_t<SPATIAL_DIM>& conv_p) {
if (takeDepthWiseFastPath<SPATIAL_DIM, ACC_T>(conv_p)) {
Expand All @@ -75,6 +104,8 @@ optimized_conv_t ConvFastPath(const conv_param_t<SPATIAL_DIM>& conv_p) {
return optimized_conv_t::groupwise;
} else if (takePointWiseFastPath<SPATIAL_DIM>(conv_p)) {
return optimized_conv_t::pointwise;
} else if (takeDirectConvPath<SPATIAL_DIM, ACC_T>(conv_p)) {
return optimized_conv_t::directconv;
} else if (take1DFastPath<SPATIAL_DIM>(conv_p)) {
return optimized_conv_t::fastpath1d;
} else {
Expand Down Expand Up @@ -309,6 +340,26 @@ int fbgemmConv(
blocking_params);
break;
}
case optimized_conv_t::directconv: {
// specialized direct convolution path
// std::cout << "Directconv fast path" << std::endl;
static_assert(
std::is_same<typename processOutputType::outType, std::uint8_t>::
value,
"For directconv 2d, only requantized output is supported");
fbgemmDirectConv<SPATIAL_DIM, processOutputType::QGRANType>(
conv_p,
// Aint8,
activations,
*(packed_weights.getPackedWForDirectconv()),
out,
outBuffer,
outProcess,
outProcess.getBias(),
thread_id,
num_threads);
break;
}
case optimized_conv_t::fastpath1d: {
break;
}
Expand Down Expand Up @@ -416,6 +467,15 @@ template bool takeDepthWiseFastPath<2, std::int16_t>(
template bool takeDepthWiseFastPath<3, std::int16_t>(
const conv_param_t<3>& conv_p);

template bool takeDirectConvPath<2, std::int32_t>(
const conv_param_t<2>& conv_p);
template bool takeDirectConvPath<3, std::int32_t>(
const conv_param_t<3>& conv_p);
template bool takeDirectConvPath<2, std::int16_t>(
const conv_param_t<2>& conv_p);
template bool takeDirectConvPath<3, std::int16_t>(
const conv_param_t<3>& conv_p);

template FBGEMM_API optimized_conv_t
ConvFastPath<1, std::int32_t>(const conv_param_t<1>& conv_p);
template FBGEMM_API optimized_conv_t
Expand Down
51 changes: 28 additions & 23 deletions src/GenerateKernelDirectConvU8S8S32ACC32.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace fbgemm {

namespace x86 = asmjit::x86;
/**
* Generate AVX512 instructions for computing block in the rank-k update of
* Generate AVX256 instructions for computing block in the rank-k update of
* 32-bit Accumulation kernel.
*
* this compute block implements the following register blocking
Expand All @@ -33,7 +33,7 @@ namespace x86 = asmjit::x86;
*/

/**
* Generate AVX512 instructions for storing the C registers back to the memory
* Generate AVX256 instructions for storing the C registers back to the memory
* in 32-bit Accumulation kernel.
*/
template <>
Expand Down Expand Up @@ -421,7 +421,7 @@ DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::getOrCreateDirectConv(
}

/**
* Generate AVX512 instructions for storing the C registers back to the memory
* Generate AVX256 instructions for storing the C registers back to the memory
* in 32-bit Accumulation kernel.
*/
template <>
Expand Down Expand Up @@ -454,7 +454,7 @@ void DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::storeCRegsTrans(
}

/**
* Generate AVX512 instructions for computing block in the rank-k update of
* Generate AVX256 instructions for computing block in the rank-k update of
* 32-bit Accumulation kernel.
The function generates the register blocking code for transposed
Expand Down Expand Up @@ -588,22 +588,32 @@ void DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::
*
*/

/**
* Get or Create the AVX256 instructions for 32-bit Accumulation macro-kernel.
*
*/
template <>
template <inst_set_t instSet>
DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::
jit_micro_kernel_fp_convT
DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::
getOrCreateDirectConvTrans(bool accum, int32_t stride) {
getOrCreateDirectConvTrans(
bool accum,
int32_t stride,
int32_t numColRegs) {
using VecRegT = typename simd_info<instSet>::vec_reg_t;
constexpr int numRegs = simd_info<instSet>::NUM_VEC_REGS;
constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES;
static constexpr int vectorLen = simd_info<instSet>::WIDTH_BYTES;

std::tuple<bool, int, int, int> kernelSig;
constexpr int mRowRegBlockSize = 2;
constexpr int mColRegBlockSize = 6;
constexpr int mRegBlockSize = mRowRegBlockSize * mColRegBlockSize;
constexpr int nRegBlockSize = 8;
constexpr int row_interleave = 4;
// int ichSize = 32;
int mRowRegBlockSize = 2;
int mColRegBlockSize = numColRegs;
int mRegBlockSize = mRowRegBlockSize * mColRegBlockSize;
int nRegBlockSize = 8;
// int nRegBlockSizeMin;
int row_interleave = 4;

kernelSig = std::make_tuple(accum, stride, mRegBlockSize, nRegBlockSize);

Expand All @@ -615,8 +625,7 @@ DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::
#if defined(FBGEMM_LOG_CODE)
// generated code logging
FILE* codeLogfile = fopen(
getCodeLoggingFile<instSet>(
accum, stride, 0, 0, 0, mRegBlockSize, nRegBlockSize)
getCodeLoggingFile<instSet>(accum, stride, mRegBlockSize, nRegBlockSize)
.c_str(),
"w");
asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile);
Expand Down Expand Up @@ -689,9 +698,6 @@ DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::

gen16BitVectorOne<instSet, VecRegT>(a, oneReg);
a->imul(ldcReg, ldcReg, static_cast<asmjit::Imm>(sizeof(int32_t)));
// a->xor_(C_Offset.r32(), C_Offset.r32());

// a->mov(B_pf_saved, B_pf);

int colRegs = maxNRegs;

Expand All @@ -702,7 +708,6 @@ DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::
initCRegs(a, rowRegs, colRegs);

// Loops over K: input channel
// corresponds to the "icb" loop in the pseudo code
a->xor_(kIdx.r32(), kIdx.r32());
a->bind(LoopKLabel);

Expand Down Expand Up @@ -753,13 +758,12 @@ DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::
// B for next block
a->mov(buffer_B, buffer_B_saved);
// increment C for next B block
// ldcReg already multiplied by 4 (sizeof(int32_t))
a->imul(C_offset, ldcReg, static_cast<asmjit::Imm>(stride));
a->imul(
C_offset,
ldcReg,
static_cast<asmjit::Imm>(stride)); // ldcReg already multiplied by 4
a->add(CBase, C_offset);

// a->add(CBase, static_cast<asmjit::Imm>(12*16*4));
// storeCRegs<instSet>(a, 12, 1, C_Offset, ldcReg, accum);

a->cmp(iIdx, i1);
a->jl(LoopMBlocks);
}
Expand Down Expand Up @@ -889,6 +893,7 @@ template DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::
DirectConvCodeGenBase<uint8_t, int8_t, int32_t, int32_t>::
getOrCreateDirectConvTrans<inst_set_t::avx2>(
bool accum,
int32_t stride);
int32_t stride,
int32_t numColRegs);

} // namespace fbgemm
Loading

0 comments on commit f0f6ca7

Please sign in to comment.