Skip to content

Commit

Permalink
Back out "add avx512 support for groupConv"
Browse files Browse the repository at this point in the history
Summary: revert D23003620 (pytorch@2b81eef) for gcc9 build issue

Reviewed By: dskhudia

Differential Revision: D23243218

fbshipit-source-id: 3a116c85733c7254d677435bd41d50fbf4fe7e73
  • Loading branch information
YazhiGao authored and facebook-github-bot committed Aug 20, 2020
1 parent 2b81eef commit 156bc80
Show file tree
Hide file tree
Showing 11 changed files with 355 additions and 1,486 deletions.
3 changes: 0 additions & 3 deletions defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ def get_fbgemm_generic_srcs(with_base = False):
"src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc",
"src/GroupwiseConv.cc",
"src/GroupwiseConvAcc32Avx2.cc",
"src/GroupwiseConvAcc32Avx512.cc",
"src/PackAMatrix.cc",
"src/PackAWithIm2Col.cc",
"src/PackAWithQuantRowOffset.cc",
Expand Down Expand Up @@ -59,7 +58,6 @@ def get_fbgemm_public_headers():
"include/fbgemm/PackingTraits-inl.h",
"include/fbgemm/QuantUtils.h",
"include/fbgemm/QuantUtilsAvx2.h",
"include/fbgemm/QuantUtilsAvx512.h",
"include/fbgemm/Utils.h",
"include/fbgemm/UtilsAvx2.h",
"include/fbgemm/ConvUtils.h",
Expand Down Expand Up @@ -89,7 +87,6 @@ def get_fbgemm_avx512_srcs(msvc = False):
#All the source files that use avx512 instructions statically
"src/FbgemmBfloat16ConvertAvx512.cc",
"src/FbgemmFloat16ConvertAvx512.cc",
"src/QuantUtilsAvx512.cc",
"src/UtilsAvx512.cc",
#FP16 kernels contain inline assembly and inline assembly syntax for MSVC is different.
"src/FbgemmFP16UKernelsAvx512.cc" if not msvc else "src/FbgemmFP16UKernelsIntrinsicAvx512.cc",
Expand Down
1 change: 0 additions & 1 deletion include/fbgemm/Fbgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include "./FbgemmI8DepthwiseAvx2.h"
#include "./FbgemmI8Spmdm.h"
#include "./QuantUtilsAvx2.h"
#include "./QuantUtilsAvx512.h"
#include "./Types.h"
#include "./Utils.h"

Expand Down
23 changes: 0 additions & 23 deletions include/fbgemm/QuantUtilsAvx512.h

This file was deleted.

