Skip to content

Commit

Permalink
Integrate VNNI into FBGEMM master branch (pytorch#114)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#114

Adding the VNNI support in FBGEMM.

Previously, we have the issue on CMake version. Currently PyTorch and FBGEMM OSS test has the CMake 3.5 test, while ASMJIT requires CMake to be 3.8+. This caused the build failure for some platforms. Now the CMake version issue is resolved by a PR to ASMJIT to downgrade the CMake requirement: asmjit/asmjit#252.

Reviewed By: dskhudia

Differential Revision: D16720839

fbshipit-source-id: e5e5f2d26f924df8d9fb955f4a3758561fa73288
  • Loading branch information
jianyuh authored and facebook-github-bot committed Aug 9, 2019
1 parent 122135c commit 7b15607
Show file tree
Hide file tree
Showing 24 changed files with 1,055 additions and 379 deletions.
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ set(FBGEMM_GENERIC_SRCS src/ExecuteKernel.cc
src/FbgemmI8Spmdm.cc
src/GenerateKernelU8S8S32ACC16.cc
src/GenerateKernelU8S8S32ACC16Avx512.cc
src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc
src/GenerateKernelU8S8S32ACC32.cc
src/GenerateKernelU8S8S32ACC32Avx512.cc
src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
src/GroupwiseConvAcc32Avx2.cc
src/PackAMatrix.cc
src/PackAWithIm2Col.cc
Expand Down
50 changes: 50 additions & 0 deletions include/fbgemm/PackingTraits-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -222,3 +222,53 @@ struct PackingTraits<
128}; ///< Cache block for N dimension (multiple of NR).
static constexpr int KCB{256}; ///< Cache block for K dimension.
};

/**
* @brief Helper struct to type specialize for int16_t and int32_t together.
*/
template <typename T>
struct is_16or32bit {
static constexpr bool value =
std::is_same<T, int16_t>::value || std::is_same<T, int32_t>::value;
};

/**
* @brief Packing parameter specialization for accumulation into 32-bit/16-bit
* integers.
*
* Since there is no int16_t accumulation for AVX512 VNNI, we redirect int16_t
* to int32_t accumulation and use the same blocking parameters as int32_t.
*
* This is picked when T is of int8 type (signed or unsigned) and instruction
* set is avx512_vnni.
*/
template <typename T, typename accT>
struct PackingTraits<
T,
accT,
inst_set_t::avx512_vnni,
typename std::enable_if<
is_8bit<T>::value && is_16or32bit<accT>::value>::type> {
static constexpr int MR{8}; ///< Register block for M dimension.
static constexpr int NR_MIN{
16}; ///< Minimum register block for N dimension.
///< 16 because 16*ROW_INTERLEAVE int8 elements
///< completely fill a 512-bit wide vector.
static constexpr int NR{
32}; ///< Register block for N dimension.
///< Must be a multiple of 16 because 16*ROW_INTERLEAVE int8 elements
///< completely fill a 512-bit wide vector. Total registers used for
///< N dimension: NR*ROW_INTERLEAVE*8/VLEN. We use MR x
///< NR*ROW_INTERLEAVE*8/VLEN zmm registers
///< for C accumulations.

static constexpr int ROW_INTERLEAVE{
4}; ///< 4 rows are interleaved to use vpmaddubsw instruction for packing
///< B matrix.

static constexpr int MCB{
128}; ///< Cache block for M dimension (multiple of MR).
static constexpr int NCB{
32}; ///< Cache block for N dimension (multiple of NR).
static constexpr int KCB{256}; ///< Cache block for K dimension.
};
7 changes: 6 additions & 1 deletion include/fbgemm/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ enum class matrix_op_t { NoTranspose, Transpose };
/**
* @brief Typed enum for supported instruction sets.
*/
enum class inst_set_t { anyarch, avx2, avx512 };
enum class inst_set_t { anyarch, avx2, avx512, avx512_vnni };

/**
* @brief Typed enum for optimized paths for convolutions
Expand Down Expand Up @@ -99,6 +99,11 @@ FBGEMM_API bool fbgemmHasAvx512Support();
*/
FBGEMM_API bool fbgemmHasAvx2Support();

