Skip to content

Commit

Permalink
fp16 gemm using avx512 (pytorch#137)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#137

fp16 GEMM was not using avx512 falling behind fp32 performance for large m cases.
This diff enables using avx512. Further tuning for register blocking size may be needed.
Longer term we would also need to use JIT'ing for fp16.

Reviewed By: jianyuh

Differential Revision: D17786712

fbshipit-source-id: bebf8723d03db7e128097310745a8103b712ee06
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Oct 8, 2019
1 parent 8786c08 commit 82d259d
Show file tree
Hide file tree
Showing 9 changed files with 3,306 additions and 479 deletions.
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ set(FBGEMM_AVX2_SRCS
src/UtilsAvx2.cc)

#All the source files that use avx512 instructions statically
set(FBGEMM_AVX512_SRCS src/UtilsAvx512.cc)
set(FBGEMM_AVX512_SRCS
src/FbgemmFP16UKernelsAvx512.cc
src/UtilsAvx512.cc)

set(FBGEMM_PUBLIC_HEADERS include/fbgemm/Fbgemm.h
include/fbgemm/FbgemmBuild.h
Expand Down
6 changes: 5 additions & 1 deletion include/fbgemm/FbgemmFP16.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
// WARNING: this is a legacy fp16 fbgemm implementation and will soon be
// upgraded to match with new fbgemm interface.

#include <cpuinfo.h>
#include <cassert>
#include <cstdlib>
#include <memory>
Expand Down Expand Up @@ -81,7 +82,10 @@ class PackedGemmMatrixFP16 {
}

void initializeParam() {
bcol_ = 8 * kernelNumColBlocks();
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
bcol_ = (fbgemmHasAvx512Support() ? 16 : 8) * kernelNumColBlocks();

// set up internal packing parameters
nbrow_ = ((numRows() % blockRowSize()) == 0)
Expand Down
268 changes: 244 additions & 24 deletions src/FbgemmFP16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@

#include <cpuinfo.h>
#include <array>
#include <cmath>
#include <utility>

#include "FbgemmFP16UKernelsAvx2.h"
#include "FbgemmFP16UKernelsAvx512.h"

using namespace std;

Expand All @@ -32,28 +34,50 @@ inline void PackA(int nrow, int ncol, const float* from, int ldim, float* to) {
transpose_simd(nrow, ncol, from, ldim, to, nrow);
}

// Each kernel does the following computation that multiplies
// mb x k A sub-matrix with k x b_block_cols*64 B sub-matrix
// for (int j = 0; j < b_block_cols * 64; j += 64) {
// for (int kk = 0; kk < k; ++k) {
// for (int i = 0; i < mb; ++i) {
// c[i][j:j+64] += a[i][kk] * b[kk][j:j+64]
// }
// }
// }

struct KernelInfo {
using knl_ptr = funcptr_fp16;
// optimized kernels to cover all cases
// 2 in ?x2 should be the same as kernel_ncol_blocks.
// Here with kernel_ncol_blocks = 2, we can provide up to 6x2 kernels, due to
// the restrictions of ymm register numbers (16).
static constexpr array<knl_ptr, 7> kernel = {
{
nullptr,
gemmkernel_1x2_AVX2_fA0fB0fC0,
gemmkernel_2x2_AVX2_fA0fB0fC0,
gemmkernel_3x2_AVX2_fA0fB0fC0,
gemmkernel_4x2_AVX2_fA0fB0fC0,
gemmkernel_5x2_AVX2_fA0fB0fC0,
gemmkernel_6x2_AVX2_fA0fB0fC0
}
};
static constexpr knl_ptr kernel_avx2[] = {nullptr,
gemmkernel_1x2_AVX2_fA0fB0fC0,
gemmkernel_2x2_AVX2_fA0fB0fC0,
gemmkernel_3x2_AVX2_fA0fB0fC0,
gemmkernel_4x2_AVX2_fA0fB0fC0,
gemmkernel_5x2_AVX2_fA0fB0fC0,
gemmkernel_6x2_AVX2_fA0fB0fC0};

static constexpr knl_ptr kernel_avx512[] = {nullptr,
gemmkernel_1x2_AVX512_fA0fB0fC0,
gemmkernel_2x2_AVX512_fA0fB0fC0,
gemmkernel_3x2_AVX512_fA0fB0fC0,
gemmkernel_4x2_AVX512_fA0fB0fC0,
gemmkernel_5x2_AVX512_fA0fB0fC0,
gemmkernel_6x2_AVX512_fA0fB0fC0,
gemmkernel_7x2_AVX512_fA0fB0fC0,
gemmkernel_8x2_AVX512_fA0fB0fC0,
gemmkernel_9x2_AVX512_fA0fB0fC0,
gemmkernel_10x2_AVX512_fA0fB0fC0,
gemmkernel_11x2_AVX512_fA0fB0fC0,
gemmkernel_12x2_AVX512_fA0fB0fC0,
gemmkernel_13x2_AVX512_fA0fB0fC0,
gemmkernel_14x2_AVX512_fA0fB0fC0};

// autotuned kernel splits for various cases m = 1:mb_max
// may need re-autotuning for new uarch
// clang-format off
static constexpr array<array<array<int, 2>, 2>, 121> partition = {
static constexpr array<array<array<int, 2>, 2>, 121> partition_avx2 = {
// NOTE: clang-format wants to use a different formatting but the current
// formatting should be easier to read.
{
Expand Down Expand Up @@ -180,10 +204,178 @@ struct KernelInfo {
{{ { 6, 20 }, { 0, 0 } } }, // 120
}
};
static constexpr array<array<array<int, 2>, 2>, 121> partition_avx512 = {
// NOTE: clang-format wants to use a different formatting but the current
// formatting should be easier to read.
{
{{ { 0, 0 }, { 0, 0 } } }, // 0
{{ { 1, 1 }, { 0, 0 } } }, // 1
{{ { 2, 1 }, { 0, 0 } } }, // 2
{{ { 3, 1 }, { 0, 0 } } }, // 3
{{ { 4, 1 }, { 0, 0 } } }, // 4
{{ { 5, 1 }, { 0, 0 } } }, // 5
{{ { 6, 1 }, { 0, 0 } } }, // 6
{{ { 7, 1 }, { 0, 0 } } }, // 7
{{ { 8, 1 }, { 0, 0 } } }, // 8
{{ { 9, 1 }, { 0, 0 } } }, // 9
{{ { 10, 1 }, { 0, 0 } } }, // 10
{{ { 11, 1 }, { 0, 0 } } }, // 11
{{ { 12, 1 }, { 0, 0 } } }, // 12
{{ { 13, 1 }, { 0, 0 } } }, // 13
{{ { 14, 1 }, { 0, 0 } } }, // 14
{{ { 8, 1 }, { 7, 1 } } }, // 15
{{ { 8, 2 }, { 0, 0 } } }, // 16
{{ { 9, 1 }, { 8, 1 } } }, // 17
{{ { 9, 2 }, { 0, 0 } } }, // 18
{{ { 10, 1 }, { 9, 1 } } }, // 19
{{ { 10, 2 }, { 0, 0 } } }, // 20
{{ { 11, 1 }, { 10, 1 } } }, // 21
{{ { 11, 2 }, { 0, 0 } } }, // 22
{{ { 12, 1 }, { 11, 1 } } }, // 23
{{ { 12, 2 }, { 0, 0 } } }, // 24
{{ { 13, 1 }, { 12, 1 } } }, // 25
{{ { 13, 2 }, { 0, 0 } } }, // 26
{{ { 14, 1 }, { 13, 1 } } }, // 27
{{ { 14, 2 }, { 0, 0 } } }, // 28
{{ { 10, 2 }, { 9, 1 } } }, // 29
{{ { 10, 3 }, { 0, 0 } } }, // 30
{{ { 11, 2 }, { 9, 1 } } }, // 31
{{ { 11, 2 }, { 10, 1 } } }, // 32
{{ { 11, 3 }, { 0, 0 } } }, // 33
{{ { 12, 2 }, { 10, 1 } } }, // 34
{{ { 12, 2 }, { 11, 1 } } }, // 35
{{ { 12, 3 }, { 0, 0 } } }, // 36
{{ { 13, 2 }, { 11, 1 } } }, // 37
{{ { 13, 2 }, { 12, 1 } } }, // 38
{{ { 13, 3 }, { 0, 0 } } }, // 39
{{ { 14, 2 }, { 12, 1 } } }, // 40
{{ { 14, 2 }, { 13, 1 } } }, // 41
{{ { 14, 3 }, { 0, 0 } } }, // 42
{{ { 11, 3 }, { 10, 1 } } }, // 43
{{ { 11, 4 }, { 0, 0 } } }, // 44
{{ { 12, 3 }, { 9, 1 } } }, // 45
{{ { 12, 3 }, { 10, 1 } } }, // 46
{{ { 12, 3 }, { 11, 1 } } }, // 47
{{ { 12, 4 }, { 0, 0 } } }, // 48
{{ { 13, 3 }, { 10, 1 } } }, // 49
{{ { 13, 3 }, { 11, 1 } } }, // 50
{{ { 13, 3 }, { 12, 1 } } }, // 51
{{ { 13, 4 }, { 0, 0 } } }, // 52
{{ { 14, 3 }, { 11, 1 } } }, // 53
{{ { 14, 3 }, { 12, 1 } } }, // 54
{{ { 14, 3 }, { 13, 1 } } }, // 55
{{ { 14, 4 }, { 0, 0 } } }, // 56
{{ { 12, 4 }, { 9, 1 } } }, // 57
{{ { 12, 4 }, { 10, 1 } } }, // 58
{{ { 12, 4 }, { 11, 1 } } }, // 59
{{ { 12, 5 }, { 0, 0 } } }, // 60
{{ { 13, 4 }, { 9, 1 } } }, // 61
{{ { 13, 4 }, { 10, 1 } } }, // 62
{{ { 13, 4 }, { 11, 1 } } }, // 63
{{ { 13, 4 }, { 12, 1 } } }, // 64
{{ { 13, 5 }, { 0, 0 } } }, // 65
{{ { 14, 4 }, { 10, 1 } } }, // 66
{{ { 14, 4 }, { 11, 1 } } }, // 67
{{ { 14, 4 }, { 12, 1 } } }, // 68
{{ { 14, 4 }, { 13, 1 } } }, // 69
{{ { 14, 5 }, { 0, 0 } } }, // 70
{{ { 12, 5 }, { 11, 1 } } }, // 71
{{ { 12, 6 }, { 0, 0 } } }, // 72
{{ { 13, 5 }, { 8, 1 } } }, // 73
{{ { 13, 5 }, { 9, 1 } } }, // 74
{{ { 13, 5 }, { 10, 1 } } }, // 75
{{ { 13, 5 }, { 11, 1 } } }, // 76
{{ { 13, 5 }, { 12, 1 } } }, // 77
{{ { 13, 6 }, { 0, 0 } } }, // 78
{{ { 14, 5 }, { 9, 1 } } }, // 79
{{ { 14, 5 }, { 10, 1 } } }, // 80
{{ { 14, 5 }, { 11, 1 } } }, // 81
{{ { 14, 5 }, { 12, 1 } } }, // 82
{{ { 14, 5 }, { 13, 1 } } }, // 83
{{ { 14, 6 }, { 0, 0 } } }, // 84
{{ { 13, 6 }, { 7, 1 } } }, // 85
{{ { 13, 6 }, { 8, 1 } } }, // 86
{{ { 13, 6 }, { 9, 1 } } }, // 87
{{ { 13, 6 }, { 10, 1 } } }, // 88
{{ { 13, 6 }, { 11, 1 } } }, // 89
{{ { 13, 6 }, { 12, 1 } } }, // 90
{{ { 13, 7 }, { 0, 0 } } }, // 91
{{ { 14, 6 }, { 8, 1 } } }, // 92
{{ { 14, 6 }, { 9, 1 } } }, // 93
{{ { 14, 6 }, { 10, 1 } } }, // 94
{{ { 14, 6 }, { 11, 1 } } }, // 95
{{ { 14, 6 }, { 12, 1 } } }, // 96
{{ { 14, 6 }, { 13, 1 } } }, // 97
{{ { 14, 7 }, { 0, 0 } } }, // 98
{{ { 13, 7 }, { 8, 1 } } }, // 99
{{ { 13, 7 }, { 9, 1 } } }, // 100
{{ { 13, 7 }, { 10, 1 } } }, // 101
{{ { 13, 7 }, { 11, 1 } } }, // 102
{{ { 13, 7 }, { 12, 1 } } }, // 103
{{ { 13, 8 }, { 0, 0 } } }, // 104
{{ { 14, 7 }, { 7, 1 } } }, // 105
{{ { 14, 7 }, { 8, 1 } } }, // 106
{{ { 14, 7 }, { 9, 1 } } }, // 107
{{ { 14, 7 }, { 10, 1 } } }, // 108
{{ { 14, 7 }, { 11, 1 } } }, // 109
{{ { 14, 7 }, { 12, 1 } } }, // 110
{{ { 14, 7 }, { 13, 1 } } }, // 111
{{ { 14, 8 }, { 0, 0 } } }, // 112
{{ { 13, 8 }, { 9, 1 } } }, // 113
{{ { 13, 8 }, { 10, 1 } } }, // 114
{{ { 13, 8 }, { 11, 1 } } }, // 115
{{ { 13, 8 }, { 12, 1 } } }, // 116
{{ { 13, 9 }, { 0, 0 } } }, // 117
{{ { 14, 8 }, { 6, 1 } } }, // 118
{{ { 14, 8 }, { 7, 1 } } }, // 119
{{ { 14, 8 }, { 8, 1 } } }, // 120
}
};
// clang-format on
};
constexpr array<KernelInfo::knl_ptr, 7> KernelInfo::kernel;
constexpr array<array<array<int, 2>, 2>, 121> KernelInfo::partition;
constexpr KernelInfo::knl_ptr KernelInfo::kernel_avx2[];
constexpr KernelInfo::knl_ptr KernelInfo::kernel_avx512[];
constexpr array<array<array<int, 2>, 2>, 121> KernelInfo::partition_avx2;
constexpr array<array<array<int, 2>, 2>, 121> KernelInfo::partition_avx512;

// define this to debug fp16 kernel using a reference C implementation
// #define FBGEMM_FP16_FALLBACK_TO_REF_KERNEL
#ifdef FBGEMM_FP16_FALLBACK_TO_REF_KERNEL
namespace {
void ref_kernel(
int kernel_nrows,
GemmParams* gp,
const float* C_base,
int m_total,
int n_total,
bool use_avx512) {
int vlen = use_avx512 ? 16 : 8;
int kernel_ncol_blocks = 2;
int block_col_size = vlen * kernel_ncol_blocks;
for (int jb = 0; jb < gp->b_block_cols; ++jb) {
for (int k = 0; k < gp->k; ++k) {
for (int i = 0; i < kernel_nrows; ++i) {
float a = gp->A[i + k * kernel_nrows];
for (int j = 0; j < block_col_size; ++j) {
float* C_ptr =
gp->C + i * (gp->ldc / sizeof(float)) + jb * block_col_size + j;
assert(C_ptr < C_base + m_total * n_total);
float b =
cpu_half2float(gp->B[(jb * gp->k + k) * block_col_size + j]);
if (gp->accum) {
*C_ptr = std::fma(a, b, (*gp->beta) * (*C_ptr));
} else if (k > 0) {
*C_ptr = std::fma(a, b, *C_ptr);
} else {
*C_ptr = a * b;
}
}
}
}
}
}
} // anonymous namespace
#endif // FBGEMM_FP16_FALLBACK_TO_REF_KERNEL

// autotuned kernel splits for various cases m = 1:mb_max
FBGEMM_API void cblas_gemm_compute(
Expand All @@ -204,8 +396,12 @@ FBGEMM_API void cblas_gemm_compute(
// constants
const int n = Bp.numCols(), k = Bp.numRows(), ldc = n;
const int mb_max = 120;
constexpr int simd_width = 8;
int kernel_ncol_blocks = Bp.kernelNumColBlocks();
// By some reason, if packed B is using packing layout for avx2, we just use
// avx2 even if avx512 is available.
bool use_avx512 = fbgemmHasAvx512Support() &&
(Bp.blockColSize() == 16 * kernel_ncol_blocks);
int simd_width = use_avx512 ? 16 : 8;
int kernel_ncols = kernel_ncol_blocks * simd_width;

// private scratchpad storage
Expand All @@ -214,13 +410,18 @@ FBGEMM_API void cblas_gemm_compute(

GemmParams gp;

const funcptr_fp16* kernels =
use_avx512 ? KernelInfo::kernel_avx512 : KernelInfo::kernel_avx2;
const array<array<array<int, 2>, 2>, 121>& partition =
use_avx512 ? KernelInfo::partition_avx512 : KernelInfo::partition_avx2;

int i_begin, i_end;
// fbgemmGetRange(num_threads, thread_id, m, 1, i_begin, i_end);
i_begin = 0;
i_end = m;
for (auto m0 = i_begin; m0 < i_end; m0 += mb_max) {
int mb = std::min(mb_max, i_end - m0);
assert(mb < KernelInfo::partition.size());
assert(mb < partition.size());
for (auto k_ind = 0; k_ind < k; k_ind += Bp.blockRowSize()) {
// set up proper accumulation to avoid "Nan" problem
float beta_;
Expand All @@ -240,8 +441,8 @@ FBGEMM_API void cblas_gemm_compute(

auto m1 = m0;
for (auto c = 0; c < 2; c++) {
auto kernel_nrows = KernelInfo::partition[mb][c][0];
auto nkernel_nrows = KernelInfo::partition[mb][c][1];
auto kernel_nrows = partition[mb][c][0];
auto nkernel_nrows = partition[mb][c][1];

auto m_start = m1, m_end = m1 + kernel_nrows * nkernel_nrows;
for (auto m2 = m_start; m2 < m_end; m2 += kernel_nrows) {
Expand Down Expand Up @@ -275,7 +476,11 @@ FBGEMM_API void cblas_gemm_compute(
gp.C += Bp.blockColSize() * jb_begin;
gp.b_block_cols = jb_end - jb_begin;
if (gp.b_block_cols) {
KernelInfo::kernel[kernel_nrows](&gp);
#ifdef FBGEMM_FP16_FALLBACK_TO_REF_KERNEL
ref_kernel(kernel_nrows, &gp, C, m, n, use_avx512);
#else
kernels[kernel_nrows](&gp);
#endif
}
} else {
int last_blk_col = nbcol * Bp.blockColSize();
Expand All @@ -287,7 +492,11 @@ FBGEMM_API void cblas_gemm_compute(
gp.C += Bp.blockColSize() * jb_begin;
gp.b_block_cols = jb_end - jb_begin;
if (gp.b_block_cols) {
KernelInfo::kernel[kernel_nrows](&gp);
#ifdef FBGEMM_FP16_FALLBACK_TO_REF_KERNEL
ref_kernel(kernel_nrows, &gp, C, m, n, use_avx512);
#else
kernels[kernel_nrows](&gp);
#endif
}
}

Expand All @@ -298,22 +507,33 @@ FBGEMM_API void cblas_gemm_compute(
assert(rem < kernel_ncols);

if ((rem % Bp.blockColSize()) == 0) {
// FIXME : can't happen
gp.B = &(Bp(k_ind, last_blk_col));
gp.C = &C[m2 * ldc + last_blk_col];
gp.b_block_cols = 1;
KernelInfo::kernel[kernel_nrows](&gp);
#ifdef FBGEMM_FP16_FALLBACK_TO_REF_KERNEL
ref_kernel(kernel_nrows, &gp, C, m, n, use_avx512);
#else
kernels[kernel_nrows](&gp);
#endif
} else {
// small temporary buffer: the size should be larger than the
// required kernel_nrow x kernel_ncols elements computed in the
// registers.
float c_tmp[16 * 24] = {0};
assert((16 * 24) > kernel_nrows * kernel_ncols);
float c_tmp[14 * 32] = {0};
assert(
sizeof(c_tmp) / sizeof(c_tmp[0]) >=
kernel_nrows * kernel_ncols);

gp.B = &(Bp(k_ind, last_blk_col));
gp.C = c_tmp;
gp.ldc = kernel_ncols * sizeof(C[0]);
gp.b_block_cols = 1;
KernelInfo::kernel[kernel_nrows](&gp);
#ifdef FBGEMM_FP16_FALLBACK_TO_REF_KERNEL
ref_kernel(kernel_nrows, &gp, c_tmp, 14, 32, use_avx512);
#else
kernels[kernel_nrows](&gp);
#endif
for (int i = 0; i < kernel_nrows; i++) {
// Todo: use assembly
for (int j = last_blk_col; j < n; j++) {
Expand Down
Loading

0 comments on commit 82d259d

Please sign in to comment.