121 changes: 48 additions & 73 deletions src/CodeGenHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,24 @@ namespace x86 = asmjit::x86;
* dest[0:15] will have 0x0001, dest[16:31]
* will have 0x0001 and so on
*/
template <
template<
inst_set_t instSet,
typename T,
typename std::enable_if<instSet == inst_set_t::avx2, int>::type = 0>
void gen16BitVectorOne(x86::Emitter* a, T dest) {
void gen16BitVectorOne(x86::Emitter* a, T dest)
{
a->vpcmpeqw(dest, dest, dest);
a->vpsrlw(dest, dest, 15);
}

template <
template<
inst_set_t instSet,
typename T,
typename std::enable_if<
instSet == inst_set_t::avx512 || instSet == inst_set_t::avx512_ymm ||
instSet == inst_set_t::avx512_vnni ||
instSet == inst_set_t::avx512_vnni_ymm,
instSet == inst_set_t::avx512 ||
instSet == inst_set_t::avx512_ymm ||
instSet == inst_set_t::avx512_vnni ||
instSet == inst_set_t::avx512_vnni_ymm,
int>::type = 0>
void gen16BitVectorOne(x86::Emitter* a, T dest) {
a->vpternlogd(dest, dest, dest, 0xff);
Expand All @@ -49,24 +51,27 @@ void gen16BitVectorOne(x86::Emitter* a, T dest) {
*
* @param dest Destination vector register
*/
template <
template<
inst_set_t instSet,
typename T,
typename std::enable_if<instSet == inst_set_t::avx2, int>::type = 0>
void emitLoadDWord(x86::Emitter* a, T dest, const x86::Mem& ptr) {
a->vmovdqa(dest, ptr);
void emitLoadDWord(
x86::Emitter* a, T dest, const x86::Mem& ptr) {
a->vmovdqa(dest, ptr);
}

template <
template<
inst_set_t instSet,
typename T,
typename std::enable_if<
instSet == inst_set_t::avx512 || instSet == inst_set_t::avx512_ymm ||
instSet == inst_set_t::avx512_vnni ||
instSet == inst_set_t::avx512_vnni_ymm,
instSet == inst_set_t::avx512 ||
instSet == inst_set_t::avx512_ymm ||
instSet == inst_set_t::avx512_vnni ||
instSet == inst_set_t::avx512_vnni_ymm,
int>::type = 0>
void emitLoadDWord(x86::Emitter* a, T dest, const x86::Mem& ptr) {
a->vmovdqa32(dest, ptr);
void emitLoadDWord(
x86::Emitter* a, T dest, const x86::Mem& ptr) {
a->vmovdqa32(dest, ptr);
}

/**
Expand All @@ -78,47 +83,42 @@ void emitLoadDWord(x86::Emitter* a, T dest, const x86::Mem& ptr) {
* @param vec Source (full) vector register
* @param idx Index of of the half vector 0 or 1
*/
template <
template<
inst_set_t instSet,
typename T,
typename std::enable_if<
instSet == inst_set_t::avx512 || instSet == inst_set_t::avx512_ymm ||
instSet == inst_set_t::avx512_vnni ||
instSet == inst_set_t::avx512_vnni_ymm,
instSet == inst_set_t::avx512 ||
instSet == inst_set_t::avx512_ymm ||
instSet == inst_set_t::avx512_vnni ||
instSet == inst_set_t::avx512_vnni_ymm,
int>::type = 0>
void emitExtractHalfVector(
x86::Emitter* a,
x86::Ymm half,
const x86::Zmm vec,
int idx) {
x86::Emitter* a, x86::Ymm half, const x86::Zmm vec, int idx) {
a->vextracti32x8(half, vec, idx);
}

template <
template<
inst_set_t instSet,
typename T,
typename std::enable_if<
instSet == inst_set_t::avx512 || instSet == inst_set_t::avx512_ymm ||
instSet == inst_set_t::avx512_vnni ||
instSet == inst_set_t::avx512_vnni_ymm,
instSet == inst_set_t::avx512 ||
instSet == inst_set_t::avx512_ymm ||
instSet == inst_set_t::avx512_vnni ||
instSet == inst_set_t::avx512_vnni_ymm,
int>::type = 0>
void emitExtractHalfVector(
x86::Emitter* a,
x86::Xmm half,
x86::Ymm vec,
int idx) {
x86::Emitter* a, x86::Xmm half, x86::Ymm vec, int idx) {
a->vextracti32x4(half, vec, idx);
}

template <
template<
inst_set_t instSet,
typename T,
typename std::enable_if<instSet == inst_set_t::avx2, int>::type = 0>
typename std::enable_if<
instSet == inst_set_t::avx2,
int>::type = 0>
void emitExtractHalfVector(
x86::Emitter* a,
x86::Xmm half,
x86::Ymm vec,
int idx) {
x86::Emitter* a, x86::Xmm half, x86::Ymm vec, int idx) {
a->vextracti128(half, vec, idx);
}

Expand All @@ -130,42 +130,27 @@ void emitExtractHalfVector(
* dest[0:7] will have 0x01, dest[8:15]
* will have 0x01 and so on
*/
template <
typename T,
typename std::enable_if<std::is_same<T, x86::Ymm>::value, int>::type = 0>
template <typename T>
void gen8BitVectorOne(x86::Emitter* a, T dest) {
a->vpcmpeqw(dest, dest, dest);
a->vpabsb(dest, dest);
}

template <
typename T,
typename std::enable_if<std::is_same<T, x86::Zmm>::value, int>::type = 0>
void gen8BitVectorOne(x86::Emitter* a, T dest) {
a->vpternlogd(dest, dest, dest, 0xff);
a->vpabsb(dest, dest);
}

/**
* @brief Generates instruction sequence to compute s32 += U8 * I8
* @tparam T Register type of destination, e.g., x86::Ymm or x86::Zmm
*
* @param cReg contains result
*
*/

template <
inst_set_t INST_SET,
typename std::enable_if<
INST_SET == inst_set_t::avx2 || INST_SET == inst_set_t::avx512,
int>::type = 0>
template <typename T>
void genU8I8S32FMA(
x86::Emitter* a,
typename simd_info<INST_SET>::vec_reg_t aReg,
typename simd_info<INST_SET>::vec_reg_t bReg,
typename simd_info<INST_SET>::vec_reg_t cReg,
typename simd_info<INST_SET>::vec_reg_t oneReg16Bit,
typename simd_info<INST_SET>::vec_reg_t tmpReg) {
T aReg,
T bReg,
T cReg,
T oneReg16Bit,
T tmpReg) {
a->vpmaddubsw(tmpReg, aReg, bReg);
a->vpmaddwd(tmpReg, oneReg16Bit, tmpReg);
a->vpaddd(cReg, tmpReg, cReg);
Expand All @@ -181,17 +166,8 @@ void genU8I8S32FMA(
* @param dest contains result
*
*/
template <
inst_set_t INST_SET,
typename std::enable_if<
INST_SET == inst_set_t::avx2 || INST_SET == inst_set_t::avx512,
int>::type = 0>
void genU8Sum4(
x86::Emitter* a,
typename simd_info<INST_SET>::vec_reg_t src,
typename simd_info<INST_SET>::vec_reg_t dest,
typename simd_info<INST_SET>::vec_reg_t oneReg16Bit,
typename simd_info<INST_SET>::vec_reg_t tmpReg) {
template <typename T>
void genU8Sum4(x86::Emitter* a, T src, T dest, T oneReg16Bit, T tmpReg) {
gen8BitVectorOne(a, tmpReg);
a->vpmaddubsw(tmpReg, src, tmpReg);
a->vpmaddwd(tmpReg, tmpReg, oneReg16Bit);
Expand Down Expand Up @@ -237,9 +213,8 @@ void genU8Sum8(x86::Emitter* a, T src, T dest, T tmpReg) {
template <typename T>
void broadcast8Bit(x86::Emitter* a, x86::Gp src, T dest) {
// move src to dest
auto xmm = dest.xmm();
a->movq(xmm, src);
a->vpbroadcastb(dest, xmm);
a->movq(dest.half(), src);
a->vpbroadcastb(dest, dest.half());
}

} // namespace fbgemm
Loading

0 comments on commit 156bc80

Please sign in to comment.