Skip to content

Commit

Permalink
Add support for AVX512-256(YMM) in FBGEMM16 (pytorch#209)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#209

Add AVX512-256 support to FBGEMM operation. This benefits Intel(r) Xeon(r) D processors by running at higher turbo frequency.

Reviewed By: jianyuh

Differential Revision: D18138146

fbshipit-source-id: 7f25247b92e62a058797b2a44ba57b147cb7f5f6
  • Loading branch information
efiks authored and facebook-github-bot committed Dec 11, 2019
1 parent 3839cba commit 6394aab
Show file tree
Hide file tree
Showing 14 changed files with 3,431 additions and 1,716 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ set(FBGEMM_AVX512_SRCS
src/FbgemmBfloat16ConvertAvx512.cc
src/FbgemmFP16UKernelsAvx512.cc
src/FbgemmFloat16ConvertAvx512.cc
src/FbgemmFP16UKernelsAvx512_256.cc
src/UtilsAvx512.cc)

set(FBGEMM_PUBLIC_HEADERS include/fbgemm/Fbgemm.h
Expand Down
2 changes: 2 additions & 0 deletions bench/FP16Benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,8 @@ int main(int argc, const char* argv[]) {
int repetitions = parseArgumentInt(argc, argv, "--repit=", 1, 1);
bool no_flush = parseArgumentBool(argc, argv, "--no-flush", false);
bool no_mkl = parseArgumentBool(argc, argv, "--no-mkl", false);
bool enableAvx512_ymm = parseArgumentBool(argc, argv, "--avx512-256", false);
fbgemmEnableAvx512Ymm(enableAvx512_ymm);

performance_test(num_instances, !no_flush, repetitions, !no_mkl);
}
2 changes: 1 addition & 1 deletion include/fbgemm/FbgemmFP16.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class PackedGemmMatrixFP16 {
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
bcol_ = (fbgemmHasAvx512Support()
bcol_ = (isZmm(fbgemmInstructionSet())
? simd_info<inst_set_t::avx512>::WIDTH_32BIT_ELEMS
: simd_info<inst_set_t::avx2>::WIDTH_32BIT_ELEMS) *
kernelNumColBlocks();
Expand Down
45 changes: 38 additions & 7 deletions include/fbgemm/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ enum class matrix_op_t { NoTranspose, Transpose };
/**
* @brief Typed enum for supported instruction sets.
*/
enum class inst_set_t { anyarch, avx2, avx512, avx512_vnni };
enum class inst_set_t { anyarch, avx2, avx512, avx512_ymm, avx512_vnni };

/**
* @brief Typed enum for optimized paths for convolutions
Expand Down Expand Up @@ -130,6 +130,21 @@ FBGEMM_API void transpose_simd(
float* dst,
int ld_dst);

/**
* @brief Explicitly set instruction set to be used
*/
FBGEMM_API void fbgemmForceIsa(inst_set_t);

/**
* @brief Enable AVX512-256 path for Intel(r) Xeon(r) D servers
*/
void fbgemmEnableAvx512Ymm(bool);

/**
* @brief Are we running on a Xeon-D cpu?
*/
FBGEMM_API bool fbgemmIsIntelXeonD();

/**
* @brief Are we running on a AVX512 supported cpu?
*/
Expand All @@ -145,6 +160,21 @@ FBGEMM_API bool fbgemmHasAvx2Support();
*/
FBGEMM_API bool fbgemmHasAvx512VnniSupport();

/**
* @brief Retrieve current CPU instruction set
*/
FBGEMM_API inst_set_t fbgemmInstructionSet();

/**
* @brief Is ISA is wide vector ZMM
*/
FBGEMM_API bool isZmm(inst_set_t);

/**
* @brief Is ISA is wide vector ZMM
*/
FBGEMM_API bool isYmm(inst_set_t);

/**
* @brief Helper struct to enable autotuning of FBGEMM packing and kernels.
*
Expand Down Expand Up @@ -228,26 +258,27 @@ template <typename accT = std::int32_t>
FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) {
constexpr bool is_32bit = std::is_same<accT, int32_t>::value;
constexpr bool is_16bit = std::is_same<accT, int16_t>::value;
static const auto iset = fbgemmInstructionSet();

if (is_32bit) {
if (param->ROW_INTERLEAVE != 4)
return false;

if (fbgemmHasAvx512Support()) {
if (isZmm(iset)) {
if (param->NR_MIN != 16 || param->NR % param->NR_MIN)
return false;
} else if (fbgemmHasAvx2Support()) {
} else if (isYmm(iset)) {
if (param->NR_MIN != 8 || param->NR % param->NR_MIN)
return false;
}
} else if (is_16bit) {
if (param->ROW_INTERLEAVE != 2)
return false;

if (fbgemmHasAvx512Support()) {
if (isZmm(iset)) {
if (param->NR_MIN != 32 || param->NR % param->NR_MIN)
return false;
} else if (fbgemmHasAvx2Support()) {
} else if (isYmm(iset)) {
if (param->NR_MIN != 16 || param->NR % param->NR_MIN)
return false;
}
Expand All @@ -257,7 +288,7 @@ FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) {
return false;
if (param->NCB % param->NR)
return false;
if (fbgemmHasAvx512Support()) {
if (isZmm(iset)) {
if (is_32bit) {
// Zmm register usage for C
if (param->MR * (param->NR / param->NR_MIN) > 28)
Expand All @@ -269,7 +300,7 @@ FBGEMM_API bool isValidBlockingFactor(BlockingFactors* param) {
return false;
}

} else if (fbgemmHasAvx2Support()) {
} else if (isYmm(iset)) {
if (param->MR * (param->NR / param->NR_MIN) > 12)
return false;
}
Expand Down
118 changes: 67 additions & 51 deletions src/FbgemmFP16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,17 @@
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "fbgemm/FbgemmFP16.h"

#include "fbgemm/Fbgemm.h"

#include <cpuinfo.h>
#include <array>
#include <cmath>
#include <utility>

#include "./FbgemmFP16Common.h"
#include "./FbgemmFP16UKernelsAvx2.h"
#include "./FbgemmFP16UKernelsAvx512.h"
#include "./FbgemmFP16UKernelsAvx512_256.h"
#include "fbgemm/Fbgemm.h"
#include "fbgemm/FbgemmFP16.h"

using namespace std;

Expand Down Expand Up @@ -44,40 +44,59 @@ inline void PackA(int nrow, int ncol, const float* from, int ldim, float* to) {
// }
// }

struct KernelInfo {
using knl_ptr = funcptr_fp16;
namespace KernelInfo {
using knl_ptr = funcptr_t<float16>;
// optimized kernels to cover all cases
// 2 in ?x2 should be the same as kernel_ncol_blocks.
// Here with kernel_ncol_blocks = 2, we can provide up to 6x2 kernels, due to
// the restrictions of ymm register numbers (16).
static constexpr knl_ptr kernel_avx2[] = {nullptr,
gemmkernel_1x2_Avx2_fA0fB0fC0,
gemmkernel_2x2_Avx2_fA0fB0fC0,
gemmkernel_3x2_Avx2_fA0fB0fC0,
gemmkernel_4x2_Avx2_fA0fB0fC0,
gemmkernel_5x2_Avx2_fA0fB0fC0,
gemmkernel_6x2_Avx2_fA0fB0fC0};
constexpr std::array<knl_ptr, 15> kernel_avx2 = {
nullptr,
gemmkernel_1x2_Avx2_fp16_fA0fB0fC0,
gemmkernel_2x2_Avx2_fp16_fA0fB0fC0,
gemmkernel_3x2_Avx2_fp16_fA0fB0fC0,
gemmkernel_4x2_Avx2_fp16_fA0fB0fC0,
gemmkernel_5x2_Avx2_fp16_fA0fB0fC0,
gemmkernel_6x2_Avx2_fp16_fA0fB0fC0};

constexpr std::array<knl_ptr, 15> kernel_avx512_256 = {
nullptr,
gemmkernel_1x2_Avx2_fp16_fA0fB0fC0,
gemmkernel_2x2_Avx2_fp16_fA0fB0fC0,
gemmkernel_3x2_Avx2_fp16_fA0fB0fC0,
gemmkernel_4x2_Avx2_fp16_fA0fB0fC0,
gemmkernel_5x2_Avx2_fp16_fA0fB0fC0,
gemmkernel_6x2_Avx2_fp16_fA0fB0fC0,
gemmkernel_7x2_Avx512_256_fp16_fA0fB0fC0,
gemmkernel_8x2_Avx512_256_fp16_fA0fB0fC0,
gemmkernel_9x2_Avx512_256_fp16_fA0fB0fC0,
gemmkernel_10x2_Avx512_256_fp16_fA0fB0fC0,
gemmkernel_11x2_Avx512_256_fp16_fA0fB0fC0,
gemmkernel_12x2_Avx512_256_fp16_fA0fB0fC0,
gemmkernel_13x2_Avx512_256_fp16_fA0fB0fC0,
gemmkernel_14x2_Avx512_256_fp16_fA0fB0fC0};

static constexpr knl_ptr kernel_avx512[] = {nullptr,
gemmkernel_1x2_Avx512_fA0fB0fC0,
gemmkernel_2x2_Avx512_fA0fB0fC0,
gemmkernel_3x2_Avx512_fA0fB0fC0,
gemmkernel_4x2_Avx512_fA0fB0fC0,
gemmkernel_5x2_Avx512_fA0fB0fC0,
gemmkernel_6x2_Avx512_fA0fB0fC0,
gemmkernel_7x2_Avx512_fA0fB0fC0,
gemmkernel_8x2_Avx512_fA0fB0fC0,
gemmkernel_9x2_Avx512_fA0fB0fC0,
gemmkernel_10x2_Avx512_fA0fB0fC0,
gemmkernel_11x2_Avx512_fA0fB0fC0,
gemmkernel_12x2_Avx512_fA0fB0fC0,
gemmkernel_13x2_Avx512_fA0fB0fC0,
gemmkernel_14x2_Avx512_fA0fB0fC0};
constexpr std::array<knl_ptr, 15> kernel_avx512 = {
nullptr,
gemmkernel_1x2_Avx512_fp16_fA0fB0fC0,
gemmkernel_2x2_Avx512_fp16_fA0fB0fC0,
gemmkernel_3x2_Avx512_fp16_fA0fB0fC0,
gemmkernel_4x2_Avx512_fp16_fA0fB0fC0,
gemmkernel_5x2_Avx512_fp16_fA0fB0fC0,
gemmkernel_6x2_Avx512_fp16_fA0fB0fC0,
gemmkernel_7x2_Avx512_fp16_fA0fB0fC0,
gemmkernel_8x2_Avx512_fp16_fA0fB0fC0,
gemmkernel_9x2_Avx512_fp16_fA0fB0fC0,
gemmkernel_10x2_Avx512_fp16_fA0fB0fC0,
gemmkernel_11x2_Avx512_fp16_fA0fB0fC0,
gemmkernel_12x2_Avx512_fp16_fA0fB0fC0,
gemmkernel_13x2_Avx512_fp16_fA0fB0fC0,
gemmkernel_14x2_Avx512_fp16_fA0fB0fC0};

// autotuned kernel splits for various cases m = 1:mb_max
// may need re-autotuning for new uarch
// clang-format off
static constexpr array<array<array<int, 2>, 2>, 121> partition_avx2 = {
// autotuned kernel splits for various cases m = 1:mb_max
// may need re-autotuning for new uarch
// clang-format off
constexpr partition_array_t partition_avx2 = {
// NOTE: clang-format wants to use a different formatting but the current
// formatting should be easier to read.
{
Expand Down Expand Up @@ -204,7 +223,7 @@ static constexpr knl_ptr kernel_avx512[] = {nullptr,
{{ { 6, 20 }, { 0, 0 } } }, // 120
}
};
static constexpr array<array<array<int, 2>, 2>, 121> partition_avx512 = {
constexpr partition_array_t partition_avx512 = {
// NOTE: clang-format wants to use a different formatting but the current
// formatting should be easier to read.
{
Expand Down Expand Up @@ -331,12 +350,8 @@ static constexpr knl_ptr kernel_avx512[] = {nullptr,
{{ { 14, 8 }, { 8, 1 } } }, // 120
}
};
// clang-format on
};
constexpr KernelInfo::knl_ptr KernelInfo::kernel_avx2[];
constexpr KernelInfo::knl_ptr KernelInfo::kernel_avx512[];
constexpr array<array<array<int, 2>, 2>, 121> KernelInfo::partition_avx2;
constexpr array<array<array<int, 2>, 2>, 121> KernelInfo::partition_avx512;
// clang-format on
}; // namespace KernelInfo

// define this to debug fp16 kernel using a reference C implementation
// #define FBGEMM_FP16_FALLBACK_TO_REF_KERNEL
Expand Down Expand Up @@ -399,7 +414,9 @@ FBGEMM_API void cblas_gemm_compute(
const int mb_max = 120;
// By some reason, if packed B is using packing layout for avx2, we just use
// avx2 even if avx512 is available.
bool use_avx512 = fbgemmHasAvx512Support() &&
static inst_set_t isa = fbgemmInstructionSet();

bool use_avx512 = isZmm(isa) &&
(Bp.blockColSize() ==
simd_info<inst_set_t::avx512>::WIDTH_32BIT_ELEMS *
Bp.kernelNumColBlocks());
Expand All @@ -408,12 +425,15 @@ FBGEMM_API void cblas_gemm_compute(
static thread_local unique_ptr<std::array<float, 256 * 1024>> scratchpad(
new std::array<float, 256 * 1024>());

GemmParams gp;
GemmParams<float16> gp;

const funcptr_fp16* kernels =
use_avx512 ? KernelInfo::kernel_avx512 : KernelInfo::kernel_avx2;
const array<array<array<int, 2>, 2>, 121>& partition =
use_avx512 ? KernelInfo::partition_avx512 : KernelInfo::partition_avx2;
const auto& kernels = use_avx512
? KernelInfo::kernel_avx512
: isa == inst_set_t::avx512_ymm ? KernelInfo::kernel_avx512_256
: KernelInfo::kernel_avx2;
const auto& partition = use_avx512 || isa == inst_set_t::avx512_ymm
? KernelInfo::partition_avx512
: KernelInfo::partition_avx2;

int i_begin, i_end;
// fbgemmPartition1D(thread_id, num_threads, m, i_begin, i_end);
Expand All @@ -425,16 +445,13 @@ FBGEMM_API void cblas_gemm_compute(
for (auto k_ind = 0; k_ind < k; k_ind += Bp.blockRowSize()) {
// set up proper accumulation to avoid "Nan" problem
float beta_;
uint64_t accum;
if (k_ind == 0) {
// accumulate of beta != 0.0
// do not!!! accumulate otherwise
beta_ = beta;
accum = (beta_ == 0.0f) ? 0 : 1;
} else {
// always accumulate with beta_ = 1.0f
beta_ = 1.0f;
accum = 1;
}

const int kb = std::min(Bp.blockRowSize(), Bp.numRows() - k_ind);
Expand All @@ -461,8 +478,7 @@ FBGEMM_API void cblas_gemm_compute(
int nbcol = n / Bp.blockColSize();
gp.k = kb;
gp.B = &(Bp(k_ind, 0));
gp.beta = &beta_;
gp.accum = accum;
gp.beta = beta_;
gp.C = &C[m2 * ldc];
gp.ldc = ldc * sizeof(C[0]);
gp.b_block_cols = nbcol;
Expand Down Expand Up @@ -529,7 +545,7 @@ FBGEMM_API void cblas_gemm_compute(
assert(
i * Bp.blockColSize() + (j - last_blk_col) <
sizeof(c_tmp) / sizeof(c_tmp[0]));
if (accum == 0) {
if (beta_ == 0.f) {
C[(m2 + i) * ldc + j] =
c_tmp[i * Bp.blockColSize() + (j - last_blk_col)];
} else {
Expand Down
35 changes: 35 additions & 0 deletions src/FbgemmFP16Common.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once

#include <array>
#include <fbgemm/Types.h>
#include <fbgemm/Utils.h>

namespace fbgemm {
using partition_array_t = std::array<std::array<std::array<int, 2>, 2>, 121>;

template<typename T>
struct GemmParams {
uint64_t k;
float* A;
const T* B;
float beta;
float* C;
uint64_t ldc;
uint64_t b_block_cols;
uint64_t b_block_size;
};

template<typename T>
using funcptr_t = void(*)(GemmParams<T>*);

using fp16 = float16;
using fp32 = float;
using GemmParamsFP16 = GemmParams<fp16>;

}
Loading

0 comments on commit 6394aab

Please sign in to comment.