Skip to content

Commit

Permalink
add avx512vnni support for GroupConv (pytorch#411)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#411

this diff adds support for avx512vnni in Groupwise Convolution and also refactor the code structure that we can extend for spatial-specific and instruction-specifc subroutines

Reviewed By: dskhudia

Differential Revision: D23120201

fbshipit-source-id: 4e3d7365f84edce6c0dbf2e9dc1716e448bbafdb
  • Loading branch information
YazhiGao authored and facebook-github-bot committed Aug 30, 2020
1 parent c2d2c2c commit a288e17
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 77 deletions.
1 change: 0 additions & 1 deletion include/fbgemm/Fbgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include "./FbgemmI8DepthwiseAvx2.h"
#include "./FbgemmI8Spmdm.h"
#include "./QuantUtilsAvx2.h"
#include "./QuantUtilsAvx512.h"
#include "./Types.h"
#include "./Utils.h"

Expand Down
26 changes: 26 additions & 0 deletions src/CodeGenHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,19 @@ void genU8I8S32FMA(
a->vpaddd(cReg, tmpReg, cReg);
}

template <
inst_set_t INST_SET,
typename std::enable_if<INST_SET == inst_set_t::avx512_vnni, int>::type = 0>
void genU8I8S32FMA(
x86::Emitter* a,
typename simd_info<INST_SET>::vec_reg_t aReg,
typename simd_info<INST_SET>::vec_reg_t bReg,
typename simd_info<INST_SET>::vec_reg_t cReg,
typename simd_info<INST_SET>::vec_reg_t oneReg16Bit,
typename simd_info<INST_SET>::vec_reg_t tmpReg) {
a->vpdpbusd(cReg, aReg, bReg);
}

/**
* @brief Add 4 consecutive numbers of type uint8
* and emit their sum as 32-bit numbers.
Expand Down Expand Up @@ -203,6 +216,19 @@ void genU8Sum4(
/*a->vpaddd(dest, tmpReg, dest);*/
}

template <
inst_set_t INST_SET,
typename std::enable_if<INST_SET == inst_set_t::avx512_vnni, int>::type = 0>
void genU8Sum4(
x86::Emitter* a,
typename simd_info<INST_SET>::vec_reg_t src,
typename simd_info<INST_SET>::vec_reg_t dest,
typename simd_info<INST_SET>::vec_reg_t oneReg16Bit,
typename simd_info<INST_SET>::vec_reg_t tmpReg) {
gen8BitVectorOne(a, tmpReg);
a->vpdpbusd(dest, src, tmpReg);
}

/**
* @brief Add 8 consecutive numbers of type uint8
* and emit their sum as 16-bit numbers.
Expand Down
18 changes: 16 additions & 2 deletions src/GroupwiseConv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "./RefImplementations.h"
#include "./TransposeUtils.h"
#include "fbgemm/Fbgemm.h"
#include "fbgemm/QuantUtilsAvx512.h"

namespace fbgemm {

Expand Down Expand Up @@ -122,7 +123,20 @@ jit_conv_kernel_fp getOrCreateConvKernel(
accum);

if (cpuinfo_initialize()) {
if (fbgemmHasAvx512Support()) {
if (fbgemmHasAvx512VnniSupport()) {
return GenConvKernel<SPATIAL_DIM, inst_set_t::avx512_vnni>::codeCache_
.getOrCreate(kernelSig, [&]() {
auto genObj = GenConvKernel<SPATIAL_DIM, inst_set_t::avx512_vnni>(
conv_param,
a_zero_point,
needRowOffset,
isTopEdgeIncluded,
isBottomEdgeIncluded,
isTopBottomEdgeSame,
accum);
return genObj.getOrCreate();
});
} else if (fbgemmHasAvx512Support()) {
return GenConvKernel<SPATIAL_DIM, inst_set_t::avx512>::codeCache_
.getOrCreate(kernelSig, [&]() {
auto genObj = GenConvKernel<SPATIAL_DIM, inst_set_t::avx512>(
Expand Down Expand Up @@ -891,7 +905,7 @@ void dispatchOutputProcessing(
groups,
outProcess.getActWScale()};
if (cpuinfo_initialize()) {
if (fbgemmHasAvx512Support()) {
if (fbgemmHasAvx512Support() || fbgemmHasAvx512VnniSupport()) {
if (C_per_G == 2) {
if (a_zero_point == 0) {
if (b_symmetric) {
Expand Down
94 changes: 49 additions & 45 deletions src/GroupwiseConv.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,28 @@
#include "fbgemm/Utils.h"
/*#define FBGEMM_LOG_CODE 1*/

#define GCONV_INST_AVX2_HEADER \
template <inst_set_t ISET = INST_SET> \
typename std::enable_if<ISET == inst_set_t::avx2, void>::type

