Skip to content

Commit

Permalink
add avx512 support for groupConv (pytorch#422)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#422

add the avx512 features for groupconv due to multiple back-and-forths of hhvm compiler avx52 dependency removal.

Reviewed By: dskhudia

Differential Revision: D23436486

fbshipit-source-id: 7ef1c206e067038ccf1afcc8df6db388ff889070
  • Loading branch information
YazhiGao authored and facebook-github-bot committed Sep 4, 2020
1 parent a591925 commit d5ace7c
Show file tree
Hide file tree
Showing 11 changed files with 1,492 additions and 355 deletions.
3 changes: 3 additions & 0 deletions defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def get_fbgemm_generic_srcs(with_base = False):
"src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc",
"src/GroupwiseConv.cc",
"src/GroupwiseConvAcc32Avx2.cc",
"src/GroupwiseConvAcc32Avx512.cc",
"src/PackAMatrix.cc",
"src/PackAWithIm2Col.cc",
"src/PackAWithQuantRowOffset.cc",
Expand Down Expand Up @@ -58,6 +59,7 @@ def get_fbgemm_public_headers():
"include/fbgemm/PackingTraits-inl.h",
"include/fbgemm/QuantUtils.h",
"include/fbgemm/QuantUtilsAvx2.h",
"include/fbgemm/QuantUtilsAvx512.h",
"include/fbgemm/Utils.h",
"include/fbgemm/UtilsAvx2.h",
"include/fbgemm/ConvUtils.h",
Expand Down Expand Up @@ -87,6 +89,7 @@ def get_fbgemm_avx512_srcs(msvc = False):
#All the source files that use avx512 instructions statically
"src/FbgemmBfloat16ConvertAvx512.cc",
"src/FbgemmFloat16ConvertAvx512.cc",
"src/QuantUtilsAvx512.cc",
"src/UtilsAvx512.cc",
#FP16 kernels contain inline assembly and inline assembly syntax for MSVC is different.
"src/FbgemmFP16UKernelsAvx512.cc" if not msvc else "src/FbgemmFP16UKernelsIntrinsicAvx512.cc",
Expand Down
1 change: 1 addition & 0 deletions include/fbgemm/Fbgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "./FbgemmI8DepthwiseAvx2.h"
#include "./FbgemmI8Spmdm.h"
#include "./QuantUtilsAvx2.h"
#include "./QuantUtilsAvx512.h"
#include "./Types.h"
#include "./Utils.h"

Expand Down
29 changes: 29 additions & 0 deletions include/fbgemm/QuantUtilsAvx512.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once

#include <cstdint>
#include "./FbgemmBuild.h"
#include "./UtilsAvx2.h"

namespace fbgemm {
template <
bool A_SYMMETRIC,
bool B_SYMMETRIC,
QuantizationGranularity Q_GRAN,
bool HAS_BIAS,
bool FUSE_RELU,
int C_PER_G,
typename BIAS_TYPE = std::int32_t>
FBGEMM_API void requantizeOutputProcessingGConvAvx512(
std::uint8_t* out,
const std::int32_t* inp,
const block_type_t& block,
int ld_out,
int ld_in,
const requantizationParams_t<BIAS_TYPE>& r);
}
121 changes: 73 additions & 48 deletions src/CodeGenHelpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,24 +20,22 @@ namespace x86 = asmjit::x86;
* dest[0:15] will have 0x0001, dest[16:31]
* will have 0x0001 and so on
*/
template<
template <
inst_set_t instSet,
typename T,
typename std::enable_if<instSet == inst_set_t::avx2, int>::type = 0>
void gen16BitVectorOne(x86::Emitter* a, T dest)
{
void gen16BitVectorOne(x86::Emitter* a, T dest) {
a->vpcmpeqw(dest, dest, dest);
a->vpsrlw(dest, dest, 15);
}

template<
template <
inst_set_t instSet,
typename T,
typename std::enable_if<
instSet == inst_set_t::avx512 ||
instSet == inst_set_t::avx512_ymm ||
instSet == inst_set_t::avx512_vnni ||
instSet == inst_set_t::avx512_vnni_ymm,
instSet == inst_set_t::avx512 || instSet == inst_set_t::avx512_ymm ||
instSet == inst_set_t::avx512_vnni ||
instSet == inst_set_t::avx512_vnni_ymm,
int>::type = 0>
void gen16BitVectorOne(x86::Emitter* a, T dest) {
a->vpternlogd(dest, dest, dest, 0xff);
Expand All @@ -51,27 +49,24 @@ void gen16BitVectorOne(x86::Emitter* a, T dest) {
*
* @param dest Destination vector register
*/
template<
template <
inst_set_t instSet,
typename T,
typename std::enable_if<instSet == inst_set_t::avx2, int>::type = 0>
void emitLoadDWord(
x86::Emitter* a, T dest, const x86::Mem& ptr) {
a->vmovdqa(dest, ptr);
void emitLoadDWord(x86::Emitter* a, T dest, const x86::Mem& ptr) {
a->vmovdqa(dest, ptr);
}

template<
template <
inst_set_t instSet,
typename T,
typename std::enable_if<
instSet == inst_set_t::avx512 ||
instSet == inst_set_t::avx512_ymm ||
instSet == inst_set_t::avx512_vnni ||
instSet == inst_set_t::avx512_vnni_ymm,
instSet == inst_set_t::avx512 || instSet == inst_set_t::avx512_ymm ||
instSet == inst_set_t::avx512_vnni ||
instSet == inst_set_t::avx512_vnni_ymm,
int>::type = 0>
void emitLoadDWord(
x86::Emitter* a, T dest, const x86::Mem& ptr) {
a->vmovdqa32(dest, ptr);
void emitLoadDWord(x86::Emitter* a, T dest, const x86::Mem& ptr) {
a->vmovdqa32(dest, ptr);
}

/**
Expand All @@ -83,42 +78,47 @@ void emitLoadDWord(
* @param vec Source (full) vector register
* @param idx Index of of the half vector 0 or 1
*/
template<
template <
inst_set_t instSet,
typename T,
typename std::enable_if<
instSet == inst_set_t::avx512 ||
instSet == inst_set_t::avx512_ymm ||
instSet == inst_set_t::avx512_vnni ||
instSet == inst_set_t::avx512_vnni_ymm,
instSet == inst_set_t::avx512 || instSet == inst_set_t::avx512_ymm ||
instSet == inst_set_t::avx512_vnni ||
instSet == inst_set_t::avx512_vnni_ymm,
int>::type = 0>
void emitExtractHalfVector(
x86::Emitter* a, x86::Ymm half, const x86::Zmm vec, int idx) {
x86::Emitter* a,
x86::Ymm half,
const x86::Zmm vec,
int idx) {
a->vextracti32x8(half, vec, idx);
}

template<
template <
inst_set_t instSet,
typename T,
typename std::enable_if<
instSet == inst_set_t::avx512 ||
instSet == inst_set_t::avx512_ymm ||
instSet == inst_set_t::avx512_vnni ||
instSet == inst_set_t::avx512_vnni_ymm,
instSet == inst_set_t::avx512 || instSet == inst_set_t::avx512_ymm ||
instSet == inst_set_t::avx512_vnni ||
instSet == inst_set_t::avx512_vnni_ymm,
int>::type = 0>
void emitExtractHalfVector(
x86::Emitter* a, x86::Xmm half, x86::Ymm vec, int idx) {
x86::Emitter* a,
x86::Xmm half,
x86::Ymm vec,
int idx) {
a->vextracti32x4(half, vec, idx);
}

template<
template <
inst_set_t instSet,
typename T,
typename std::enable_if<
instSet == inst_set_t::avx2,
int>::type = 0>
typename std::enable_if<instSet == inst_set_t::avx2, int>::type = 0>
void emitExtractHalfVector(
x86::Emitter* a, x86::Xmm half, x86::Ymm vec, int idx) {
x86::Emitter* a,
x86::Xmm half,
x86::Ymm vec,
int idx) {
a->vextracti128(half, vec, idx);
}

Expand All @@ -130,27 +130,42 @@ void emitExtractHalfVector(
* dest[0:7] will have 0x01, dest[8:15]
* will have 0x01 and so on
*/
template <typename T>
template <
typename T,
typename std::enable_if<std::is_same<T, x86::Ymm>::value, int>::type = 0>
void gen8BitVectorOne(x86::Emitter* a, T dest) {
a->vpcmpeqw(dest, dest, dest);
a->vpabsb(dest, dest);
}

template <
typename T,
typename std::enable_if<std::is_same<T, x86::Zmm>::value, int>::type = 0>
void gen8BitVectorOne(x86::Emitter* a, T dest) {
a->vpternlogd(dest, dest, dest, 0xff);
a->vpabsb(dest, dest);
}

/**
* @brief Generates instruction sequence to compute s32 += U8 * I8
* @tparam T Register type of destination, e.g., x86::Ymm or x86::Zmm
*
* @param cReg contains result
*
*/
template <typename T>

template <
inst_set_t INST_SET,
typename std::enable_if<
INST_SET == inst_set_t::avx2 || INST_SET == inst_set_t::avx512,
int>::type = 0>
void genU8I8S32FMA(
x86::Emitter* a,
T aReg,
T bReg,
T cReg,
T oneReg16Bit,
T tmpReg) {
typename simd_info<INST_SET>::vec_reg_t aReg,
typename simd_info<INST_SET>::vec_reg_t bReg,
typename simd_info<INST_SET>::vec_reg_t cReg,
typename simd_info<INST_SET>::vec_reg_t oneReg16Bit,
typename simd_info<INST_SET>::vec_reg_t tmpReg) {
a->vpmaddubsw(tmpReg, aReg, bReg);
a->vpmaddwd(tmpReg, oneReg16Bit, tmpReg);
a->vpaddd(cReg, tmpReg, cReg);
Expand All @@ -166,8 +181,17 @@ void genU8I8S32FMA(
* @param dest contains result
*
*/
template <typename T>
void genU8Sum4(x86::Emitter* a, T src, T dest, T oneReg16Bit, T tmpReg) {
template <
inst_set_t INST_SET,
typename std::enable_if<
INST_SET == inst_set_t::avx2 || INST_SET == inst_set_t::avx512,
int>::type = 0>
void genU8Sum4(
x86::Emitter* a,
typename simd_info<INST_SET>::vec_reg_t src,
typename simd_info<INST_SET>::vec_reg_t dest,
typename simd_info<INST_SET>::vec_reg_t oneReg16Bit,
typename simd_info<INST_SET>::vec_reg_t tmpReg) {
gen8BitVectorOne(a, tmpReg);
a->vpmaddubsw(tmpReg, src, tmpReg);
a->vpmaddwd(tmpReg, tmpReg, oneReg16Bit);
Expand Down Expand Up @@ -213,8 +237,9 @@ void genU8Sum8(x86::Emitter* a, T src, T dest, T tmpReg) {
template <typename T>
void broadcast8Bit(x86::Emitter* a, x86::Gp src, T dest) {
// move src to dest
a->movq(dest.half(), src);
a->vpbroadcastb(dest, dest.half());
auto xmm = dest.xmm();
a->movq(xmm, src);
a->vpbroadcastb(dest, xmm);
}

} // namespace fbgemm
Loading

0 comments on commit d5ace7c

Please sign in to comment.