Skip to content

Commit

Permalink
Back out "add avx512vnni support for GroupConv"
Browse files Browse the repository at this point in the history
Summary: back out for https://www.internalfb.com/intern/diff/D23421376/

Reviewed By: dskhudia

Differential Revision: D23421914

fbshipit-source-id: 7939b1b7d51d7bfb362a2ab9a8981d3600aca7f9
  • Loading branch information
YazhiGao authored and facebook-github-bot committed Aug 31, 2020
1 parent a288e17 commit 36e0509
Show file tree
Hide file tree
Showing 7 changed files with 77 additions and 148 deletions.
1 change: 1 addition & 0 deletions include/fbgemm/Fbgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "./FbgemmI8DepthwiseAvx2.h"
#include "./FbgemmI8Spmdm.h"
#include "./QuantUtilsAvx2.h"
#include "./QuantUtilsAvx512.h"
#include "./Types.h"
#include "./Utils.h"

Expand Down
26 changes: 0 additions & 26 deletions src/CodeGenHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,19 +171,6 @@ void genU8I8S32FMA(
a->vpaddd(cReg, tmpReg, cReg);
}

template <
inst_set_t INST_SET,
typename std::enable_if<INST_SET == inst_set_t::avx512_vnni, int>::type = 0>
void genU8I8S32FMA(
x86::Emitter* a,
typename simd_info<INST_SET>::vec_reg_t aReg,
typename simd_info<INST_SET>::vec_reg_t bReg,
typename simd_info<INST_SET>::vec_reg_t cReg,
typename simd_info<INST_SET>::vec_reg_t oneReg16Bit,
typename simd_info<INST_SET>::vec_reg_t tmpReg) {
a->vpdpbusd(cReg, aReg, bReg);
}

/**
* @brief Add 4 consecutive numbers of type uint8
* and emit their sum as 32-bit numbers.
Expand Down Expand Up @@ -216,19 +203,6 @@ void genU8Sum4(
/*a->vpaddd(dest, tmpReg, dest);*/
}

template <
inst_set_t INST_SET,
typename std::enable_if<INST_SET == inst_set_t::avx512_vnni, int>::type = 0>
void genU8Sum4(
x86::Emitter* a,
typename simd_info<INST_SET>::vec_reg_t src,
typename simd_info<INST_SET>::vec_reg_t dest,
typename simd_info<INST_SET>::vec_reg_t oneReg16Bit,
typename simd_info<INST_SET>::vec_reg_t tmpReg) {
gen8BitVectorOne(a, tmpReg);
a->vpdpbusd(dest, src, tmpReg);
}

/**
* @brief Add 8 consecutive numbers of type uint8
* and emit their sum as 16-bit numbers.
Expand Down
18 changes: 2 additions & 16 deletions src/GroupwiseConv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
#include "./RefImplementations.h"
#include "./TransposeUtils.h"
#include "fbgemm/Fbgemm.h"
#include "fbgemm/QuantUtilsAvx512.h"

namespace fbgemm {

Expand Down Expand Up @@ -123,20 +122,7 @@ jit_conv_kernel_fp getOrCreateConvKernel(
accum);

if (cpuinfo_initialize()) {
if (fbgemmHasAvx512VnniSupport()) {
return GenConvKernel<SPATIAL_DIM, inst_set_t::avx512_vnni>::codeCache_
.getOrCreate(kernelSig, [&]() {
auto genObj = GenConvKernel<SPATIAL_DIM, inst_set_t::avx512_vnni>(
conv_param,
a_zero_point,
needRowOffset,
isTopEdgeIncluded,
isBottomEdgeIncluded,
isTopBottomEdgeSame,
accum);
return genObj.getOrCreate();
});
} else if (fbgemmHasAvx512Support()) {
if (fbgemmHasAvx512Support()) {
return GenConvKernel<SPATIAL_DIM, inst_set_t::avx512>::codeCache_
.getOrCreate(kernelSig, [&]() {
auto genObj = GenConvKernel<SPATIAL_DIM, inst_set_t::avx512>(
Expand Down Expand Up @@ -905,7 +891,7 @@ void dispatchOutputProcessing(
groups,
outProcess.getActWScale()};
if (cpuinfo_initialize()) {
if (fbgemmHasAvx512Support() || fbgemmHasAvx512VnniSupport()) {
if (fbgemmHasAvx512Support()) {
if (C_per_G == 2) {
if (a_zero_point == 0) {
if (b_symmetric) {
Expand Down
94 changes: 45 additions & 49 deletions src/GroupwiseConv.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,28 +21,6 @@
#include "fbgemm/Utils.h"
/*#define FBGEMM_LOG_CODE 1*/

#define GCONV_INST_AVX2_HEADER \
template <inst_set_t ISET = INST_SET> \
typename std::enable_if<ISET == inst_set_t::avx2, void>::type

#define GCONV_INST_AVX512_AND_VNNI_HEADER \
template <inst_set_t ISET = INST_SET> \
typename std::enable_if< \
ISET == inst_set_t::avx512 || ISET == inst_set_t::avx512_vnni, \
void>::type

