From cf34b9a26b609109b18d6498f0608faddb7a911b Mon Sep 17 00:00:00 2001 From: Jianyu Huang Date: Tue, 6 Aug 2019 11:55:17 -0700 Subject: [PATCH] Back out "[fbgemm] Integrate VNNI into FBGEMM master branch" Summary: Original commit changeset: fcaa13cc3159 ASMJIT requires the CMake version to be 3.8 However, FBGEMM and PyTorch only need the CMake version to be 3.5+. This caused the build failure in FBGEMM: https://circleci.com/gh/pytorch/FBGEMM/122#build-timing/containers/0 Reviewed By: dskhudia Differential Revision: D16670547 fbshipit-source-id: 506714c3db1cb82cf98895f58f82f235128f5285 --- CMakeLists.txt | 2 - include/fbgemm/PackingTraits-inl.h | 50 --- include/fbgemm/Utils.h | 7 +- src/ExecuteKernelU8S8.cc | 47 +-- src/Fbgemm.cc | 18 +- src/GenerateKernel.h | 30 +- src/GenerateKernelU8S8S32ACC16.cc | 91 +++-- src/GenerateKernelU8S8S32ACC16Avx512.cc | 94 ++--- src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc | 102 ----- src/GenerateKernelU8S8S32ACC32.cc | 87 ++-- src/GenerateKernelU8S8S32ACC32Avx512.cc | 95 ++--- src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc | 431 -------------------- src/GroupwiseConv.h | 100 ++--- src/GroupwiseConvAcc32Avx2.cc | 177 ++++---- src/PackAMatrix.cc | 10 +- src/PackAWithIm2Col.cc | 14 +- src/PackAWithQuantRowOffset.cc | 14 +- src/PackAWithRowOffset.cc | 14 +- src/PackBMatrix.cc | 25 +- src/PackMatrix.cc | 9 +- src/PackWeightMatrixForGConv.cc | 8 +- src/Utils.cc | 3 - test/GConvTest.cc | 4 +- third_party/asmjit | 2 +- 24 files changed, 379 insertions(+), 1055 deletions(-) delete mode 100644 src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc delete mode 100644 src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 817f699ff3..b575e17677 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,10 +33,8 @@ set(FBGEMM_GENERIC_SRCS src/ExecuteKernel.cc src/FbgemmI8Spmdm.cc src/GenerateKernelU8S8S32ACC16.cc src/GenerateKernelU8S8S32ACC16Avx512.cc - src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc src/GenerateKernelU8S8S32ACC32.cc src/GenerateKernelU8S8S32ACC32Avx512.cc - src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc src/GroupwiseConvAcc32Avx2.cc src/PackAMatrix.cc src/PackAWithIm2Col.cc diff --git a/include/fbgemm/PackingTraits-inl.h b/include/fbgemm/PackingTraits-inl.h index baccfadd69..76eb4252ed 100644 --- a/include/fbgemm/PackingTraits-inl.h +++ b/include/fbgemm/PackingTraits-inl.h @@ -222,53 +222,3 @@ struct PackingTraits< 128}; ///< Cache block for N dimension (multiple of NR). static constexpr int KCB{256}; ///< Cache block for K dimension. }; - -/** - * @brief Helper struct to type specialize for int16_t and int32_t together. - */ -template -struct is_16or32bit { - static constexpr bool value = - std::is_same::value || std::is_same::value; -}; - -/** - * @brief Packing parameter specialization for accumulation into 32-bit/16-bit - * integers. - * - * Since there is no int16_t accumulation for AVX512 VNNI, we redirect int16_t - * to int32_t accumulation and use the same blocking parameters as int32_t. - * - * This is picked when T is of int8 type (signed or unsigned) and instruction - * set is avx512_vnni. - */ -template -struct PackingTraits< - T, - accT, - inst_set_t::avx512_vnni, - typename std::enable_if< - is_8bit::value && is_16or32bit::value>::type> { - static constexpr int MR{8}; ///< Register block for M dimension. - static constexpr int NR_MIN{ - 16}; ///< Minimum register block for N dimension. - ///< 16 because 16*ROW_INTERLEAVE int8 elements - ///< completely fill a 512-bit wide vector. - static constexpr int NR{ - 32}; ///< Register block for N dimension. - ///< Must be a multiple of 16 because 16*ROW_INTERLEAVE int8 elements - ///< completely fill a 512-bit wide vector. Total registers used for - ///< N dimension: NR*ROW_INTERLEAVE*8/VLEN. We use MR x - ///< NR*ROW_INTERLEAVE*8/VLEN zmm registers - ///< for C accumulations. - - static constexpr int ROW_INTERLEAVE{ - 4}; ///< 4 rows are interleaved to use vpmaddubsw instruction for packing - ///< B matrix. - - static constexpr int MCB{ - 128}; ///< Cache block for M dimension (multiple of MR). - static constexpr int NCB{ - 32}; ///< Cache block for N dimension (multiple of NR). - static constexpr int KCB{256}; ///< Cache block for K dimension. -}; diff --git a/include/fbgemm/Utils.h b/include/fbgemm/Utils.h index 3f8522bee8..107cf073f3 100644 --- a/include/fbgemm/Utils.h +++ b/include/fbgemm/Utils.h @@ -29,7 +29,7 @@ enum class matrix_op_t { NoTranspose, Transpose }; /** * @brief Typed enum for supported instruction sets. */ -enum class inst_set_t { anyarch, avx2, avx512, avx512_vnni }; +enum class inst_set_t { anyarch, avx2, avx512 }; /** * @brief Typed enum for optimized paths for convolutions @@ -99,11 +99,6 @@ FBGEMM_API bool fbgemmHasAvx512Support(); */ FBGEMM_API bool fbgemmHasAvx2Support(); -/** - * @brief Are we running on a AVX512_VNNI supported cpu? - */ -FBGEMM_API bool fbgemmHasAvx512VnniSupport(); - /** * @brief Helper struct to enable autotuning of FBGEMM packing and kernels. * diff --git a/src/ExecuteKernelU8S8.cc b/src/ExecuteKernelU8S8.cc index 0a4ff550c5..f7292fd27e 100644 --- a/src/ExecuteKernelU8S8.cc +++ b/src/ExecuteKernelU8S8.cc @@ -49,8 +49,7 @@ ExecuteKernel< throw std::runtime_error("Failed to initialize cpuinfo!"); } if (params) { - if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support() || - fbgemmHasAvx2Support()) { + if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) { mbSize_ = params->MCB; nbSize_ = params->NCB; nrMinSize_ = params->NR_MIN; @@ -60,20 +59,7 @@ ExecuteKernel< assert(0 && "unsupported architecure"); } } else { - if (fbgemmHasAvx512VnniSupport()) { - mbSize_ = PackingTraits< - int8_t, - typename packingAMatrix::accType, - inst_set_t::avx512_vnni>::MCB; - nbSize_ = PackingTraits< - int8_t, - typename packingAMatrix::accType, - inst_set_t::avx512_vnni>::NCB; - nrMinSize_ = PackingTraits< - int8_t, - typename packingAMatrix::accType, - inst_set_t::avx512_vnni>::NR_MIN; - } else if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512Support()) { mbSize_ = PackingTraits< int8_t, typename packingAMatrix::accType, @@ -132,25 +118,7 @@ void ExecuteKernel< typename BaseType::jit_micro_kernel_fp fn; - if (fbgemmHasAvx512VnniSupport()) { - if (std::is_same::value) { - // For AVX512VNNI, we redirect int16_t to int32_t accumulation. - CodeGenBase codeObj; - fn = codeObj.getOrCreate( - accum, - packed_rows_A, - packedB_.blockColSize(), - packedA_.numPackedCols(), - nbSize_); - } else { - fn = BaseType::template getOrCreate( - accum, - packed_rows_A, - packedB_.blockColSize(), - packedA_.numPackedCols(), - nbSize_); - } - } else if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512Support()) { fn = BaseType::template getOrCreate( accum, packed_rows_A, @@ -180,10 +148,7 @@ void ExecuteKernel< if (jb == bColBlocks - 1) { int nc = ((packedB_.lastBcol() - 1) / nrMinSize_ + 1) * nrMinSize_; if (nc != nbSize_) { - if (fbgemmHasAvx512VnniSupport()) { - fn = BaseType::template getOrCreate( - accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_); - } else if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512Support()) { fn = BaseType::template getOrCreate( accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_); } else if (fbgemmHasAvx2Support()) { @@ -248,7 +213,7 @@ void ExecuteKernel< int32_t nSize = C_buffer_start == C_tile_ ? jb * nbSize_ : packedB_.numCols(); if (nSize) { - if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512Support()) { // TODO: avx512 path // Currently use avx2 code outputProcess_.template f( @@ -273,7 +238,7 @@ void ExecuteKernel< if (C_buffer_start == C_tile_) { // When C_tile_ scratchpad was used to avoid accessing memory past // C_buffer_ . - if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512Support()) { // TODO: avx512 path // Currently use avx2 code outputProcess_.template f( diff --git a/src/Fbgemm.cc b/src/Fbgemm.cc index 4f7026f7b0..0f2f6fbe45 100644 --- a/src/Fbgemm.cc +++ b/src/Fbgemm.cc @@ -48,8 +48,7 @@ void fbgemmPacked( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && - !fbgemmHasAvx2Support())) { + if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) { assert(0 && "unknown architecure"); } @@ -63,20 +62,7 @@ void fbgemmPacked( MR = blocking_params->MR; } else { - if (fbgemmHasAvx512VnniSupport()) { - MCB = PackingTraits< - typename packingAMatrix::inpType, - typename packingAMatrix::accType, - inst_set_t::avx512_vnni>::MCB; - KCB = PackingTraits< - typename packingAMatrix::inpType, - typename packingAMatrix::accType, - inst_set_t::avx512_vnni>::KCB; - MR = PackingTraits< - typename packingAMatrix::inpType, - typename packingAMatrix::accType, - inst_set_t::avx512_vnni>::MR; - } else if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512Support()) { MCB = PackingTraits< typename packingAMatrix::inpType, typename packingAMatrix::accType, diff --git a/src/GenerateKernel.h b/src/GenerateKernel.h index e52097e18e..dccdfc5513 100644 --- a/src/GenerateKernel.h +++ b/src/GenerateKernel.h @@ -18,7 +18,7 @@ namespace fbgemm { namespace x86 = asmjit::x86; /** - * @brief AVX2/AVX512/AVX512VNNI JIT assembly code generator. + * @brief AVX2/AVX512 JIT assembly code generator. * @tparam TA Type of matrix A. * @tparam TB Type of matrix B. * @tparam TC Type of matrix C. @@ -104,7 +104,7 @@ class CodeGenBase { */ template void initCRegs( - x86::Emitter* a, + asmjit::X86Emitter* a, int rowRegs, int colRegs, int leadingDimCRegAssign = 4); @@ -114,10 +114,10 @@ class CodeGenBase { */ template void genComputeBlock( - x86::Emitter* a, - x86::Gp buffer_A, - x86::Gp buffer_B, - x86::Gp B_pf, + asmjit::X86Emitter* a, + asmjit::X86Gp buffer_A, + asmjit::X86Gp buffer_B, + asmjit::X86Gp B_pf, int rowRegs, int colRegs, int lda, @@ -129,11 +129,11 @@ class CodeGenBase { */ template void storeCRegs( - x86::Emitter* a, + asmjit::X86Emitter* a, int rowRegs, int colRegs, - x86::Gp C_Offset, - x86::Gp ldcReg, + asmjit::X86Gp C_Offset, + asmjit::X86Gp ldcReg, bool accum, int leadingDimCRegAssign = 4); @@ -168,9 +168,7 @@ class CodeGenBase { fileName += "_MR-" + std::to_string(MR); fileName += "_NR-" + std::to_string(NR); fileName += "_NR_MIN-" + std::to_string(NR_MIN); - if (instSet == inst_set_t::avx512_vnni) { - fileName += "_avx512vnni"; - } else if (instSet == inst_set_t::avx512) { + if (instSet == inst_set_t::avx512) { fileName += "_avx512"; } else if (instSet == inst_set_t::avx2) { fileName += "_avx2"; @@ -180,10 +178,12 @@ class CodeGenBase { } private: - x86::Ymm CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel. - x86::Zmm + asmjit::X86Ymm + CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel. + asmjit::X86Zmm CRegs_avx512_[28]; ///< AVX512 zmm registers for C in the micro-kernel. - x86::Zmm AllRegs_avx512_[32]; ///< all AVX512 zmm registers. + asmjit::X86Zmm + AllRegs_avx512_[32]; ///< all AVX512 zmm registers. int vectorWidth_; ///< Vector width in bits. int VLEN_; ///< Vector width in elements. diff --git a/src/GenerateKernelU8S8S32ACC16.cc b/src/GenerateKernelU8S8S32ACC16.cc index 1e7e08106c..718b8832d6 100644 --- a/src/GenerateKernelU8S8S32ACC16.cc +++ b/src/GenerateKernelU8S8S32ACC16.cc @@ -31,7 +31,7 @@ template <> template <> void CodeGenBase::initCRegs< inst_set_t::avx2>( - x86::Emitter* a, + asmjit::X86Emitter* a, int rowRegs, int colRegs, int leadingDimCReg) { @@ -53,18 +53,18 @@ template <> template <> void CodeGenBase::genComputeBlock< inst_set_t::avx2>( - x86::Emitter* a, - x86::Gp buffer_A, - x86::Gp buffer_B, - x86::Gp /* unused (reserved for prefetching)*/, + asmjit::X86Emitter* a, + asmjit::X86Gp buffer_A, + asmjit::X86Gp buffer_B, + asmjit::X86Gp /* unused (reserved for prefetching)*/, int rowRegs, int colRegs, int lda, int leadingDimCReg) { // used for matrix A - x86::Ymm AReg = x86::ymm12; + asmjit::X86Ymm AReg = x86::ymm12; - x86::Ymm tmpReg = x86::ymm14; + asmjit::X86Ymm tmpReg = x86::ymm14; for (int i = 0; i < rowRegs; ++i) { // broadcast A @@ -95,15 +95,15 @@ template <> template <> void CodeGenBase::storeCRegs< inst_set_t::avx2>( - x86::Emitter* a, + asmjit::X86Emitter* a, int rowRegs, int colRegs, - x86::Gp C_Offset, - x86::Gp ldcReg, + asmjit::X86Gp C_Offset, + asmjit::X86Gp ldcReg, bool accum, int leadingDimCReg) { - x86::Xmm extractDest128 = x86::xmm15; - x86::Ymm extractDest256 = x86::ymm15; + asmjit::X86Xmm extractDest128 = x86::xmm15; + asmjit::X86Ymm extractDest256 = x86::ymm15; for (int i = 0; i < rowRegs; ++i) { a->imul(C_Offset, ldcReg, static_cast(i * sizeof(int32_t))); @@ -112,7 +112,7 @@ void CodeGenBase::storeCRegs< a->vextracti128( extractDest128, CRegs_avx2_[i * leadingDimCReg + j], idx); a->vpmovsxwd(extractDest256, extractDest128); - x86::Mem destAddr = x86::dword_ptr( + asmjit::X86Mem destAddr = x86::dword_ptr( a->zcx(), C_Offset, 0, (j * 2 + idx) * 8 * sizeof(int32_t)); if (accum) { a->vpaddd(extractDest256, extractDest256, destAddr); @@ -176,9 +176,9 @@ CodeGenBase::getOrCreate( return codeCache_[kernelSig]; } code_.reset(false); - code_.init(rt_.codeInfo()); - x86::Assembler assembler(&code_); - x86::Emitter* a = assembler.as(); + code_.init(rt_.getCodeInfo()); + asmjit::X86Assembler assembler(&code_); + asmjit::X86Emitter* a = assembler.asEmitter(); #if defined(FBGEMM_LOG_CODE) // generated code logging @@ -207,45 +207,46 @@ CodeGenBase::getOrCreate( //"nc must be equal to the number of register blocks"); // arguments to the function created - x86::Gp buffer_A = a->zdi(); - x86::Gp buffer_B = a->zsi(); - x86::Gp B_pf = a->zdx(); - x86::Gp CBase = a->zcx(); - x86::Gp kSize = a->gpz(8); - x86::Gp ldcReg = a->gpz(9); + asmjit::X86Gp buffer_A = a->zdi(); + asmjit::X86Gp buffer_B = a->zsi(); + asmjit::X86Gp B_pf = a->zdx(); + asmjit::X86Gp CBase = a->zcx(); + asmjit::X86Gp kSize = a->gpzRef(8); + asmjit::X86Gp ldcReg = a->gpzRef(9); asmjit::FuncDetail func; func.init( asmjit:: - FuncSignatureT( + FuncSignature6( asmjit::CallConv::kIdHost)); - asmjit::FuncFrame frame; - frame.init(func); - frame.setDirtyRegs( - x86::Reg::kGroupVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - frame.setDirtyRegs( - x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14)); + asmjit::FuncFrameInfo ffi; + ffi.setDirtyRegs( + asmjit::X86Reg::kKindVec, + asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); + ffi.setDirtyRegs( + asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14)); - asmjit::FuncArgsAssignment args(&func); + asmjit::FuncArgsMapper args(&func); args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - args.updateFuncFrame(frame); - frame.finalize(); + args.updateFrameInfo(ffi); - a->emitProlog(frame); - a->emitArgsAssignment(frame, args); + asmjit::FuncFrameLayout layout; + layout.init(func, ffi); + + asmjit::FuncUtils::emitProlog(a, layout); + asmjit::FuncUtils::allocArgs(a, layout, args); asmjit::Label Loopk = a->newLabel(); asmjit::Label LoopMBlocks = a->newLabel(); - x86::Gp buffer_B_saved = a->gpz(10); - x86::Gp C_Offset = a->gpz(11); - // x86::Gp B_pf_saved = a->gpz(12); - x86::Gp iIdx = a->gpz(13); - x86::Gp kIdx = a->gpz(14); + asmjit::X86Gp buffer_B_saved = a->gpzRef(10); + asmjit::X86Gp C_Offset = a->gpzRef(11); + // asmjit::X86Gp B_pf_saved = a->gpzRef(12); + asmjit::X86Gp iIdx = a->gpzRef(13); + asmjit::X86Gp kIdx = a->gpzRef(14); int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; if (mRegBlocks > 0) { @@ -288,7 +289,8 @@ CodeGenBase::getOrCreate( a->jl(Loopk); // store C matrix - storeCRegs(a, rowRegs, colRegs, C_Offset, ldcReg, accum); + storeCRegs( + a, rowRegs, colRegs, C_Offset, ldcReg, accum); // increment A for next block a->sub(buffer_A, kSize); @@ -338,10 +340,11 @@ CodeGenBase::getOrCreate( a->jl(LoopkRem); // store C matrix - storeCRegs(a, rowRegs, colRegs, C_Offset, ldcReg, accum); + storeCRegs( + a, rowRegs, colRegs, C_Offset, ldcReg, accum); } - a->emitEpilog(frame); + asmjit::FuncUtils::emitEpilog(a, layout); jit_micro_kernel_fp fn; asmjit::Error err = rt_.add(&fn, &code_); diff --git a/src/GenerateKernelU8S8S32ACC16Avx512.cc b/src/GenerateKernelU8S8S32ACC16Avx512.cc index a49e4406cf..c95757b5c3 100644 --- a/src/GenerateKernelU8S8S32ACC16Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC16Avx512.cc @@ -19,7 +19,7 @@ template <> template <> void CodeGenBase::initCRegs< inst_set_t::avx512>( - x86::Emitter* a, + asmjit::X86Emitter* a, int rowRegs, int colRegs, int leadingDimCReg) { @@ -41,18 +41,18 @@ template <> template <> void CodeGenBase::genComputeBlock< inst_set_t::avx512>( - x86::Emitter* a, - x86::Gp buffer_A, - x86::Gp buffer_B, - x86::Gp /* unused (reserved for prefetching)*/, + asmjit::X86Emitter* a, + asmjit::X86Gp buffer_A, + asmjit::X86Gp buffer_B, + asmjit::X86Gp /* unused (reserved for prefetching)*/, int rowRegs, int colRegs, int lda, int leadingDimCReg) { // used for matrix A - x86::Zmm AReg = x86::zmm29; + asmjit::X86Zmm AReg = x86::zmm29; - x86::Zmm tmpReg = x86::zmm30; + asmjit::X86Zmm tmpReg = x86::zmm30; // We start allocating BRegs from zmm27 and then allocate zmm26 and so on. for (int j = 0; j < colRegs; ++j) { @@ -66,7 +66,8 @@ void CodeGenBase::genComputeBlock< a->vpbroadcastw( AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); for (int j = 0; j < colRegs; ++j) { - a->vpmaddubsw(tmpReg, AReg, AllRegs_avx512_[27 - j]); + a->vpmaddubsw( + tmpReg, AReg, AllRegs_avx512_[27-j]); a->vpaddsw( CRegs_avx512_[i * leadingDimCReg + j], tmpReg, @@ -89,16 +90,15 @@ template <> template <> void CodeGenBase::storeCRegs< inst_set_t::avx512>( - x86::Emitter* a, + asmjit::X86Emitter* a, int rowRegs, int colRegs, - x86::Gp C_Offset, - x86::Gp ldcReg, - + asmjit::X86Gp C_Offset, + asmjit::X86Gp ldcReg, bool accum, int leadingDimCReg) { - x86::Ymm extractDest256 = x86::ymm31; - x86::Zmm extractDest512 = x86::zmm31; + asmjit::X86Ymm extractDest256 = x86::ymm31; + asmjit::X86Zmm extractDest512 = x86::zmm31; for (int i = 0; i < rowRegs; ++i) { a->imul(C_Offset, ldcReg, static_cast(i * sizeof(int32_t))); @@ -107,7 +107,7 @@ void CodeGenBase::storeCRegs< a->vextracti32x8( extractDest256, CRegs_avx512_[i * leadingDimCReg + j], idx); a->vpmovsxwd(extractDest512, extractDest256); - x86::Mem destAddr = x86::dword_ptr( + asmjit::X86Mem destAddr = x86::dword_ptr( a->zcx(), C_Offset, 0, (j * 2 + idx) * 16 * sizeof(int32_t)); if (accum) { a->vpaddd(extractDest512, extractDest512, destAddr); @@ -172,9 +172,9 @@ CodeGenBase::getOrCreate( } code_.reset(false); - code_.init(rt_.codeInfo()); - x86::Assembler assembler(&code_); - x86::Emitter* a = assembler.as(); + code_.init(rt_.getCodeInfo()); + asmjit::X86Assembler assembler(&code_); + asmjit::X86Emitter* a = assembler.asEmitter(); #if defined(FBGEMM_LOG_CODE) // generated code logging @@ -209,49 +209,49 @@ CodeGenBase::getOrCreate( int mRegBlocksRem = mc % mRegBlockSize; // arguments to the function created - x86::Gp buffer_A = a->zdi(); - x86::Gp buffer_B = a->zsi(); - x86::Gp B_pf = a->zdx(); - x86::Gp CBase = a->zcx(); - x86::Gp kSize = a->gpz(8); - x86::Gp ldcReg = a->gpz(9); + asmjit::X86Gp buffer_A = a->zdi(); + asmjit::X86Gp buffer_B = a->zsi(); + asmjit::X86Gp B_pf = a->zdx(); + asmjit::X86Gp CBase = a->zcx(); + asmjit::X86Gp kSize = a->gpzRef(8); + asmjit::X86Gp ldcReg = a->gpzRef(9); asmjit::FuncDetail func; func.init( asmjit:: - FuncSignatureT( + FuncSignature6( asmjit::CallConv::kIdHost)); - asmjit::FuncFrame frame; - frame.init(func); - - frame.setDirtyRegs( - x86::Reg::kGroupVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - frame.setDirtyRegs( - x86::Reg::kGroupGp, - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + asmjit::FuncFrameInfo ffi; + ffi.setDirtyRegs( + asmjit::X86Reg::kKindVec, + asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); + ffi.setDirtyRegs( + asmjit::X86Reg::kKindGp, + asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - asmjit::FuncArgsAssignment args(&func); + asmjit::FuncArgsMapper args(&func); args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - args.updateFuncFrame(frame); - frame.finalize(); + args.updateFrameInfo(ffi); + + asmjit::FuncFrameLayout layout; + layout.init(func, ffi); - a->emitProlog(frame); - a->emitArgsAssignment(frame, args); + asmjit::FuncUtils::emitProlog(a, layout); + asmjit::FuncUtils::allocArgs(a, layout, args); asmjit::Label LoopMBlocks = a->newLabel(); asmjit::Label LoopNBlocks = a->newLabel(); asmjit::Label Loopk = a->newLabel(); - x86::Gp buffer_B_saved = a->gpz(10); - x86::Gp C_Offset = a->gpz(11); - // x86::Gp B_pf_saved = a->gpz(12); - x86::Gp iIdx = a->gpz(13); - x86::Gp jIdx = a->gpz(14); - x86::Gp kIdx = a->gpz(15); + asmjit::X86Gp buffer_B_saved = a->gpzRef(10); + asmjit::X86Gp C_Offset = a->gpzRef(11); + // asmjit::X86Gp B_pf_saved = a->gpzRef(12); + asmjit::X86Gp iIdx = a->gpzRef(13); + asmjit::X86Gp jIdx = a->gpzRef(14); + asmjit::X86Gp kIdx = a->gpzRef(15); // save B_buffer address a->mov(buffer_B_saved, buffer_B); @@ -407,7 +407,7 @@ CodeGenBase::getOrCreate( a->jl(LoopNRem); } - a->emitEpilog(frame); + asmjit::FuncUtils::emitEpilog(a, layout); jit_micro_kernel_fp fn; asmjit::Error err = rt_.add(&fn, &code_); diff --git a/src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc b/src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc deleted file mode 100644 index f559aba2a8..0000000000 --- a/src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc +++ /dev/null @@ -1,102 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * All rights reserved. - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ -#include -#include "GenerateKernel.h" - -namespace fbgemm { - -namespace x86 = asmjit::x86; - -/** - * Generate AVX512 instructions for initializing the C registers to 0 in 16-bit - * Accumulation kernel. - */ -template <> -template <> -void CodeGenBase::initCRegs< - inst_set_t::avx512_vnni>( - x86::Emitter* a, - int rowRegs, - int colRegs, - int leadingDimCReg) { - assert(0 && "Accumulation to int16_t is not available for VNNI!"); - - // For AVX512VNNI, redirect to int32_t accumulation. - CodeGenBase codeObj; - codeObj.initCRegs( - a, rowRegs, colRegs, leadingDimCReg); -} - -/** - * Generate AVX512 instructions for computing block in the rank-k update of - * 16-bit Accmulation kernel. - */ -template <> -template <> -void CodeGenBase::genComputeBlock< - inst_set_t::avx512_vnni>( - x86::Emitter* a, - x86::Gp buffer_A, - x86::Gp buffer_B, - x86::Gp /* unused (reserved for prefetching)*/, - int rowRegs, - int colRegs, - int lda, - int leadingDimCReg) { - assert(0 && "Accumulation to int16_t is not available for VNNI!"); - - // For AVX512VNNI, redirect to int32_t accumulation. - CodeGenBase codeObj; - codeObj.genComputeBlock( - a, buffer_A, buffer_B, buffer_B, rowRegs, colRegs, lda, leadingDimCReg); -} - -/** - * Generate AVX512 instructions for storing the C registers back to the memory - * in 16-bit Accumulation kernel. - */ -template <> -template <> -void CodeGenBase::storeCRegs< - inst_set_t::avx512_vnni>( - x86::Emitter* a, - int rowRegs, - int colRegs, - x86::Gp C_Offset, - x86::Gp ldcReg, - bool accum, - int leadingDimCReg) { - assert(0 && "Accumulation to int16_t is not available for VNNI!"); - - // For AVX512VNNI, redirect to int32_t accumulation. - CodeGenBase codeObj; - codeObj.storeCRegs( - a, rowRegs, colRegs, C_Offset, ldcReg, accum, leadingDimCReg); -} - -/** - * Get or Create the AVX512 instructions for 16-bit Accumulation macro-kernel. - * - */ -template <> -template <> -CodeGenBase::jit_micro_kernel_fp -CodeGenBase::getOrCreate< - inst_set_t::avx512_vnni>( - bool accum, - int32_t mc, - int32_t nc, - int32_t kc, - int32_t /* unused */) { - assert(0 && "Accumulation to int16_t is not available for VNNI!"); - - // For AVX512VNNI, redirect to int32_t accumulation. - CodeGenBase codeObj; - return codeObj.getOrCreate(accum, mc, nc, kc, kc); -} - -} // namespace fbgemm diff --git a/src/GenerateKernelU8S8S32ACC32.cc b/src/GenerateKernelU8S8S32ACC32.cc index 6b547434a3..58643adc33 100644 --- a/src/GenerateKernelU8S8S32ACC32.cc +++ b/src/GenerateKernelU8S8S32ACC32.cc @@ -31,7 +31,7 @@ template <> template <> void CodeGenBase::initCRegs< inst_set_t::avx2>( - x86::Emitter* a, + asmjit::X86Emitter* a, int rowRegs, int colRegs, int leadingDimCReg) { @@ -53,25 +53,25 @@ template <> template <> void CodeGenBase::genComputeBlock< inst_set_t::avx2>( - x86::Emitter* a, - x86::Gp buffer_A, - x86::Gp buffer_B, - x86::Gp B_pf, + asmjit::X86Emitter* a, + asmjit::X86Gp buffer_A, + asmjit::X86Gp buffer_B, + asmjit::X86Gp B_pf, int rowRegs, int colRegs, int lda, int leadingDimCReg) { // used for matrix A - x86::Ymm AReg = x86::ymm12; + asmjit::X86Ymm AReg = x86::ymm12; // used for matrix B - x86::Ymm BReg = x86::ymm13; + asmjit::X86Ymm BReg = x86::ymm13; // Contains 16-bit 1s - x86::Ymm oneReg = x86::ymm15; + asmjit::X86Ymm oneReg = x86::ymm15; // temporary register - x86::Ymm res1 = x86::ymm14; + asmjit::X86Ymm res1 = x86::ymm14; for (int j = 0; j < colRegs; ++j) { // load B @@ -99,11 +99,11 @@ template <> template <> void CodeGenBase::storeCRegs< inst_set_t::avx2>( - x86::Emitter* a, + asmjit::X86Emitter* a, int rowRegs, int colRegs, - x86::Gp C_Offset, - x86::Gp ldcReg, + asmjit::X86Gp C_Offset, + asmjit::X86Gp ldcReg, bool accum, int leadingDimCReg) { for (int i = 0; i < rowRegs; ++i) { @@ -177,9 +177,9 @@ CodeGenBase::getOrCreate( return codeCache_[kernelSig]; } code_.reset(false); - code_.init(rt_.codeInfo()); - x86::Assembler assembler(&code_); - x86::Emitter* a = assembler.as(); + code_.init(rt_.getCodeInfo()); + asmjit::X86Assembler assembler(&code_); + asmjit::X86Emitter* a = assembler.asEmitter(); #if defined(FBGEMM_LOG_CODE) // generated code logging FILE* codeLogfile = fopen( @@ -205,48 +205,49 @@ CodeGenBase::getOrCreate( int mRegBlocksRem = mc % mRegBlockSize; // arguments to the function created - x86::Gp buffer_A = a->zdi(); - x86::Gp buffer_B = a->zsi(); - x86::Gp B_pf = a->zdx(); - x86::Gp CBase = a->zcx(); - x86::Gp kSize = a->gpz(8); - x86::Gp ldcReg = a->gpz(9); + asmjit::X86Gp buffer_A = a->zdi(); + asmjit::X86Gp buffer_B = a->zsi(); + asmjit::X86Gp B_pf = a->zdx(); + asmjit::X86Gp CBase = a->zcx(); + asmjit::X86Gp kSize = a->gpzRef(8); + asmjit::X86Gp ldcReg = a->gpzRef(9); asmjit::FuncDetail func; func.init( asmjit:: - FuncSignatureT( + FuncSignature6( asmjit::CallConv::kIdHost)); - asmjit::FuncFrame frame; - frame.init(func); - frame.setDirtyRegs( - x86::Reg::kGroupVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - frame.setDirtyRegs( - x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14)); + asmjit::FuncFrameInfo ffi; + ffi.setDirtyRegs( + asmjit::X86Reg::kKindVec, + asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); + ffi.setDirtyRegs( + asmjit::X86Reg::kKindGp, asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14)); - asmjit::FuncArgsAssignment args(&func); + asmjit::FuncArgsMapper args(&func); args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - args.updateFuncFrame(frame); - frame.finalize(); + args.updateFrameInfo(ffi); - a->emitProlog(frame); - a->emitArgsAssignment(frame, args); + asmjit::FuncFrameLayout layout; + layout.init(func, ffi); + + asmjit::FuncUtils::emitProlog(a, layout); + asmjit::FuncUtils::allocArgs(a, layout, args); asmjit::Label Loopk = a->newLabel(); asmjit::Label LoopMBlocks = a->newLabel(); - x86::Gp buffer_B_saved = a->gpz(10); - x86::Gp C_Offset = a->gpz(11); - x86::Gp B_pf_saved = a->gpz(12); - x86::Gp iIdx = a->gpz(13); - x86::Gp kIdx = a->gpz(14); - // x86::Gp B_pf = a->gpz(8); + asmjit::X86Gp buffer_B_saved = a->gpzRef(10); + asmjit::X86Gp C_Offset = a->gpzRef(11); + asmjit::X86Gp B_pf_saved = a->gpzRef(12); + asmjit::X86Gp iIdx = a->gpzRef(13); + asmjit::X86Gp kIdx = a->gpzRef(14); + // asmjit::X86Gp B_pf = a->gpzRef(8); - x86::Ymm oneReg = x86::ymm15; + asmjit::X86Ymm oneReg = x86::ymm15; // create 16-bit 1s // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 // and so on @@ -357,7 +358,7 @@ CodeGenBase::getOrCreate( a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); } - a->emitEpilog(frame); + asmjit::FuncUtils::emitEpilog(a, layout); jit_micro_kernel_fp fn; asmjit::Error err = rt_.add(&fn, &code_); diff --git a/src/GenerateKernelU8S8S32ACC32Avx512.cc b/src/GenerateKernelU8S8S32ACC32Avx512.cc index fe356270bb..12243ee780 100644 --- a/src/GenerateKernelU8S8S32ACC32Avx512.cc +++ b/src/GenerateKernelU8S8S32ACC32Avx512.cc @@ -19,7 +19,7 @@ template <> template <> void CodeGenBase::initCRegs< inst_set_t::avx512>( - x86::Emitter* a, + asmjit::X86Emitter* a, int rowRegs, int colRegs, int leadingDimCReg) { @@ -41,25 +41,25 @@ template <> template <> void CodeGenBase::genComputeBlock< inst_set_t::avx512>( - x86::Emitter* a, - x86::Gp buffer_A, - x86::Gp buffer_B, - x86::Gp B_pf, + asmjit::X86Emitter* a, + asmjit::X86Gp buffer_A, + asmjit::X86Gp buffer_B, + asmjit::X86Gp B_pf, int rowRegs, int colRegs, int lda, int leadingDimCReg) { // used for matrix A - x86::Zmm AReg = x86::zmm31; + asmjit::X86Zmm AReg = x86::zmm31; // used for matrix B - x86::Zmm BReg = x86::zmm30; + asmjit::X86Zmm BReg = x86::zmm30; // Contains 16-bit 1s - x86::Zmm oneReg = x86::zmm29; + asmjit::X86Zmm oneReg = x86::zmm29; // temporary register - x86::Zmm res1 = x86::zmm28; + asmjit::X86Zmm res1 = x86::zmm28; for (int j = 0; j < colRegs; ++j) { // load B @@ -87,17 +87,18 @@ template <> template <> void CodeGenBase::storeCRegs< inst_set_t::avx512>( - x86::Emitter* a, + asmjit::X86Emitter* a, int rowRegs, int colRegs, - x86::Gp C_Offset, - x86::Gp ldcReg, + asmjit::X86Gp C_Offset, + asmjit::X86Gp ldcReg, bool accum, int leadingDimCReg) { for (int i = 0; i < rowRegs; ++i) { if (i != 0) { a->add(C_Offset, ldcReg); - } else { + } + else { a->mov(C_Offset, static_cast(0)); } for (int j = 0; j < colRegs; ++j) { @@ -167,9 +168,9 @@ CodeGenBase::getOrCreate( return codeCache_[kernelSig]; } code_.reset(false); - code_.init(rt_.codeInfo()); - x86::Assembler assembler(&code_); - x86::Emitter* a = assembler.as(); + code_.init(rt_.getCodeInfo()); + asmjit::X86Assembler assembler(&code_); + asmjit::X86Emitter* a = assembler.asEmitter(); #if defined(FBGEMM_LOG_CODE) // generated code logging @@ -204,52 +205,52 @@ CodeGenBase::getOrCreate( int mRegBlocksRem = mc % mRegBlockSize; // arguments to the function created - x86::Gp buffer_A = a->zdi(); - x86::Gp buffer_B = a->zsi(); - x86::Gp B_pf = a->zdx(); - x86::Gp CBase = a->zcx(); - x86::Gp kSize = a->gpz(8); - x86::Gp ldcReg = a->gpz(9); + asmjit::X86Gp buffer_A = a->zdi(); + asmjit::X86Gp buffer_B = a->zsi(); + asmjit::X86Gp B_pf = a->zdx(); + asmjit::X86Gp CBase = a->zcx(); + asmjit::X86Gp kSize = a->gpzRef(8); + asmjit::X86Gp ldcReg = a->gpzRef(9); asmjit::FuncDetail func; func.init( asmjit:: - FuncSignatureT( + FuncSignature6( asmjit::CallConv::kIdHost)); - asmjit::FuncFrame frame; - frame.init(func); - - frame.setDirtyRegs( - x86::Reg::kGroupVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - frame.setDirtyRegs( - x86::Reg::kGroupGp, - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); + asmjit::FuncFrameInfo ffi; + ffi.setDirtyRegs( + asmjit::X86Reg::kKindVec, + asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); + ffi.setDirtyRegs( + asmjit::X86Reg::kKindGp, + asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); - asmjit::FuncArgsAssignment args(&func); + asmjit::FuncArgsMapper args(&func); args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - args.updateFuncFrame(frame); - frame.finalize(); + args.updateFrameInfo(ffi); + + asmjit::FuncFrameLayout layout; + layout.init(func, ffi); - a->emitProlog(frame); - a->emitArgsAssignment(frame, args); + asmjit::FuncUtils::emitProlog(a, layout); + asmjit::FuncUtils::allocArgs(a, layout, args); asmjit::Label LoopMBlocks = a->newLabel(); asmjit::Label LoopNBlocks = a->newLabel(); asmjit::Label Loopk = a->newLabel(); - x86::Gp buffer_B_saved = a->gpz(10); - x86::Gp C_Offset = a->gpz(11); - x86::Gp B_pf_saved = a->gpz(12); - x86::Gp iIdx = a->gpz(13); - x86::Gp jIdx = a->gpz(14); - x86::Gp kIdx = a->gpz(15); - // x86::Gp B_pf = a->gpz(8); + asmjit::X86Gp buffer_B_saved = a->gpzRef(10); + asmjit::X86Gp C_Offset = a->gpzRef(11); + asmjit::X86Gp B_pf_saved = a->gpzRef(12); + asmjit::X86Gp iIdx = a->gpzRef(13); + asmjit::X86Gp jIdx = a->gpzRef(14); + asmjit::X86Gp kIdx = a->gpzRef(15); + // asmjit::X86Gp B_pf = a->gpzRef(8); - x86::Zmm oneReg = x86::zmm29; + asmjit::X86Zmm oneReg = x86::zmm29; // create 16-bit 1s // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 // and so on @@ -419,7 +420,7 @@ CodeGenBase::getOrCreate( a->jl(LoopNRem); } - a->emitEpilog(frame); + asmjit::FuncUtils::emitEpilog(a, layout); jit_micro_kernel_fp fn; asmjit::Error err = rt_.add(&fn, &code_); diff --git a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc b/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc deleted file mode 100644 index 8ae0745fb3..0000000000 --- a/src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc +++ /dev/null @@ -1,431 +0,0 @@ -/* - * Copyright (c) Facebook, Inc. and its affiliates. - * All rights reserved. - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ -#include -#include "GenerateKernel.h" - -namespace fbgemm { - -namespace x86 = asmjit::x86; - -/** - * Generate AVX512 instructions for initializing the C registers to 0 in 32-bit - * Accumulation kernel. - */ -template <> -template <> -void CodeGenBase::initCRegs< - inst_set_t::avx512_vnni>( - x86::Emitter* a, - int rowRegs, - int colRegs, - int leadingDimCReg) { - for (int i = 0; i < rowRegs; ++i) { - for (int j = 0; j < colRegs; ++j) { - a->vxorps( - CRegs_avx512_[i * leadingDimCReg + j], - CRegs_avx512_[i * leadingDimCReg + j], - CRegs_avx512_[i * leadingDimCReg + j]); - } - } -} - -/** - * Generate AVX512 instructions for computing block in the rank-k update of - * 32-bit Accmulation kernel. - */ -template <> -template <> -void CodeGenBase::genComputeBlock< - inst_set_t::avx512_vnni>( - x86::Emitter* a, - x86::Gp buffer_A, - x86::Gp buffer_B, - x86::Gp B_pf, - int rowRegs, - int colRegs, - int lda, - int leadingDimCReg) { - // used for matrix A - x86::Zmm AReg = x86::zmm31; - - // used for matrix B - x86::Zmm BReg = x86::zmm30; - - for (int j = 0; j < colRegs; ++j) { - // load B - a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t))); - // load A, broadcast and fmas - for (int i = 0; i < rowRegs; ++i) { - a->vpbroadcastd( - AReg, x86::dword_ptr(buffer_A, (i * lda) * sizeof(uint8_t))); - a->vpdpbusd(CRegs_avx512_[i * leadingDimCReg + j], AReg, BReg); - } - a->prefetcht0(x86::dword_ptr(B_pf, j * VLEN_ * sizeof(int8_t))); - } -} - -/** - * Generate AVX512 instructions for storing the C registers back to the memory - * in 32-bit Accumulation kernel. - */ -template <> -template <> -void CodeGenBase::storeCRegs< - inst_set_t::avx512_vnni>( - x86::Emitter* a, - int rowRegs, - int colRegs, - x86::Gp C_Offset, - x86::Gp ldcReg, - bool accum, - int leadingDimCReg) { - for (int i = 0; i < rowRegs; ++i) { - if (i != 0) { - a->add(C_Offset, ldcReg); - } else { - a->mov(C_Offset, static_cast(0)); - } - for (int j = 0; j < colRegs; ++j) { - if (accum) { - a->vpaddd( - CRegs_avx512_[i * leadingDimCReg + j], - CRegs_avx512_[i * leadingDimCReg + j], - x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t))); - } - a->vmovups( - x86::dword_ptr(a->zcx(), C_Offset, 0, j * 16 * sizeof(int32_t)), - CRegs_avx512_[i * leadingDimCReg + j]); - } - } -} - -/** - * Get or Create the AVX512 instructions for 32-bit Accumulation macro-kernel. - * - */ -template <> -template <> -CodeGenBase::jit_micro_kernel_fp -CodeGenBase::getOrCreate< - inst_set_t::avx512_vnni>( - bool accum, - int32_t mc, - int32_t nc, - int32_t kc, - int32_t /* unused */) { - std::tuple kernelSig; - int kBlock; - int nBlock; - int mRegBlockSize; - int nRegBlockSize; - int nRegBlockSizeMin; - int row_interleave; - - if (blocking_params) { - kBlock = blocking_params->KCB; - nBlock = blocking_params->NCB; - mRegBlockSize = blocking_params->MR; - nRegBlockSize = blocking_params->NR; - nRegBlockSizeMin = blocking_params->NR_MIN; - row_interleave = blocking_params->ROW_INTERLEAVE; - } else { - kBlock = PackingTraits::KCB; - nBlock = PackingTraits::NCB; - mRegBlockSize = - PackingTraits::MR; - nRegBlockSize = - PackingTraits::NR; - nRegBlockSizeMin = - PackingTraits::NR_MIN; - row_interleave = PackingTraits:: - ROW_INTERLEAVE; - } - - kernelSig = std::make_tuple( - accum, - mc, - nc, - nBlock, - kBlock, - mRegBlockSize, - nRegBlockSize, - nRegBlockSizeMin); - - if (codeCache_.find(kernelSig) != codeCache_.end()) { - return codeCache_[kernelSig]; - } - code_.reset(false); - code_.init(rt_.codeInfo()); - x86::Assembler assembler(&code_); - x86::Emitter* a = assembler.as(); - -#if defined(FBGEMM_LOG_CODE) - // generated code logging - FILE* codeLogfile = fopen( - getCodeLoggingFile( - accum, - mc, - nc, - nBlock, - kBlock, - mRegBlockSize, - nRegBlockSize, - nRegBlockSizeMin) - .c_str(), - "w"); - asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); - if (codeLogger) { - code_.setLogger(codeLogger); - } -#endif - - assert(kc % row_interleave == 0 && "kc must be a multiple of row_interleave"); - assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN"); - int maxMRegs = mRegBlockSize; - int maxNRegs = nRegBlockSize * row_interleave / VLEN_; - assert( - maxMRegs * maxNRegs <= 28 && - "MR*(NR*ROW_INTERLEAVE*8/512) \ - must be <= 28(available registers constraint)"); - - int mRegBlocks = mc / mRegBlockSize; - int mRegBlocksRem = mc % mRegBlockSize; - - // arguments to the function created - x86::Gp buffer_A = a->zdi(); - x86::Gp buffer_B = a->zsi(); - x86::Gp B_pf = a->zdx(); - x86::Gp CBase = a->zcx(); - x86::Gp kSize = a->gpz(8); - x86::Gp ldcReg = a->gpz(9); - - asmjit::FuncDetail func; - func.init( - asmjit:: - FuncSignatureT( - asmjit::CallConv::kIdHost)); - - asmjit::FuncFrame frame; - frame.init(func); - - frame.setDirtyRegs( - x86::Reg::kGroupVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - frame.setDirtyRegs( - x86::Reg::kGroupGp, - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - - asmjit::FuncArgsAssignment args(&func); - args.assignAll(buffer_A, buffer_B, B_pf, CBase, kSize, ldcReg); - - args.updateFuncFrame(frame); - frame.finalize(); - - a->emitProlog(frame); - a->emitArgsAssignment(frame, args); - - asmjit::Label LoopMBlocks = a->newLabel(); - asmjit::Label LoopNBlocks = a->newLabel(); - asmjit::Label Loopk = a->newLabel(); - - x86::Gp buffer_B_saved = a->gpz(10); - x86::Gp C_Offset = a->gpz(11); - x86::Gp B_pf_saved = a->gpz(12); - x86::Gp iIdx = a->gpz(13); - x86::Gp jIdx = a->gpz(14); - x86::Gp kIdx = a->gpz(15); - // x86::Gp B_pf = a->gpz(8); - - x86::Zmm oneReg = x86::zmm29; - // create 16-bit 1s - // i.e., oneReg[0:15] contains 0x0001, oneReg[16:31] contains 0x0001 - // and so on - // a->vpcmpeqw(oneReg, oneReg, oneReg); - a->vpternlogd(oneReg, oneReg, oneReg, 0xff); - a->vpsrlw(oneReg, oneReg, 15); - a->imul(ldcReg, ldcReg, static_cast(sizeof(int32_t))); - - // save B_buffer address - a->mov(buffer_B_saved, buffer_B); - a->mov(B_pf_saved, B_pf); - - int currColRegs = nc * row_interleave * sizeof(int8_t) / VLEN_; - int colRegs = std::min(currColRegs, maxNRegs); - if (mRegBlocks > 0) { - // move 0 to iteration variables - a->mov(iIdx, 0); - - a->bind(LoopMBlocks); - a->inc(iIdx); - a->mov(jIdx, 0); - - a->bind(LoopNBlocks); - a->inc(jIdx); - - int rowRegs = mRegBlockSize; - - // init C registers - initCRegs(a, rowRegs, colRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(Loopk); - - // k is incremented by row_interleave - a->add(kIdx, static_cast(row_interleave)); - - genComputeBlock( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); - - // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast(nBlock * row_interleave * sizeof(int8_t))); - a->add( - B_pf, - static_cast(nBlock * row_interleave * sizeof(int8_t))); - - // a->add(B_pf, static_cast(32*sizeof(float))); - - a->cmp(kIdx, kSize); - a->jl(Loopk); - - // store C matrix - storeCRegs( - a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); - - // reset A - a->sub(buffer_A, kSize); - - // B for next block - a->mov(buffer_B, buffer_B_saved); - // using C_Offset as temp reg - a->imul( - C_Offset, - jIdx, - static_cast( - nRegBlockSize * row_interleave * sizeof(int8_t))); - a->add(buffer_B, C_Offset); - a->mov(B_pf, B_pf_saved); - a->add(B_pf, C_Offset); - - // increment C for next B block - a->add(CBase, static_cast(nRegBlockSize * sizeof(int32_t))); - - int jLoopTrips = currColRegs / maxNRegs; - // jLoopTrips should be at least 1 - jLoopTrips = jLoopTrips ? jLoopTrips : 1; - a->cmp(jIdx, jLoopTrips); - a->jl(LoopNBlocks); - - // increment A for next block - a->add( - buffer_A, static_cast((rowRegs)*kBlock * sizeof(uint8_t))); - - // increment C for next A block - a->sub( - CBase, - static_cast(jLoopTrips * nRegBlockSize * sizeof(int32_t))); - a->imul(C_Offset, ldcReg, static_cast(rowRegs)); - a->add(CBase, C_Offset); - - // reset B - a->mov(buffer_B, buffer_B_saved); - a->mov(B_pf, B_pf_saved); - a->cmp(iIdx, mRegBlocks); - a->jl(LoopMBlocks); - } - // generate code for remainder - if (mRegBlocksRem > 0) { - asmjit::Label LoopNRem = a->newLabel(); - asmjit::Label LoopkRem = a->newLabel(); - int rowRegs = mRegBlocksRem; - - a->mov(jIdx, 0); - a->bind(LoopNRem); - a->inc(jIdx); - - // init C registers - initCRegs(a, rowRegs, colRegs, colRegs); - - // init k loop index - a->mov(kIdx, 0); - a->bind(LoopkRem); - - // k is incremented by row_interleave - a->add(kIdx, static_cast(row_interleave)); - - genComputeBlock( - a, buffer_A, buffer_B, B_pf, rowRegs, colRegs, kBlock, colRegs); - - // update buffer_A address for next k iteration - a->add( - buffer_A, static_cast(row_interleave * sizeof(uint8_t))); - - // update buffer_B address for next k iteration - a->add( - buffer_B, - static_cast(nBlock * row_interleave * sizeof(int8_t))); - a->add( - B_pf, - static_cast(nBlock * row_interleave * sizeof(int8_t))); - - a->cmp(kIdx, kSize); - a->jl(LoopkRem); - - // reset A - a->sub(buffer_A, kSize); - // B for next block - // using C_Offset as temp reg - a->imul( - C_Offset, - jIdx, - static_cast( - nRegBlockSize * row_interleave * sizeof(int8_t))); - a->mov(buffer_B, buffer_B_saved); - a->add(buffer_B, C_Offset); - a->mov(B_pf, B_pf_saved); - a->add(B_pf, C_Offset); - - // store C matrix - storeCRegs( - a, rowRegs, colRegs, C_Offset, ldcReg, accum, colRegs); - - // increment C for next B block - a->add(CBase, static_cast(nRegBlockSize * sizeof(int32_t))); - - int jLoopTrips = currColRegs / maxNRegs; - // jLoopTrips should be at least 1 - jLoopTrips = jLoopTrips ? jLoopTrips : 1; - a->cmp(jIdx, jLoopTrips); - a->jl(LoopNRem); - } - - a->emitEpilog(frame); - - jit_micro_kernel_fp fn; - asmjit::Error err = rt_.add(&fn, &code_); - if (err) { - std::cout << "Error: in fn add" << std::endl; - return nullptr; - } - codeCache_[kernelSig] = fn; - -#if defined(FBGEMM_LOG_CODE) - fclose(codeLogfile); - delete codeLogger; -#endif - - return fn; -} - -} // namespace fbgemm diff --git a/src/GroupwiseConv.h b/src/GroupwiseConv.h index 4c5eea57e7..1e6324e022 100644 --- a/src/GroupwiseConv.h +++ b/src/GroupwiseConv.h @@ -128,58 +128,60 @@ class GenConvKernel { const conv_param_t& conv_param); template - void createVector16BitOne(x86::Emitter* a); + void createVector16BitOne(asmjit::X86Emitter* a); template - void createVector8BitOne(x86::Emitter* a); + void createVector8BitOne(asmjit::X86Emitter* a); template - void setToZeroPt(x86::Emitter* a, x86::Ymm destReg); + void setToZeroPt(asmjit::X86Emitter* a, asmjit::X86Ymm destReg); template - void gen8bitFMA(x86::Emitter* a, x86::Ymm aReg, x86::Ymm wReg); + void + gen8bitFMA(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm wReg); template - void genForLoadingWeights(x86::Emitter* a, int c_offset); + void genForLoadingWeights(asmjit::X86Emitter* a, int c_offset); template - void genConstForPermutations(x86::Emitter* a); + void genConstForPermutations(asmjit::X86Emitter* a); template - void genForTopEdge(x86::Emitter* a, int c_offset); + void genForTopEdge(asmjit::X86Emitter* a, int c_offset); template - void genForLeftEdge(x86::Emitter* a, int c_offset); + void genForLeftEdge(asmjit::X86Emitter* a, int c_offset); template - void genForRightEdge(x86::Emitter* a, int c_offset); + void genForRightEdge(asmjit::X86Emitter* a, int c_offset); template - void genForBottomEdge(x86::Emitter* a, int c_offset); + void genForBottomEdge(asmjit::X86Emitter* a, int c_offset); template - void genCoreInsts(x86::Emitter* a, int c_offset); + void genCoreInsts(asmjit::X86Emitter* a, int c_offset); template - void storeResult(x86::Emitter* a); + void storeResult(asmjit::X86Emitter* a); // for Rowoffset kernel // Add 4 consecutive numbers of 32 uint8 and emit 8 32-bit template - void gen8BitSumX4(x86::Emitter* a, x86::Ymm aReg); + void gen8BitSumX4(asmjit::X86Emitter* a, asmjit::X86Ymm aReg); // Add 8 consecutive numbers of 64 uint8 and emit 8 32-bit template - void gen8BitSumX8(x86::Emitter* a, x86::Ymm aReg, x86::Ymm bReg); + void + gen8BitSumX8(asmjit::X86Emitter* a, asmjit::X86Ymm aReg, asmjit::X86Ymm bReg); // Add 16 consecutive numbers of 128 uint8 and emit 8 32-bit template void gen8BitSumX16( - x86::Emitter* a, - x86::Ymm aReg, - x86::Ymm bReg, - x86::Ymm cReg, - x86::Ymm dReg); + asmjit::X86Emitter* a, + asmjit::X86Ymm aReg, + asmjit::X86Ymm bReg, + asmjit::X86Ymm cReg, + asmjit::X86Ymm dReg); // Generate instruction sequence that loads 8-bit values and sum them up. // Depending on C_per_G_, this function dispatches to gen8BitSumX4/8/16 @@ -189,33 +191,35 @@ class GenConvKernel { // Internally, actRegAvx2_, stPermRegAvx2_, WRegs_avx2_[0, 1], tmpReg1Avx2_, // and resultRegAvx2_ are used. template - void - gen8BitSum(x86::Emitter* a, int act_offset, bool use_scratch_reg1 = true); + void gen8BitSum( + asmjit::X86Emitter* a, + int act_offset, + bool use_scratch_reg1 = true); // Use scratchReg1_ and tmpReg1Avx2_ internally template - void genZeroPtSum(x86::Emitter* a, int multiplier); + void genZeroPtSum(asmjit::X86Emitter* a, int multiplier); template - void genForTopEdgeRowoffset(x86::Emitter* a); + void genForTopEdgeRowoffset(asmjit::X86Emitter* a); template - void genForLeftEdgeRowoffset(x86::Emitter* a); + void genForLeftEdgeRowoffset(asmjit::X86Emitter* a); template - void genForRightEdgeRowoffset(x86::Emitter* a); + void genForRightEdgeRowoffset(asmjit::X86Emitter* a); template - void genForBottomEdgeRowoffset(x86::Emitter* a); + void genForBottomEdgeRowoffset(asmjit::X86Emitter* a); template - void genRowoffsetCorners(x86::Emitter* a); + void genRowoffsetCorners(asmjit::X86Emitter* a); template - void genRowoffsetCore(x86::Emitter* a); + void genRowoffsetCore(asmjit::X86Emitter* a); template - void storeResultRowoffset(x86::Emitter* a, int offset = 0); + void storeResultRowoffset(asmjit::X86Emitter* a, int offset = 0); static thread_local asmjit::JitRuntime rt_; ///< JIT Runtime for asmjit. static thread_local asmjit::CodeHolder code_; ///< JIT Code Holder for asmjit. @@ -230,30 +234,30 @@ class GenConvKernel { int vectorWidth_; ///< Vector width in bits. int VLEN_; ///< Vector width in elements. // avx2 specific - x86::Ymm + asmjit::X86Ymm WRegs_avx2_[9]; ///< AVX2 ymm registers for weights in the micro-kernel. - x86::Ymm zeroPTRegAvx2_; - x86::Ymm tmpReg1Avx2_; - x86::Ymm stPermRegAvx2_; - x86::Ymm actRegAvx2_; - x86::Ymm resultRegAvx2_; - x86::Ymm oneReg8BitAvx2_; - x86::Ymm oneReg16BitAvx2_; + asmjit::X86Ymm zeroPTRegAvx2_; + asmjit::X86Ymm tmpReg1Avx2_; + asmjit::X86Ymm stPermRegAvx2_; + asmjit::X86Ymm actRegAvx2_; + asmjit::X86Ymm resultRegAvx2_; + asmjit::X86Ymm oneReg8BitAvx2_; + asmjit::X86Ymm oneReg16BitAvx2_; // arguments to the function created - x86::Gp in_acts_R_; - x86::Gp wghts_R_; - x86::Gp out_acts_R_; - x86::Gp a_zero_pt_R_; - x86::Gp H_R_; - x86::Gp W_R_; - x86::Gp row_offset_R_; + asmjit::X86Gp in_acts_R_; + asmjit::X86Gp wghts_R_; + asmjit::X86Gp out_acts_R_; + asmjit::X86Gp a_zero_pt_R_; + asmjit::X86Gp H_R_; + asmjit::X86Gp W_R_; + asmjit::X86Gp row_offset_R_; // Used registers - x86::Gp loopR1_; - x86::Gp loopR2_; - x86::Gp scratchReg1_; - x86::Gp scratchReg2_; + asmjit::X86Gp loopR1_; + asmjit::X86Gp loopR2_; + asmjit::X86Gp scratchReg1_; + asmjit::X86Gp scratchReg2_; // Other parameters bool isAZeroPointZero_; diff --git a/src/GroupwiseConvAcc32Avx2.cc b/src/GroupwiseConvAcc32Avx2.cc index b140c83ff5..e789695a39 100644 --- a/src/GroupwiseConvAcc32Avx2.cc +++ b/src/GroupwiseConvAcc32Avx2.cc @@ -104,7 +104,7 @@ jit_conv_kernel_fp getOrCreateConvKernel( template <> template <> void GenConvKernel<2, int32_t>::createVector8BitOne( - x86::Emitter* a) { + asmjit::X86Emitter* a) { // create 8-bit 1s // i.e., oneReg16BitAvx2_[0:7] contains 0x01, oneReg8BitAvx2_[8:15] contains // 0x01 and so on @@ -115,7 +115,7 @@ void GenConvKernel<2, int32_t>::createVector8BitOne( template <> template <> void GenConvKernel<2, int32_t>::createVector16BitOne( - x86::Emitter* a) { + asmjit::X86Emitter* a) { // create 16-bit 1s // i.e., oneReg16BitAvx2_[0:15] contains 0x0001, oneReg16BitAvx2_[16:31] // contains 0x0001 and so on @@ -125,11 +125,11 @@ void GenConvKernel<2, int32_t>::createVector16BitOne( template <> template <> void GenConvKernel<2, int32_t>::setToZeroPt( - x86::Emitter* a, - x86::Ymm destReg) { + asmjit::X86Emitter* a, + asmjit::X86Ymm destReg) { // make destReg all zeros a->vxorps(destReg, destReg, destReg); - x86::Xmm const_reg_xmm = x86::xmm10; + asmjit::X86Xmm const_reg_xmm = x86::xmm10; // move zero point to xmm10 a->movq(const_reg_xmm, a_zero_pt_R_); // make copies of zero point @@ -143,9 +143,9 @@ void GenConvKernel<2, int32_t>::setToZeroPt( template <> template <> void GenConvKernel<2, int32_t>::genConstForPermutations( - x86::Emitter* a) { - x86::Gp permute_const_reg = a->gpz(12); - x86::Xmm const_reg_xmm = x86::xmm10; + asmjit::X86Emitter* a) { + asmjit::X86Gp permute_const_reg = a->gpzRef(12); + asmjit::X86Xmm const_reg_xmm = x86::xmm10; // We have 1st group in even lanes and 2nd group in odd lanes. // Permute to put 1st group to lower 128-bit and 2nd group in upper // 128-bit. @@ -159,7 +159,8 @@ void GenConvKernel<2, int32_t>::genConstForPermutations( template <> template <> -void GenConvKernel<2, int32_t>::storeResult(x86::Emitter* a) { +void GenConvKernel<2, int32_t>::storeResult( + asmjit::X86Emitter* a) { if (C_per_G_ == 4) { // store with permutation a->vpermd(resultRegAvx2_, stPermRegAvx2_, resultRegAvx2_); @@ -170,7 +171,7 @@ void GenConvKernel<2, int32_t>::storeResult(x86::Emitter* a) { template <> template <> void GenConvKernel<2, int32_t>::storeResultRowoffset( - x86::Emitter* a, + asmjit::X86Emitter* a, int offset) { // store if (C_per_G_ == 4) { @@ -197,7 +198,7 @@ void GenConvKernel<2, int32_t>::storeResultRowoffset( template <> template <> void GenConvKernel<2, int32_t>::genForLoadingWeights( - x86::Emitter* a, + asmjit::X86Emitter* a, int c_offset) { // load weights for (int r = 0; r < R_; ++r) { @@ -224,9 +225,9 @@ void GenConvKernel<2, int32_t>::genForLoadingWeights( template <> template <> void GenConvKernel<2, int32_t>::gen8bitFMA( - x86::Emitter* a, - x86::Ymm aReg, - x86::Ymm wReg) { + asmjit::X86Emitter* a, + asmjit::X86Ymm aReg, + asmjit::X86Ymm wReg) { a->vpmaddubsw(tmpReg1Avx2_, aReg, wReg); a->vpmaddwd(tmpReg1Avx2_, oneReg16BitAvx2_, tmpReg1Avx2_); a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_); @@ -235,8 +236,8 @@ void GenConvKernel<2, int32_t>::gen8bitFMA( template <> template <> void GenConvKernel<2, int32_t>::gen8BitSumX4( - x86::Emitter* a, - x86::Ymm aReg) { + asmjit::X86Emitter* a, + asmjit::X86Ymm aReg) { a->vpmaddubsw(tmpReg1Avx2_, aReg, oneReg8BitAvx2_); a->vpmaddwd(tmpReg1Avx2_, tmpReg1Avx2_, oneReg16BitAvx2_); a->vpaddd(resultRegAvx2_, tmpReg1Avx2_, resultRegAvx2_); @@ -245,9 +246,9 @@ void GenConvKernel<2, int32_t>::gen8BitSumX4( template <> template <> void GenConvKernel<2, int32_t>::gen8BitSumX8( - x86::Emitter* a, - x86::Ymm aReg, - x86::Ymm bReg) { + asmjit::X86Emitter* a, + asmjit::X86Ymm aReg, + asmjit::X86Ymm bReg) { a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_); // Let a[0] denote 0th (LSB) 8-bit of aReg // After vpsadbw, a[0:2] = a[0] + ... + a[7] @@ -266,11 +267,11 @@ void GenConvKernel<2, int32_t>::gen8BitSumX8( template <> template <> void GenConvKernel<2, int32_t>::gen8BitSumX16( - x86::Emitter* a, - x86::Ymm aReg, - x86::Ymm bReg, - x86::Ymm cReg, - x86::Ymm dReg) { + asmjit::X86Emitter* a, + asmjit::X86Ymm aReg, + asmjit::X86Ymm bReg, + asmjit::X86Ymm cReg, + asmjit::X86Ymm dReg) { a->vxorps(tmpReg1Avx2_, tmpReg1Avx2_, tmpReg1Avx2_); // After vpsadbw, a[0:2] = a[0] + ... + a[7] // a[8:10] = a[8] + ... + a[15] @@ -318,7 +319,7 @@ void GenConvKernel<2, int32_t>::gen8BitSumX16( template <> template <> void GenConvKernel<2, int32_t>::gen8BitSum( - x86::Emitter* a, + asmjit::X86Emitter* a, int act_offset, bool use_scratch_reg1 /*=true*/) { if (use_scratch_reg1) { @@ -384,11 +385,11 @@ void GenConvKernel<2, int32_t>::gen8BitSum( template <> template <> void GenConvKernel<2, int32_t>::genZeroPtSum( - x86::Emitter* a, + asmjit::X86Emitter* a, int multiplier) { a->mov(scratchReg1_, static_cast(multiplier)); // tmpReg1Avx2_ also uses xmm11 - x86::Xmm const_reg_xmm = x86::xmm11; + asmjit::X86Xmm const_reg_xmm = x86::xmm11; a->movq(const_reg_xmm, scratchReg1_); a->vpbroadcastd(tmpReg1Avx2_, const_reg_xmm); a->vpmulld(tmpReg1Avx2_, zeroPTRegAvx2_, tmpReg1Avx2_); @@ -398,7 +399,7 @@ void GenConvKernel<2, int32_t>::genZeroPtSum( template <> template <> void GenConvKernel<2, int32_t>::genForTopEdge( - x86::Emitter* a, + asmjit::X86Emitter* a, int c_offset) { // top-left corner code if (c_offset == 0) { @@ -558,7 +559,7 @@ void GenConvKernel<2, int32_t>::genForTopEdge( template <> template <> void GenConvKernel<2, int32_t>::genForLeftEdge( - x86::Emitter* a, + asmjit::X86Emitter* a, int c_offset) { // left edge excluding corners asmjit::Label LoopLeftEdge = a->newLabel(); @@ -625,7 +626,7 @@ void GenConvKernel<2, int32_t>::genForLeftEdge( template <> template <> void GenConvKernel<2, int32_t>::genForRightEdge( - x86::Emitter* a, + asmjit::X86Emitter* a, int c_offset) { // right edge excluding corners asmjit::Label LoopRightEdge = a->newLabel(); @@ -713,7 +714,7 @@ void GenConvKernel<2, int32_t>::genForRightEdge( template <> template <> void GenConvKernel<2, int32_t>::genForBottomEdge( - x86::Emitter* a, + asmjit::X86Emitter* a, int c_offset) { // bottom-left corner // we updating the last row @@ -905,7 +906,7 @@ void GenConvKernel<2, int32_t>::genForBottomEdge( template <> template <> void GenConvKernel<2, int32_t>::genCoreInsts( - x86::Emitter* a, + asmjit::X86Emitter* a, int c_offset) { // main compute asmjit::Label LoopH = a->newLabel(); @@ -1010,9 +1011,9 @@ template <> jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate( const conv_param_t<2>& conv_param) { code_.reset(false); - code_.init(rt_.codeInfo()); - x86::Assembler assembler(&code_); - x86::Emitter* a = assembler.as(); + code_.init(rt_.getCodeInfo()); + asmjit::X86Assembler assembler(&code_); + asmjit::X86Emitter* a = assembler.asEmitter(); #if defined(FBGEMM_LOG_CODE) // log code to a file @@ -1029,16 +1030,16 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate( wghts_R_ = a->zsi(); out_acts_R_ = a->zdx(); a_zero_pt_R_ = a->zcx(); - H_R_ = a->gpz(8); - W_R_ = a->gpz(9); - row_offset_R_ = a->gpz(10); + H_R_ = a->gpzRef(8); + W_R_ = a->gpzRef(9); + row_offset_R_ = a->gpzRef(10); // register for temporary use - scratchReg1_ = a->gpz(12); - scratchReg2_ = a->gpz(13); + scratchReg1_ = a->gpzRef(12); + scratchReg2_ = a->gpzRef(13); asmjit::FuncDetail func; - func.init(asmjit::FuncSignatureT< + func.init(asmjit::FuncSignature6< void, uint8_t*, int8_t*, @@ -1047,29 +1048,29 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate( int32_t, int32_t>(asmjit::CallConv::kIdHost)); - asmjit::FuncFrame frame; - frame.init(func); + asmjit::FuncFrameInfo ffi; + ffi.setDirtyRegs( + asmjit::X86Reg::kKindVec, + asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); + ffi.setDirtyRegs( + asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15)); - frame.setDirtyRegs( - x86::Reg::kGroupVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - frame.setDirtyRegs( - x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - - asmjit::FuncArgsAssignment args(&func); + asmjit::FuncArgsMapper args(&func); args.assignAll(in_acts_R_, wghts_R_, out_acts_R_, a_zero_pt_R_, H_R_, W_R_); - args.updateFuncFrame(frame); - frame.finalize(); + args.updateFrameInfo(ffi); + + asmjit::FuncFrameLayout layout; + layout.init(func, ffi); - a->emitProlog(frame); - a->emitArgsAssignment(frame, args); + asmjit::FuncUtils::emitProlog(a, layout); + asmjit::FuncUtils::allocArgs(a, layout, args); createVector16BitOne(a); - loopR1_ = a->gpz(14); - loopR2_ = a->gpz(15); + loopR1_ = a->gpzRef(14); + loopR2_ = a->gpzRef(15); if (!isAZeroPointZero_) { setToZeroPt(a, zeroPTRegAvx2_); @@ -1094,7 +1095,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate( genCoreInsts(a, c); } - a->emitEpilog(frame); + asmjit::FuncUtils::emitEpilog(a, layout); jit_conv_kernel_fp fn; asmjit::Error err = rt_.add(&fn, &code_); @@ -1116,7 +1117,7 @@ jit_conv_kernel_fp GenConvKernel<2, int32_t>::getOrCreate( template <> template <> void GenConvKernel<2, int32_t>::genForTopEdgeRowoffset( - x86::Emitter* a) { + asmjit::X86Emitter* a) { // top-left corner code // zero out the results register a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); @@ -1212,7 +1213,7 @@ void GenConvKernel<2, int32_t>::genForTopEdgeRowoffset( template <> template <> void GenConvKernel<2, int32_t>::genForLeftEdgeRowoffset( - x86::Emitter* a) { + asmjit::X86Emitter* a) { // left edge excluding corners asmjit::Label LoopLeftEdge = a->newLabel(); a->mov(loopR1_, static_cast(H_PAD_)); @@ -1255,7 +1256,7 @@ void GenConvKernel<2, int32_t>::genForLeftEdgeRowoffset( template <> template <> void GenConvKernel<2, int32_t>::genForRightEdgeRowoffset( - x86::Emitter* a) { + asmjit::X86Emitter* a) { // right edge excluding corners asmjit::Label LoopRightEdge = a->newLabel(); @@ -1325,7 +1326,7 @@ void GenConvKernel<2, int32_t>::genForRightEdgeRowoffset( template <> template <> void GenConvKernel<2, int32_t>::genForBottomEdgeRowoffset( - x86::Emitter* a) { + asmjit::X86Emitter* a) { // bottom-left corner // zero out a->vxorps(resultRegAvx2_, resultRegAvx2_, resultRegAvx2_); @@ -1428,7 +1429,7 @@ void GenConvKernel<2, int32_t>::genForBottomEdgeRowoffset( template <> template <> void GenConvKernel<2, int32_t>::genRowoffsetCore( - x86::Emitter* a) { + asmjit::X86Emitter* a) { // number of uint8 elements in input channels should be a multiple of 32 assert(C_ % 32 == 0); @@ -1490,9 +1491,9 @@ jit_rowoffset_kernel_fp GenConvKernel<2, int32_t>::getOrCreateRowOffset( const conv_param_t<2>& conv_param) { code_.reset(false); - code_.init(rt_.codeInfo()); - x86::Assembler assembler(&code_); - x86::Emitter* a = assembler.as(); + code_.init(rt_.getCodeInfo()); + asmjit::X86Assembler assembler(&code_); + asmjit::X86Emitter* a = assembler.asEmitter(); #if defined(FBGEMM_LOG_CODE) // log code to a file @@ -1509,45 +1510,45 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset( a_zero_pt_R_ = a->zsi(); H_R_ = a->zdx(); W_R_ = a->zcx(); - row_offset_R_ = a->gpz(8); + row_offset_R_ = a->gpzRef(8); // register for temporary use - scratchReg1_ = a->gpz(12); - scratchReg2_ = a->gpz(13); + scratchReg1_ = a->gpzRef(12); + scratchReg2_ = a->gpzRef(13); - loopR1_ = a->gpz(14); - loopR2_ = a->gpz(15); + loopR1_ = a->gpzRef(14); + loopR2_ = a->gpzRef(15); asmjit::FuncDetail func; func.init( asmjit:: - FuncSignatureT( + FuncSignature5( asmjit::CallConv::kIdHost)); - asmjit::FuncFrame frame; - frame.init(func); + asmjit::FuncFrameInfo ffi; + ffi.setDirtyRegs( + asmjit::X86Reg::kKindVec, + asmjit::Utils::mask(0, 1, 2, 3, 4, 5, 6, 7) | + asmjit::Utils::mask(8, 9, 10, 11, 12, 13, 14, 15)); + ffi.setDirtyRegs( + asmjit::X86Reg::kKindGp, asmjit::Utils::mask(10, 11, 12, 13, 14, 15)); - frame.setDirtyRegs( - x86::Reg::kGroupVec, - asmjit::Support::bitMask(0, 1, 2, 3, 4, 5, 6, 7) | - asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - frame.setDirtyRegs( - x86::Reg::kGroupGp, asmjit::Support::bitMask(8, 9, 10, 11, 12, 13, 14, 15)); - - asmjit::FuncArgsAssignment args(&func); + asmjit::FuncArgsMapper args(&func); args.assignAll(in_acts_R_, a_zero_pt_R_, H_R_, W_R_, row_offset_R_); - args.updateFuncFrame(frame); - frame.finalize(); + args.updateFrameInfo(ffi); + + asmjit::FuncFrameLayout layout; + layout.init(func, ffi); - a->emitProlog(frame); - a->emitArgsAssignment(frame, args); + asmjit::FuncUtils::emitProlog(a, layout); + asmjit::FuncUtils::allocArgs(a, layout, args); // This uses xmm10 register temporarily. Should come before // createVector8BitOne if (!isAZeroPointZero_) { // we can use xmm11 because ymm11 is used by tmpReg1Avx2_ - x86::Xmm const_reg_xmm = x86::xmm11; + asmjit::X86Xmm const_reg_xmm = x86::xmm11; a->movq(const_reg_xmm, a_zero_pt_R_); a->vpbroadcastd(zeroPTRegAvx2_, const_reg_xmm); @@ -1568,7 +1569,7 @@ GenConvKernel<2, int32_t>::getOrCreateRowOffset( genRowoffsetCore(a); - a->emitEpilog(frame); + asmjit::FuncUtils::emitEpilog(a, layout); jit_rowoffset_kernel_fp fn; asmjit::Error err = rt_.add(&fn, &code_); diff --git a/src/PackAMatrix.cc b/src/PackAMatrix.cc index 5fabf97d95..143e11d1ee 100644 --- a/src/PackAMatrix.cc +++ b/src/PackAMatrix.cc @@ -34,8 +34,7 @@ PackAMatrix::PackAMatrix( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && - !fbgemmHasAvx2Support())) { + if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) { assert(0 && "unknown architecure"); } @@ -44,12 +43,7 @@ PackAMatrix::PackAMatrix( BaseType::bcol_ = params->KCB; row_interleave_B_ = params->ROW_INTERLEAVE; } else { - if (fbgemmHasAvx512VnniSupport()) { - BaseType::brow_ = PackingTraits::MCB; - BaseType::bcol_ = PackingTraits::KCB; - row_interleave_B_ = - PackingTraits::ROW_INTERLEAVE; - } else if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512Support()) { BaseType::brow_ = PackingTraits::MCB; BaseType::bcol_ = PackingTraits::KCB; row_interleave_B_ = diff --git a/src/PackAWithIm2Col.cc b/src/PackAWithIm2Col.cc index 2aca27d15b..d7316546cd 100644 --- a/src/PackAWithIm2Col.cc +++ b/src/PackAWithIm2Col.cc @@ -49,8 +49,7 @@ PackAWithIm2Col::PackAWithIm2Col( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && - !fbgemmHasAvx2Support())) { + if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) { assert(0 && "unknown architecure"); } @@ -59,12 +58,7 @@ PackAWithIm2Col::PackAWithIm2Col( BaseType::bcol_ = params->KCB; row_interleave_B_ = params->ROW_INTERLEAVE; } else { - if (fbgemmHasAvx512VnniSupport()) { - BaseType::brow_ = PackingTraits::MCB; - BaseType::bcol_ = PackingTraits::KCB; - row_interleave_B_ = - PackingTraits::ROW_INTERLEAVE; - } else if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512Support()) { BaseType::brow_ = PackingTraits::MCB; BaseType::bcol_ = PackingTraits::KCB; row_interleave_B_ = @@ -484,9 +478,7 @@ int PackAWithIm2Col::rowOffsetBufferSize( if (params) { return params->MCB; } else { - if (fbgemmHasAvx512VnniSupport()) { - return PackingTraits::MCB; - } else if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512Support()) { return PackingTraits::MCB; } else if (fbgemmHasAvx2Support()) { return PackingTraits::MCB; diff --git a/src/PackAWithQuantRowOffset.cc b/src/PackAWithQuantRowOffset.cc index 0af05e8488..0e5c5987ce 100644 --- a/src/PackAWithQuantRowOffset.cc +++ b/src/PackAWithQuantRowOffset.cc @@ -45,8 +45,7 @@ PackAWithQuantRowOffset::PackAWithQuantRowOffset( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && - !fbgemmHasAvx2Support())) { + if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) { assert(0 && "unknown architecure"); } @@ -55,12 +54,7 @@ PackAWithQuantRowOffset::PackAWithQuantRowOffset( BaseType::bcol_ = params->KCB; row_interleave_B_ = params->ROW_INTERLEAVE; } else { - if (fbgemmHasAvx512VnniSupport()) { - BaseType::brow_ = PackingTraits::MCB; - BaseType::bcol_ = PackingTraits::KCB; - row_interleave_B_ = - PackingTraits::ROW_INTERLEAVE; - } else if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512Support()) { BaseType::brow_ = PackingTraits::MCB; BaseType::bcol_ = PackingTraits::KCB; row_interleave_B_ = @@ -205,9 +199,7 @@ int PackAWithQuantRowOffset::rowOffsetBufferSize( if (params) { return params->MCB; } else { - if (fbgemmHasAvx512VnniSupport()) { - return PackingTraits::MCB; - } else if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512Support()) { return PackingTraits::MCB; } else if (fbgemmHasAvx2Support()) { return PackingTraits::MCB; diff --git a/src/PackAWithRowOffset.cc b/src/PackAWithRowOffset.cc index e84c67b87a..733bf5cde4 100644 --- a/src/PackAWithRowOffset.cc +++ b/src/PackAWithRowOffset.cc @@ -39,8 +39,7 @@ PackAWithRowOffset::PackAWithRowOffset( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && - !fbgemmHasAvx2Support())) { + if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) { assert(0 && "unknown architecure"); } @@ -49,12 +48,7 @@ PackAWithRowOffset::PackAWithRowOffset( BaseType::bcol_ = params->KCB; row_interleave_B_ = params->ROW_INTERLEAVE; } else { - if (fbgemmHasAvx512VnniSupport()) { - BaseType::brow_ = PackingTraits::MCB; - BaseType::bcol_ = PackingTraits::KCB; - row_interleave_B_ = - PackingTraits::ROW_INTERLEAVE; - } else if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512Support()) { BaseType::brow_ = PackingTraits::MCB; BaseType::bcol_ = PackingTraits::KCB; row_interleave_B_ = @@ -195,9 +189,7 @@ int PackAWithRowOffset::rowOffsetBufferSize( if (params) { return params->MCB; } else { - if (fbgemmHasAvx512VnniSupport()) { - return PackingTraits::MCB; - } else if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512Support()) { return PackingTraits::MCB; } else if (fbgemmHasAvx2Support()) { return PackingTraits::MCB; diff --git a/src/PackBMatrix.cc b/src/PackBMatrix.cc index bf43fab7f7..0990edb1c3 100644 --- a/src/PackBMatrix.cc +++ b/src/PackBMatrix.cc @@ -188,8 +188,7 @@ PackBMatrix::PackBMatrix( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && - !fbgemmHasAvx2Support())) { + if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) { assert(0 && "unknown architecure"); } @@ -198,12 +197,7 @@ PackBMatrix::PackBMatrix( BaseType::bcol_ = params->NCB; row_interleave_ = params->ROW_INTERLEAVE; } else { - if (fbgemmHasAvx512VnniSupport()) { - BaseType::brow_ = PackingTraits::KCB; - BaseType::bcol_ = PackingTraits::NCB; - row_interleave_ = - PackingTraits::ROW_INTERLEAVE; - } else if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512Support()) { BaseType::brow_ = PackingTraits::KCB; BaseType::bcol_ = PackingTraits::NCB; row_interleave_ = @@ -323,16 +317,14 @@ void PackBMatrix::pack_unpack_( } template -void PackBMatrix::pack( - const block_type_t& block, - const BlockingFactors* params) { +void PackBMatrix::pack(const block_type_t& block, + const BlockingFactors* params) { pack_unpack_(block, const_cast(smat_), BaseType::getBuf(), true, params); } template -void PackBMatrix::unpack( - T* origin_buf, - const BlockingFactors* params) { +void PackBMatrix::unpack(T* origin_buf, + const BlockingFactors* params) { block_type_t blockB{BaseType::packedRowStart(), BaseType::numPackedRows(), BaseType::packedColStart(), @@ -360,9 +352,8 @@ int32_t PackBMatrix::addr(int32_t r, int32_t c) const { } template -void PackBMatrix::printPackedMatrix( - std::string name, - const BlockingFactors* params) { +void PackBMatrix::printPackedMatrix(std::string name, + const BlockingFactors* params) { std::cout << name << ":" << "[" << BaseType::numPackedRows() << ", " << BaseType::numPackedCols() << "]" << std::endl; diff --git a/src/PackMatrix.cc b/src/PackMatrix.cc index ff7b8424b7..c7503dd5f4 100644 --- a/src/PackMatrix.cc +++ b/src/PackMatrix.cc @@ -36,8 +36,7 @@ int PackMatrix::packedBufferSize( if (!cpuinfo_initialize()) { throw std::runtime_error("Failed to initialize cpuinfo!"); } - if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() && - !fbgemmHasAvx2Support())) { + if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) { assert(0 && "unknown architecure"); } @@ -47,11 +46,7 @@ int PackMatrix::packedBufferSize( NCB = params->NCB; KCB = params->KCB; } else { - if (fbgemmHasAvx512VnniSupport()) { - MCB = PackingTraits::MCB; - NCB = PackingTraits::NCB; - KCB = PackingTraits::KCB; - } else if (fbgemmHasAvx512Support()) { + if (fbgemmHasAvx512Support()) { MCB = PackingTraits::MCB; NCB = PackingTraits::NCB; KCB = PackingTraits::KCB; diff --git a/src/PackWeightMatrixForGConv.cc b/src/PackWeightMatrixForGConv.cc index f6ad59e558..ba6adf372e 100644 --- a/src/PackWeightMatrixForGConv.cc +++ b/src/PackWeightMatrixForGConv.cc @@ -106,7 +106,7 @@ inline int PackWeightMatrixForGConv::packed_index_( * on 2 groups at a time and full SIMD width can be efficiently utilized even * while working on 1 group at a time. * In this case, the layout is G (C/4) R S K 4 - */ +*/ template void PackWeightMatrixForGConv::pack_unpack_( @@ -148,9 +148,9 @@ void PackWeightMatrixForGConv::pack_unpack_( if (ispack) { transposeConvWeights(conv_param_, src, dst); } else { - // TODO: Wrap this as a inverseTransposeConvWeights()? - // For unpack & transposed, call transposeConvWeights() - // G (R S C/G) K/G => G K/G (R S C/G) + // TODO: Wrap this as a inverseTransposeConvWeights()? + // For unpack & transposed, call transposeConvWeights() + // G (R S C/G) K/G => G K/G (R S C/G) for (int r = 0; r < R; ++r) { for (int s = 0; s < S; ++s) { for (int k = 0; k < OC_per_G; ++k) { diff --git a/src/Utils.cc b/src/Utils.cc index 5214e41ba3..0fa620d1e3 100644 --- a/src/Utils.cc +++ b/src/Utils.cc @@ -202,7 +202,4 @@ bool fbgemmHasAvx2Support() { return (cpuinfo_has_x86_avx2()); } -bool fbgemmHasAvx512VnniSupport() { - return (cpuinfo_has_x86_avx512vnni()); -} } // namespace fbgemm diff --git a/test/GConvTest.cc b/test/GConvTest.cc index 8c1fb8253a..0074535404 100644 --- a/test/GConvTest.cc +++ b/test/GConvTest.cc @@ -465,8 +465,8 @@ TEST_P(fbgemmGConvPackTest, PackUnpackTest) { for (int i = 0; i < weight_len; ++i) { EXPECT_EQ(Bint8.data()[i], unpack_buf.data()[i]) << "Pack/Unpack results differ at index " << i - << ", Reference: " << static_cast(Bint8.data()[i]) - << ", Pack-Unpacked: " << static_cast(unpack_buf.data()[i]); + << ", Reference: " << static_cast (Bint8.data()[i]) + << ", Pack-Unpacked: " << static_cast (unpack_buf.data()[i]); } } // for each shape } diff --git a/third_party/asmjit b/third_party/asmjit index 5d40561d14..673dcefaa0 160000 --- a/third_party/asmjit +++ b/third_party/asmjit @@ -1 +1 @@ -Subproject commit 5d40561d14f93dc45613bfa03155d1dfb4f5825a +Subproject commit 673dcefaa048c5f5a2bf8b85daf8f7b9978d018a