#define GCONV_INST_AVX512_AND_VNNI_HEADER \
template <inst_set_t ISET = INST_SET> \
typename std::enable_if< \
ISET == inst_set_t::avx512 || ISET == inst_set_t::avx512_vnni, \
void>::type

#define GCONV_INST_DEF_AVX2_HEADER \
template <int SPATIAL_DIM, inst_set_t INST_SET> \
template <inst_set_t ISET> \
typename std::enable_if<ISET == inst_set_t::avx2, void>::type

#define GCONV_INST_DEF_AVX512_AND_VNNI_HEADER \
template <int SPATIAL_DIM, inst_set_t INST_SET> \
template <inst_set_t ISET> \
typename std::enable_if< \
ISET == inst_set_t::avx512 || ISET == inst_set_t::avx512_vnni, \
void>::type

namespace fbgemm {

namespace x86 = asmjit::x86;
Expand Down Expand Up @@ -218,32 +240,49 @@ class FBGEMM_API GenConvKernel

jit_conv_kernel_fp getOrCreate();

void genForLoadingWeights(x86::Emitter* a) {}
GCONV_INST_AVX2_HEADER genForLoadingWeights(x86::Emitter* a);

void genConstForPermutations(x86::Emitter* a) {}
GCONV_INST_AVX512_AND_VNNI_HEADER genForLoadingWeights(x86::Emitter* a);

void genForTopOrBottomEdge(x86::Emitter* a, bool isTop, bool isBottom);
GCONV_INST_AVX2_HEADER genConstForPermutations(x86::Emitter* a);

void initResultRegs(x86::Emitter* a);
GCONV_INST_AVX512_AND_VNNI_HEADER genConstForPermutations(x86::Emitter* a);

void genCoreInsts(x86::Emitter* a);
GCONV_INST_AVX2_HEADER genForSingleFilterPoint(
x86::Emitter* a,
int r,
int s,
int act_s,
bool use_zero_reg);

void genForSingleFilterPoint(
GCONV_INST_AVX512_AND_VNNI_HEADER genForSingleFilterPoint(
x86::Emitter* a,
int r,
int s,
int act_s,
bool use_zero_reg) {}
bool use_zero_reg);

GCONV_INST_AVX2_HEADER storeResult(x86::Emitter* a);

GCONV_INST_AVX512_AND_VNNI_HEADER storeResult(x86::Emitter* a);

GCONV_INST_AVX2_HEADER storeOffset(x86::Emitter* a);

GCONV_INST_AVX512_AND_VNNI_HEADER storeOffset(x86::Emitter* a);

void genForTopOrBottomEdge(x86::Emitter* a, bool isTop, bool isBottom);

void initResultRegs(x86::Emitter* a);

void genCoreInsts(x86::Emitter* a);

void genForSingleOutput(
x86::Emitter* a,
bool isLeft,
bool isRight,
bool isTop,
bool isBottom);

void storeResult(x86::Emitter* a) {}
void storeOffset(x86::Emitter* a) {}

private:
int GTogether_;
// The number of iterations needed for K dim.
Expand Down Expand Up @@ -287,39 +326,4 @@ template <int SPATIAL_DIM, inst_set_t INST_SET>
CodeCache<kernel_sig_t, jit_conv_kernel_fp>
GenConvKernelBase<SPATIAL_DIM, INST_SET>::codeCache_;

// forward declaration of specialized ISA specific functions
template <>
void GenConvKernel<2, inst_set_t::avx2>::genConstForPermutations(
x86::Emitter* a);
template <>
void GenConvKernel<2, inst_set_t::avx2>::genForLoadingWeights(x86::Emitter* a);
template <>
void GenConvKernel<2, inst_set_t::avx2>::storeResult(x86::Emitter* a);
template <>
void GenConvKernel<2, inst_set_t::avx2>::storeOffset(x86::Emitter* a);
template <>
void GenConvKernel<2, inst_set_t::avx2>::genForSingleFilterPoint(
x86::Emitter* a,
int r,
int s,
int act_s,
bool use_zero_reg);

template <>
void GenConvKernel<2, inst_set_t::avx512>::genConstForPermutations(
x86::Emitter* a);
template <>
void GenConvKernel<2, inst_set_t::avx512>::genForLoadingWeights(x86::Emitter* a);
template <>
void GenConvKernel<2, inst_set_t::avx512>::storeResult(x86::Emitter* a);
template <>
void GenConvKernel<2, inst_set_t::avx512>::storeOffset(x86::Emitter* a);
template <>
void GenConvKernel<2, inst_set_t::avx512>::genForSingleFilterPoint(
x86::Emitter* a,
int r,
int s,
int act_s,
bool use_zero_reg);

} // namespace fbgemm
36 changes: 23 additions & 13 deletions src/GroupwiseConvAcc32Avx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@ using namespace std;

namespace x86 = asmjit::x86;