/**
* @brief Are we running on a AVX512_VNNI supported cpu?
*/
FBGEMM_API bool fbgemmHasAvx512VnniSupport();

/**
* @brief Helper struct to enable autotuning of FBGEMM packing and kernels.
*
Expand Down
47 changes: 41 additions & 6 deletions src/ExecuteKernelU8S8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ ExecuteKernel<
throw std::runtime_error("Failed to initialize cpuinfo!");
}
if (params) {
if (fbgemmHasAvx512Support() || fbgemmHasAvx2Support()) {
if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support() ||
fbgemmHasAvx2Support()) {
mbSize_ = params->MCB;
nbSize_ = params->NCB;
nrMinSize_ = params->NR_MIN;
Expand All @@ -59,7 +60,20 @@ ExecuteKernel<
assert(0 && "unsupported architecure");
}
} else {
if (fbgemmHasAvx512Support()) {
if (fbgemmHasAvx512VnniSupport()) {
mbSize_ = PackingTraits<
int8_t,
typename packingAMatrix::accType,
inst_set_t::avx512_vnni>::MCB;
nbSize_ = PackingTraits<
int8_t,
typename packingAMatrix::accType,
inst_set_t::avx512_vnni>::NCB;
nrMinSize_ = PackingTraits<
int8_t,
typename packingAMatrix::accType,
inst_set_t::avx512_vnni>::NR_MIN;
} else if (fbgemmHasAvx512Support()) {
mbSize_ = PackingTraits<
int8_t,
typename packingAMatrix::accType,
Expand Down Expand Up @@ -118,7 +132,25 @@ void ExecuteKernel<

typename BaseType::jit_micro_kernel_fp fn;

if (fbgemmHasAvx512Support()) {
if (fbgemmHasAvx512VnniSupport()) {
if (std::is_same<typename packingAMatrix::accType, std::int16_t>::value) {
// For AVX512VNNI, we redirect int16_t to int32_t accumulation.
CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
fn = codeObj.getOrCreate<inst_set_t::avx512_vnni>(
accum,
packed_rows_A,
packedB_.blockColSize(),
packedA_.numPackedCols(),
nbSize_);
} else {
fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>(
accum,
packed_rows_A,
packedB_.blockColSize(),
packedA_.numPackedCols(),
nbSize_);
}
} else if (fbgemmHasAvx512Support()) {
fn = BaseType::template getOrCreate<inst_set_t::avx512>(
accum,
packed_rows_A,
Expand Down Expand Up @@ -148,7 +180,10 @@ void ExecuteKernel<
if (jb == bColBlocks - 1) {
int nc = ((packedB_.lastBcol() - 1) / nrMinSize_ + 1) * nrMinSize_;
if (nc != nbSize_) {
if (fbgemmHasAvx512Support()) {
if (fbgemmHasAvx512VnniSupport()) {
fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>(
accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_);
} else if (fbgemmHasAvx512Support()) {
fn = BaseType::template getOrCreate<inst_set_t::avx512>(
accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_);
} else if (fbgemmHasAvx2Support()) {
Expand Down Expand Up @@ -213,7 +248,7 @@ void ExecuteKernel<
int32_t nSize =
C_buffer_start == C_tile_ ? jb * nbSize_ : packedB_.numCols();
if (nSize) {
if (fbgemmHasAvx512Support()) {
if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support()) {
// TODO: avx512 path
// Currently use avx2 code
outputProcess_.template f<inst_set_t::avx2>(
Expand All @@ -238,7 +273,7 @@ void ExecuteKernel<
if (C_buffer_start == C_tile_) {
// When C_tile_ scratchpad was used to avoid accessing memory past
// C_buffer_ .
if (fbgemmHasAvx512Support()) {
if (fbgemmHasAvx512VnniSupport() || fbgemmHasAvx512Support()) {
// TODO: avx512 path
// Currently use avx2 code
outputProcess_.template f<inst_set_t::avx2>(
Expand Down
18 changes: 16 additions & 2 deletions src/Fbgemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ void fbgemmPacked(
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
if ((!fbgemmHasAvx512Support() && !fbgemmHasAvx2Support())) {
if ((!fbgemmHasAvx512VnniSupport() && !fbgemmHasAvx512Support() &&
!fbgemmHasAvx2Support())) {
assert(0 && "unknown architecure");
}

Expand All @@ -62,7 +63,20 @@ void fbgemmPacked(
MR = blocking_params->MR;

} else {
if (fbgemmHasAvx512Support()) {
if (fbgemmHasAvx512VnniSupport()) {
MCB = PackingTraits<
typename packingAMatrix::inpType,
typename packingAMatrix::accType,
inst_set_t::avx512_vnni>::MCB;
KCB = PackingTraits<
typename packingAMatrix::inpType,
typename packingAMatrix::accType,
inst_set_t::avx512_vnni>::KCB;
MR = PackingTraits<
typename packingAMatrix::inpType,
typename packingAMatrix::accType,
inst_set_t::avx512_vnni>::MR;
} else if (fbgemmHasAvx512Support()) {
MCB = PackingTraits<
typename packingAMatrix::inpType,
typename packingAMatrix::accType,
Expand Down
30 changes: 15 additions & 15 deletions src/GenerateKernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ namespace fbgemm {
namespace x86 = asmjit::x86;

/**
* @brief AVX2/AVX512 JIT assembly code generator.
* @brief AVX2/AVX512/AVX512VNNI JIT assembly code generator.
* @tparam TA Type of matrix A.
* @tparam TB Type of matrix B.
* @tparam TC Type of matrix C.
Expand Down Expand Up @@ -104,7 +104,7 @@ class CodeGenBase {
*/
template <inst_set_t instSet>
void initCRegs(
asmjit::X86Emitter* a,
x86::Emitter* a,
int rowRegs,
int colRegs,
int leadingDimCRegAssign = 4);
Expand All @@ -114,10 +114,10 @@ class CodeGenBase {
*/
template <inst_set_t instSet>
void genComputeBlock(
asmjit::X86Emitter* a,
asmjit::X86Gp buffer_A,
asmjit::X86Gp buffer_B,
asmjit::X86Gp B_pf,
x86::Emitter* a,
x86::Gp buffer_A,
x86::Gp buffer_B,
x86::Gp B_pf,
int rowRegs,
int colRegs,
int lda,
Expand All @@ -129,11 +129,11 @@ class CodeGenBase {
*/
template <inst_set_t instSet>
void storeCRegs(
asmjit::X86Emitter* a,
x86::Emitter* a,
int rowRegs,
int colRegs,
asmjit::X86Gp C_Offset,
asmjit::X86Gp ldcReg,
x86::Gp C_Offset,
x86::Gp ldcReg,
bool accum,
int leadingDimCRegAssign = 4);

Expand Down Expand Up @@ -168,7 +168,9 @@ class CodeGenBase {
fileName += "_MR-" + std::to_string(MR);
fileName += "_NR-" + std::to_string(NR);
fileName += "_NR_MIN-" + std::to_string(NR_MIN);
if (instSet == inst_set_t::avx512) {
if (instSet == inst_set_t::avx512_vnni) {
fileName += "_avx512vnni";
} else if (instSet == inst_set_t::avx512) {
fileName += "_avx512";
} else if (instSet == inst_set_t::avx2) {
fileName += "_avx2";
Expand All @@ -178,12 +180,10 @@ class CodeGenBase {
}

private:
asmjit::X86Ymm
CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel.
asmjit::X86Zmm
x86::Ymm CRegs_avx2_[12]; ///< AVX2 ymm registers for C in the micro-kernel.
x86::Zmm
CRegs_avx512_[28]; ///< AVX512 zmm registers for C in the micro-kernel.
asmjit::X86Zmm
AllRegs_avx512_[32]; ///< all AVX512 zmm registers.
x86::Zmm AllRegs_avx512_[32]; ///< all AVX512 zmm registers.

int vectorWidth_; ///< Vector width in bits.
int VLEN_; ///< Vector width in elements.
Expand Down
Loading

0 comments on commit 7b15607

Please sign in to comment.