#define GCONV_INST_DEF_AVX2_HEADER \
template <int SPATIAL_DIM, inst_set_t INST_SET> \
template <inst_set_t ISET> \
typename std::enable_if<ISET == inst_set_t::avx2, void>::type

#define GCONV_INST_DEF_AVX512_AND_VNNI_HEADER \
template <int SPATIAL_DIM, inst_set_t INST_SET> \
template <inst_set_t ISET> \
typename std::enable_if< \
ISET == inst_set_t::avx512 || ISET == inst_set_t::avx512_vnni, \
void>::type

namespace fbgemm {

namespace x86 = asmjit::x86;
Expand Down Expand Up @@ -240,49 +218,32 @@ class FBGEMM_API GenConvKernel

jit_conv_kernel_fp getOrCreate();

GCONV_INST_AVX2_HEADER genForLoadingWeights(x86::Emitter* a);
void genForLoadingWeights(x86::Emitter* a) {}

GCONV_INST_AVX512_AND_VNNI_HEADER genForLoadingWeights(x86::Emitter* a);
void genConstForPermutations(x86::Emitter* a) {}

GCONV_INST_AVX2_HEADER genConstForPermutations(x86::Emitter* a);
void genForTopOrBottomEdge(x86::Emitter* a, bool isTop, bool isBottom);

GCONV_INST_AVX512_AND_VNNI_HEADER genConstForPermutations(x86::Emitter* a);
void initResultRegs(x86::Emitter* a);

GCONV_INST_AVX2_HEADER genForSingleFilterPoint(
x86::Emitter* a,
int r,
int s,
int act_s,
bool use_zero_reg);
void genCoreInsts(x86::Emitter* a);

GCONV_INST_AVX512_AND_VNNI_HEADER genForSingleFilterPoint(
void genForSingleFilterPoint(
x86::Emitter* a,
int r,
int s,
int act_s,
bool use_zero_reg);

GCONV_INST_AVX2_HEADER storeResult(x86::Emitter* a);

GCONV_INST_AVX512_AND_VNNI_HEADER storeResult(x86::Emitter* a);

GCONV_INST_AVX2_HEADER storeOffset(x86::Emitter* a);

GCONV_INST_AVX512_AND_VNNI_HEADER storeOffset(x86::Emitter* a);

void genForTopOrBottomEdge(x86::Emitter* a, bool isTop, bool isBottom);

void initResultRegs(x86::Emitter* a);

void genCoreInsts(x86::Emitter* a);

bool use_zero_reg) {}
void genForSingleOutput(
x86::Emitter* a,
bool isLeft,
bool isRight,
bool isTop,
bool isBottom);

void storeResult(x86::Emitter* a) {}
void storeOffset(x86::Emitter* a) {}

private:
int GTogether_;
// The number of iterations needed for K dim.
Expand Down Expand Up @@ -326,4 +287,39 @@ template <int SPATIAL_DIM, inst_set_t INST_SET>
CodeCache<kernel_sig_t, jit_conv_kernel_fp>
GenConvKernelBase<SPATIAL_DIM, INST_SET>::codeCache_;

// forward declaration of specialized ISA specific functions
template <>
void GenConvKernel<2, inst_set_t::avx2>::genConstForPermutations(
x86::Emitter* a);
template <>
void GenConvKernel<2, inst_set_t::avx2>::genForLoadingWeights(x86::Emitter* a);
template <>
void GenConvKernel<2, inst_set_t::avx2>::storeResult(x86::Emitter* a);
template <>
void GenConvKernel<2, inst_set_t::avx2>::storeOffset(x86::Emitter* a);
template <>
void GenConvKernel<2, inst_set_t::avx2>::genForSingleFilterPoint(
x86::Emitter* a,
int r,
int s,
int act_s,
bool use_zero_reg);

template <>
void GenConvKernel<2, inst_set_t::avx512>::genConstForPermutations(
x86::Emitter* a);
template <>
void GenConvKernel<2, inst_set_t::avx512>::genForLoadingWeights(x86::Emitter* a);
template <>
void GenConvKernel<2, inst_set_t::avx512>::storeResult(x86::Emitter* a);
template <>
void GenConvKernel<2, inst_set_t::avx512>::storeOffset(x86::Emitter* a);
template <>
void GenConvKernel<2, inst_set_t::avx512>::genForSingleFilterPoint(
x86::Emitter* a,
int r,
int s,
int act_s,
bool use_zero_reg);

} // namespace fbgemm
36 changes: 13 additions & 23 deletions src/GroupwiseConvAcc32Avx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@ using namespace std;

namespace x86 = asmjit::x86;