template <>
void GenConvKernel<2, inst_set_t::avx2>::genConstForPermutations(
GCONV_INST_DEF_AVX2_HEADER GenConvKernel<SPATIAL_DIM, INST_SET>::genConstForPermutations(
x86::Emitter* a) {
if (this->C_per_G_ == 4) {
x86::Gp permute_const_reg = a->gpz(12);
Expand Down Expand Up @@ -45,8 +44,8 @@ void GenConvKernel<2, inst_set_t::avx2>::genConstForPermutations(
}
}

template <>
void GenConvKernel<2, inst_set_t::avx2>::genForLoadingWeights(x86::Emitter* a) {
GCONV_INST_DEF_AVX2_HEADER GenConvKernel<SPATIAL_DIM, INST_SET>::genForLoadingWeights(
x86::Emitter* a) {
using WRegs = x86::Ymm;
int paddedICPerG = (this->C_per_G_ + 3) / 4 * 4;
// load weights
Expand All @@ -66,8 +65,7 @@ void GenConvKernel<2, inst_set_t::avx2>::genForLoadingWeights(x86::Emitter* a) {
}
}

template <>
void GenConvKernel<2, inst_set_t::avx2>::storeResult(x86::Emitter* a) {
GCONV_INST_DEF_AVX2_HEADER GenConvKernel<SPATIAL_DIM, INST_SET>::storeResult(x86::Emitter* a) {
if (GTogether_ > 1) {
// store with permutation
a->vpermd(x86::Ymm(9), stPermReg_V_, x86::Ymm(9));
Expand Down Expand Up @@ -113,8 +111,7 @@ void GenConvKernel<2, inst_set_t::avx2>::storeResult(x86::Emitter* a) {
}
}

template <>
void GenConvKernel<2, inst_set_t::avx2>::storeOffset(x86::Emitter* a) {
GCONV_INST_DEF_AVX2_HEADER GenConvKernel<SPATIAL_DIM, INST_SET>::storeOffset(x86::Emitter* a) {
switch (this->C_per_G_) {
case 2:
// store 128-bits containing rowoffset for four groups
Expand Down Expand Up @@ -155,8 +152,7 @@ void GenConvKernel<2, inst_set_t::avx2>::storeOffset(x86::Emitter* a) {
}
}

template <>
void GenConvKernel<2, inst_set_t::avx2>::genForSingleFilterPoint(
GCONV_INST_DEF_AVX2_HEADER GenConvKernel<SPATIAL_DIM, INST_SET>::genForSingleFilterPoint(
x86::Emitter* a,
int r,
int s,
Expand Down Expand Up @@ -189,12 +185,12 @@ void GenConvKernel<2, inst_set_t::avx2>::genForSingleFilterPoint(
}
// row offset
if (this->needRowOffset_) {
genU8Sum4<inst_set_t::avx2>(
genU8Sum4<INST_SET>(
a, actReg_V_, rowOffsetReg_V_, oneReg16Bit_V_, tmpReg1_V_);
}
// 32 * int8 weight product 32 * uint8 activation -> 8
// output(K_per_g * group_together)
genU8I8S32FMA<inst_set_t::avx2>(
genU8I8S32FMA<INST_SET>(
a,
actReg_V_,
WRegs(r * this->S_ + s),
Expand Down Expand Up @@ -235,12 +231,26 @@ void GenConvKernel<2, inst_set_t::avx2>::genForSingleFilterPoint(
// FMA result is not final reduction on C_per_G, producing 8 output in
// which consectutive 2 elements if summedforms one final output over
// K_Per_G dimension
genU8I8S32FMA<inst_set_t::avx2>(
genU8I8S32FMA<INST_SET>(
a, actReg_V_, WRegs(0), x86::Ymm(9 - k), oneReg16Bit_V_, tmpReg1_V_);
}
}
}

#define GENCONVKERNEL_FUNCS(S, IN) \
template void GenConvKernel<S, IN>::genForLoadingWeights<IN>( \
x86::Emitter * a); \
template void GenConvKernel<S, IN>::genConstForPermutations<IN>( \
x86::Emitter * a); \
template void GenConvKernel<S, IN>::genForSingleFilterPoint<IN>( \
x86::Emitter * a, int r, int s, int act_s, bool use_zero_reg); \
template void GenConvKernel<S, IN>::storeResult<IN>(x86::Emitter * a); \
template void GenConvKernel<S, IN>::storeOffset<IN>(x86::Emitter * a);
GENCONVKERNEL_FUNCS(1, inst_set_t::avx2)
GENCONVKERNEL_FUNCS(2, inst_set_t::avx2)
GENCONVKERNEL_FUNCS(3, inst_set_t::avx2)
#undef GENCONVKERNEL_FUNCS

template class GenConvKernel<1, inst_set_t::avx2>;
template class GenConvKernel<2, inst_set_t::avx2>;
template class GenConvKernel<3, inst_set_t::avx2>;
Expand Down
Loading

0 comments on commit a288e17

Please sign in to comment.