diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h index e936409d67..c5f95634f2 100644 --- a/include/fbgemm/Fbgemm.h +++ b/include/fbgemm/Fbgemm.h @@ -20,7 +20,6 @@ #include "./FbgemmI8DepthwiseAvx2.h" #include "./FbgemmI8Spmdm.h" #include "./QuantUtilsAvx2.h" -#include "./QuantUtilsAvx512.h" #include "./Types.h" #include "./Utils.h" diff --git a/src/CodeGenHelpers.h b/src/CodeGenHelpers.h index f2c32fee06..8ba1d49c09 100644 --- a/src/CodeGenHelpers.h +++ b/src/CodeGenHelpers.h @@ -171,6 +171,19 @@ void genU8I8S32FMA( a->vpaddd(cReg, tmpReg, cReg); } +template < + inst_set_t INST_SET, + typename std::enable_if::type = 0> +void genU8I8S32FMA( + x86::Emitter* a, + typename simd_info::vec_reg_t aReg, + typename simd_info::vec_reg_t bReg, + typename simd_info::vec_reg_t cReg, + typename simd_info::vec_reg_t oneReg16Bit, + typename simd_info::vec_reg_t tmpReg) { + a->vpdpbusd(cReg, aReg, bReg); +} + /** * @brief Add 4 consecutive numbers of type uint8 * and emit their sum as 32-bit numbers. @@ -203,6 +216,19 @@ void genU8Sum4( /*a->vpaddd(dest, tmpReg, dest);*/ } +template < + inst_set_t INST_SET, + typename std::enable_if::type = 0> +void genU8Sum4( + x86::Emitter* a, + typename simd_info::vec_reg_t src, + typename simd_info::vec_reg_t dest, + typename simd_info::vec_reg_t oneReg16Bit, + typename simd_info::vec_reg_t tmpReg) { + gen8BitVectorOne(a, tmpReg); + a->vpdpbusd(dest, src, tmpReg); +} + /** * @brief Add 8 consecutive numbers of type uint8 * and emit their sum as 16-bit numbers. diff --git a/src/GroupwiseConv.cc b/src/GroupwiseConv.cc index 4a4783fa1f..f33a8e23bc 100644 --- a/src/GroupwiseConv.cc +++ b/src/GroupwiseConv.cc @@ -19,6 +19,7 @@ #include "./RefImplementations.h" #include "./TransposeUtils.h" #include "fbgemm/Fbgemm.h" +#include "fbgemm/QuantUtilsAvx512.h" namespace fbgemm { @@ -122,7 +123,20 @@ jit_conv_kernel_fp getOrCreateConvKernel( accum); if (cpuinfo_initialize()) { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512VnniSupport()) { + return GenConvKernel::codeCache_ + .getOrCreate(kernelSig, [&]() { + auto genObj = GenConvKernel( + conv_param, + a_zero_point, + needRowOffset, + isTopEdgeIncluded, + isBottomEdgeIncluded, + isTopBottomEdgeSame, + accum); + return genObj.getOrCreate(); + }); + } else if (fbgemmHasAvx512Support()) { return GenConvKernel::codeCache_ .getOrCreate(kernelSig, [&]() { auto genObj = GenConvKernel( @@ -891,7 +905,7 @@ void dispatchOutputProcessing( groups, outProcess.getActWScale()}; if (cpuinfo_initialize()) { - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512Support() || fbgemmHasAvx512VnniSupport()) { if (C_per_G == 2) { if (a_zero_point == 0) { if (b_symmetric) { diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h index a611e59799..96030052a2 100644 --- a/src/GroupwiseConv.h +++ b/src/GroupwiseConv.h @@ -21,6 +21,28 @@ #include "fbgemm/Utils.h" /*#define FBGEMM_LOG_CODE 1*/ +#define GCONV_INST_AVX2_HEADER \ + template \ + typename std::enable_if::type + +#define GCONV_INST_AVX512_AND_VNNI_HEADER \ + template \ + typename std::enable_if< \ + ISET == inst_set_t::avx512 || ISET == inst_set_t::avx512_vnni, \ + void>::type + +#define GCONV_INST_DEF_AVX2_HEADER \ + template \ + template \ + typename std::enable_if::type + +#define GCONV_INST_DEF_AVX512_AND_VNNI_HEADER \ + template \ + template \ + typename std::enable_if< \ + ISET == inst_set_t::avx512 || ISET == inst_set_t::avx512_vnni, \ + void>::type + namespace fbgemm { namespace x86 = asmjit::x86; @@ -218,22 +240,42 @@ class FBGEMM_API GenConvKernel jit_conv_kernel_fp getOrCreate(); - void genForLoadingWeights(x86::Emitter* a) {} + GCONV_INST_AVX2_HEADER genForLoadingWeights(x86::Emitter* a); - void genConstForPermutations(x86::Emitter* a) {} + GCONV_INST_AVX512_AND_VNNI_HEADER genForLoadingWeights(x86::Emitter* a); - void genForTopOrBottomEdge(x86::Emitter* a, bool isTop, bool isBottom); + GCONV_INST_AVX2_HEADER genConstForPermutations(x86::Emitter* a); - void initResultRegs(x86::Emitter* a); + GCONV_INST_AVX512_AND_VNNI_HEADER genConstForPermutations(x86::Emitter* a); - void genCoreInsts(x86::Emitter* a); + GCONV_INST_AVX2_HEADER genForSingleFilterPoint( + x86::Emitter* a, + int r, + int s, + int act_s, + bool use_zero_reg); - void genForSingleFilterPoint( + GCONV_INST_AVX512_AND_VNNI_HEADER genForSingleFilterPoint( x86::Emitter* a, int r, int s, int act_s, - bool use_zero_reg) {} + bool use_zero_reg); + + GCONV_INST_AVX2_HEADER storeResult(x86::Emitter* a); + + GCONV_INST_AVX512_AND_VNNI_HEADER storeResult(x86::Emitter* a); + + GCONV_INST_AVX2_HEADER storeOffset(x86::Emitter* a); + + GCONV_INST_AVX512_AND_VNNI_HEADER storeOffset(x86::Emitter* a); + + void genForTopOrBottomEdge(x86::Emitter* a, bool isTop, bool isBottom); + + void initResultRegs(x86::Emitter* a); + + void genCoreInsts(x86::Emitter* a); + void genForSingleOutput( x86::Emitter* a, bool isLeft, @@ -241,9 +283,6 @@ class FBGEMM_API GenConvKernel bool isTop, bool isBottom); - void storeResult(x86::Emitter* a) {} - void storeOffset(x86::Emitter* a) {} - private: int GTogether_; // The number of iterations needed for K dim. @@ -287,39 +326,4 @@ template CodeCache GenConvKernelBase::codeCache_; -// forward declaration of specialized ISA specific functions -template <> -void GenConvKernel<2, inst_set_t::avx2>::genConstForPermutations( - x86::Emitter* a); -template <> -void GenConvKernel<2, inst_set_t::avx2>::genForLoadingWeights(x86::Emitter* a); -template <> -void GenConvKernel<2, inst_set_t::avx2>::storeResult(x86::Emitter* a); -template <> -void GenConvKernel<2, inst_set_t::avx2>::storeOffset(x86::Emitter* a); -template <> -void GenConvKernel<2, inst_set_t::avx2>::genForSingleFilterPoint( - x86::Emitter* a, - int r, - int s, - int act_s, - bool use_zero_reg); - -template <> -void GenConvKernel<2, inst_set_t::avx512>::genConstForPermutations( - x86::Emitter* a); -template <> -void GenConvKernel<2, inst_set_t::avx512>::genForLoadingWeights(x86::Emitter* a); -template <> -void GenConvKernel<2, inst_set_t::avx512>::storeResult(x86::Emitter* a); -template <> -void GenConvKernel<2, inst_set_t::avx512>::storeOffset(x86::Emitter* a); -template <> -void GenConvKernel<2, inst_set_t::avx512>::genForSingleFilterPoint( - x86::Emitter* a, - int r, - int s, - int act_s, - bool use_zero_reg); - } // namespace fbgemm diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc index c5e01b4441..a5cb1faf87 100644 --- a/src/GroupwiseConvAcc32Avx2.cc +++ b/src/GroupwiseConvAcc32Avx2.cc @@ -16,8 +16,7 @@ using namespace std; namespace x86 = asmjit::x86; -template <> -void GenConvKernel<2, inst_set_t::avx2>::genConstForPermutations( +GCONV_INST_DEF_AVX2_HEADER GenConvKernel::genConstForPermutations( x86::Emitter* a) { if (this->C_per_G_ == 4) { x86::Gp permute_const_reg = a->gpz(12); @@ -45,8 +44,8 @@ void GenConvKernel<2, inst_set_t::avx2>::genConstForPermutations( } } -template <> -void GenConvKernel<2, inst_set_t::avx2>::genForLoadingWeights(x86::Emitter* a) { +GCONV_INST_DEF_AVX2_HEADER GenConvKernel::genForLoadingWeights( + x86::Emitter* a) { using WRegs = x86::Ymm; int paddedICPerG = (this->C_per_G_ + 3) / 4 * 4; // load weights @@ -66,8 +65,7 @@ void GenConvKernel<2, inst_set_t::avx2>::genForLoadingWeights(x86::Emitter* a) { } } -template <> -void GenConvKernel<2, inst_set_t::avx2>::storeResult(x86::Emitter* a) { +GCONV_INST_DEF_AVX2_HEADER GenConvKernel::storeResult(x86::Emitter* a) { if (GTogether_ > 1) { // store with permutation a->vpermd(x86::Ymm(9), stPermReg_V_, x86::Ymm(9)); @@ -113,8 +111,7 @@ void GenConvKernel<2, inst_set_t::avx2>::storeResult(x86::Emitter* a) { } } -template <> -void GenConvKernel<2, inst_set_t::avx2>::storeOffset(x86::Emitter* a) { +GCONV_INST_DEF_AVX2_HEADER GenConvKernel::storeOffset(x86::Emitter* a) { switch (this->C_per_G_) { case 2: // store 128-bits containing rowoffset for four groups @@ -155,8 +152,7 @@ void GenConvKernel<2, inst_set_t::avx2>::storeOffset(x86::Emitter* a) { } } -template <> -void GenConvKernel<2, inst_set_t::avx2>::genForSingleFilterPoint( +GCONV_INST_DEF_AVX2_HEADER GenConvKernel::genForSingleFilterPoint( x86::Emitter* a, int r, int s, @@ -189,12 +185,12 @@ void GenConvKernel<2, inst_set_t::avx2>::genForSingleFilterPoint( } // row offset if (this->needRowOffset_) { - genU8Sum4( + genU8Sum4( a, actReg_V_, rowOffsetReg_V_, oneReg16Bit_V_, tmpReg1_V_); } // 32 * int8 weight product 32 * uint8 activation -> 8 // output(K_per_g * group_together) - genU8I8S32FMA( + genU8I8S32FMA( a, actReg_V_, WRegs(r * this->S_ + s), @@ -235,12 +231,26 @@ void GenConvKernel<2, inst_set_t::avx2>::genForSingleFilterPoint( // FMA result is not final reduction on C_per_G, producing 8 output in // which consectutive 2 elements if summedforms one final output over // K_Per_G dimension - genU8I8S32FMA( + genU8I8S32FMA( a, actReg_V_, WRegs(0), x86::Ymm(9 - k), oneReg16Bit_V_, tmpReg1_V_); } } } +#define GENCONVKERNEL_FUNCS(S, IN) \ + template void GenConvKernel::genForLoadingWeights( \ + x86::Emitter * a); \ + template void GenConvKernel::genConstForPermutations( \ + x86::Emitter * a); \ + template void GenConvKernel::genForSingleFilterPoint( \ + x86::Emitter * a, int r, int s, int act_s, bool use_zero_reg); \ + template void GenConvKernel::storeResult(x86::Emitter * a); \ + template void GenConvKernel::storeOffset(x86::Emitter * a); +GENCONVKERNEL_FUNCS(1, inst_set_t::avx2) +GENCONVKERNEL_FUNCS(2, inst_set_t::avx2) +GENCONVKERNEL_FUNCS(3, inst_set_t::avx2) +#undef GENCONVKERNEL_FUNCS + template class GenConvKernel<1, inst_set_t::avx2>; template class GenConvKernel<2, inst_set_t::avx2>; template class GenConvKernel<3, inst_set_t::avx2>; diff --git a/src/GroupwiseConvAcc32Avx512.cc b/src/GroupwiseConvAcc32Avx512.cc index 1cd01820a0..c310b2ba2b 100644 --- a/src/GroupwiseConvAcc32Avx512.cc +++ b/src/GroupwiseConvAcc32Avx512.cc @@ -15,9 +15,8 @@ using namespace std; namespace x86 = asmjit::x86; -template <> -void GenConvKernel<2, inst_set_t::avx512>::genConstForPermutations( - x86::Emitter* a) { +GCONV_INST_DEF_AVX512_AND_VNNI_HEADER +GenConvKernel::genConstForPermutations(x86::Emitter* a) { x86::Gp permute_const_reg_upper_half = a->gpz(12); x86::Gp permute_const_reg_lower_half = a->gpz(13); x86::Xmm const_reg_xmm = x86::xmm11; @@ -66,9 +65,8 @@ void GenConvKernel<2, inst_set_t::avx512>::genConstForPermutations( a->vpmovzxbd(stPermReg_V_, const_reg_xmm); } -template <> -void GenConvKernel<2, inst_set_t::avx512>::genForLoadingWeights( - x86::Emitter* a) { +GCONV_INST_DEF_AVX512_AND_VNNI_HEADER +GenConvKernel::genForLoadingWeights(x86::Emitter* a) { using WRegs = x86::Zmm; int paddedICPerG = (this->C_per_G_ + 3) / 4 * 4; // load weights @@ -91,8 +89,8 @@ void GenConvKernel<2, inst_set_t::avx512>::genForLoadingWeights( } } -template <> -void GenConvKernel<2, inst_set_t::avx512>::storeResult(x86::Emitter* a) { +GCONV_INST_DEF_AVX512_AND_VNNI_HEADER +GenConvKernel::storeResult(x86::Emitter* a) { if (GTogether_ > 1) { // store with permutation a->vpermd(x86::Zmm(9), stPermReg_V_, x86::Zmm(9)); @@ -137,8 +135,8 @@ void GenConvKernel<2, inst_set_t::avx512>::storeResult(x86::Emitter* a) { } } -template <> -void GenConvKernel<2, inst_set_t::avx512>::storeOffset(x86::Emitter* a) { +GCONV_INST_DEF_AVX512_AND_VNNI_HEADER +GenConvKernel::storeOffset(x86::Emitter* a) { auto rowOffsetReg_V_Ymm = rowOffsetReg_V_.half(); auto rowOffsetReg_V_Xmm = rowOffsetReg_V_Ymm.half(); auto tmpReg1_V_Xmm = tmpReg1_V_.half().half(); @@ -186,8 +184,8 @@ void GenConvKernel<2, inst_set_t::avx512>::storeOffset(x86::Emitter* a) { } } -template <> -void GenConvKernel<2, inst_set_t::avx512>::genForSingleFilterPoint( +GCONV_INST_DEF_AVX512_AND_VNNI_HEADER +GenConvKernel::genForSingleFilterPoint( x86::Emitter* a, int r, int s, @@ -223,7 +221,7 @@ void GenConvKernel<2, inst_set_t::avx512>::genForSingleFilterPoint( // row offset if (this->needRowOffset_) { if (this->C_per_G_ == 2 || this->C_per_G_ == 4) { - genU8Sum4( + genU8Sum4( a, actReg_V_, rowOffsetReg_V_, oneReg16Bit_V_, tmpReg1_V_); } else { // still use Ymm for Sum8 @@ -232,7 +230,7 @@ void GenConvKernel<2, inst_set_t::avx512>::genForSingleFilterPoint( } // FMA if (this->C_per_G_ != 16) { - genU8I8S32FMA( + genU8I8S32FMA( a, actReg_V_, WRegs(r * this->S_ + s), @@ -253,12 +251,32 @@ void GenConvKernel<2, inst_set_t::avx512>::genForSingleFilterPoint( // FMA result is not final reduction on C_per_G, producing 4 * 16 outputs // in which consectutive 4 elements if summed forms one final output over // K_Per_G dimension, we need 16 final 32bits outputs. - genU8I8S32FMA( + genU8I8S32FMA( a, actReg_V_, WRegs(0), WRegs(9 - k), oneReg16Bit_V_, tmpReg1_V_); } } } +#define GENCONVKERNEL_FUNCS(S, IN) \ + template void GenConvKernel::genForLoadingWeights(x86::Emitter* a); \ + template void GenConvKernel::genConstForPermutations( \ + x86::Emitter* a); \ + template void GenConvKernel::genForSingleFilterPoint( \ + x86::Emitter* a, int r, int s, int act_s, bool use_zero_reg); \ + template void GenConvKernel::storeResult(x86::Emitter* a); \ + template void GenConvKernel::storeOffset(x86::Emitter* a); +GENCONVKERNEL_FUNCS(1, inst_set_t::avx512) +GENCONVKERNEL_FUNCS(1, inst_set_t::avx512_vnni) +GENCONVKERNEL_FUNCS(2, inst_set_t::avx512) +GENCONVKERNEL_FUNCS(2, inst_set_t::avx512_vnni) +GENCONVKERNEL_FUNCS(3, inst_set_t::avx512) +GENCONVKERNEL_FUNCS(3, inst_set_t::avx512_vnni) +#undef GENCONVKERNEL_FUNCS +template class GenConvKernel<1, inst_set_t::avx512>; +template class GenConvKernel<1, inst_set_t::avx512_vnni>; template class GenConvKernel<2, inst_set_t::avx512>; +template class GenConvKernel<2, inst_set_t::avx512_vnni>; +template class GenConvKernel<3, inst_set_t::avx512>; +template class GenConvKernel<3, inst_set_t::avx512_vnni>; } // namespace fbgemm diff --git a/src/PackWeightMatrixForGConv.cc b/src/PackWeightMatrixForGConv.cc index 9a1e3042f0..86f85d4fc4 100644 --- a/src/PackWeightMatrixForGConv.cc +++ b/src/PackWeightMatrixForGConv.cc @@ -51,7 +51,7 @@ int PackWeightMatrixForGConv::numOfGroupsTogether( const conv_param_t& conv_param) { int OC_per_G = conv_param.OC / conv_param.G; int IC_per_G = conv_param.IC / conv_param.G; - if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512Support() || fbgemmHasAvx512VnniSupport()) { // TODO: change to avx512 when avx512 support is available return std::max( simd_info::WIDTH_BYTES / OC_per_G /