GCONV_INST_DEF_AVX2_HEADER GenConvKernel<SPATIAL_DIM, INST_SET>::genConstForPermutations(
template <>
void GenConvKernel<2, inst_set_t::avx2>::genConstForPermutations(
x86::Emitter* a) {
if (this->C_per_G_ == 4) {
x86::Gp permute_const_reg = a->gpz(12);
Expand Down Expand Up @@ -44,8 +45,8 @@ GCONV_INST_DEF_AVX2_HEADER GenConvKernel<SPATIAL_DIM, INST_SET>::genConstForPerm
}
}

GCONV_INST_DEF_AVX2_HEADER GenConvKernel<SPATIAL_DIM, INST_SET>::genForLoadingWeights(
x86::Emitter* a) {
template <>
void GenConvKernel<2, inst_set_t::avx2>::genForLoadingWeights(x86::Emitter* a) {
using WRegs = x86::Ymm;
int paddedICPerG = (this->C_per_G_ + 3) / 4 * 4;
// load weights
Expand All @@ -65,7 +66,8 @@ GCONV_INST_DEF_AVX2_HEADER GenConvKernel<SPATIAL_DIM, INST_SET>::genForLoadingWe
}
}

GCONV_INST_DEF_AVX2_HEADER GenConvKernel<SPATIAL_DIM, INST_SET>::storeResult(x86::Emitter* a) {
template <>
void GenConvKernel<2, inst_set_t::avx2>::storeResult(x86::Emitter* a) {
if (GTogether_ > 1) {
// store with permutation
a->vpermd(x86::Ymm(9), stPermReg_V_, x86::Ymm(9));
Expand Down Expand Up @@ -111,7 +113,8 @@ GCONV_INST_DEF_AVX2_HEADER GenConvKernel<SPATIAL_DIM, INST_SET>::storeResult(x86
}
}

GCONV_INST_DEF_AVX2_HEADER GenConvKernel<SPATIAL_DIM, INST_SET>::storeOffset(x86::Emitter* a) {
template <>
void GenConvKernel<2, inst_set_t::avx2>::storeOffset(x86::Emitter* a) {
switch (this->C_per_G_) {
case 2:
// store 128-bits containing rowoffset for four groups
Expand Down Expand Up @@ -152,7 +155,8 @@ GCONV_INST_DEF_AVX2_HEADER GenConvKernel<SPATIAL_DIM, INST_SET>::storeOffset(x86
}
}

GCONV_INST_DEF_AVX2_HEADER GenConvKernel<SPATIAL_DIM, INST_SET>::genForSingleFilterPoint(
template <>
void GenConvKernel<2, inst_set_t::avx2>::genForSingleFilterPoint(
x86::Emitter* a,
int r,
int s,
Expand Down Expand Up @@ -185,12 +189,12 @@ GCONV_INST_DEF_AVX2_HEADER GenConvKernel<SPATIAL_DIM, INST_SET>::genForSingleFil
}
// row offset
if (this->needRowOffset_) {
genU8Sum4<INST_SET>(
genU8Sum4<inst_set_t::avx2>(
a, actReg_V_, rowOffsetReg_V_, oneReg16Bit_V_, tmpReg1_V_);
}
// 32 * int8 weight product 32 * uint8 activation -> 8
// output(K_per_g * group_together)
genU8I8S32FMA<INST_SET>(
genU8I8S32FMA<inst_set_t::avx2>(
a,
actReg_V_,
WRegs(r * this->S_ + s),
Expand Down Expand Up @@ -231,26 +235,12 @@ GCONV_INST_DEF_AVX2_HEADER GenConvKernel<SPATIAL_DIM, INST_SET>::genForSingleFil
// FMA result is not final reduction on C_per_G, producing 8 output in
// which consectutive 2 elements if summedforms one final output over
// K_Per_G dimension
genU8I8S32FMA<INST_SET>(
genU8I8S32FMA<inst_set_t::avx2>(
a, actReg_V_, WRegs(0), x86::Ymm(9 - k), oneReg16Bit_V_, tmpReg1_V_);
}
}
}

#define GENCONVKERNEL_FUNCS(S, IN) \
template void GenConvKernel<S, IN>::genForLoadingWeights<IN>( \
x86::Emitter * a); \
template void GenConvKernel<S, IN>::genConstForPermutations<IN>( \
x86::Emitter * a); \
template void GenConvKernel<S, IN>::genForSingleFilterPoint<IN>( \
x86::Emitter * a, int r, int s, int act_s, bool use_zero_reg); \
template void GenConvKernel<S, IN>::storeResult<IN>(x86::Emitter * a); \
template void GenConvKernel<S, IN>::storeOffset<IN>(x86::Emitter * a);
GENCONVKERNEL_FUNCS(1, inst_set_t::avx2)
GENCONVKERNEL_FUNCS(2, inst_set_t::avx2)
GENCONVKERNEL_FUNCS(3, inst_set_t::avx2)
#undef GENCONVKERNEL_FUNCS

template class GenConvKernel<1, inst_set_t::avx2>;
template class GenConvKernel<2, inst_set_t::avx2>;
template class GenConvKernel<3, inst_set_t::avx2>;
Expand Down
Loading

0 comments on commit 36e0509

Please sign in to comment.