Skip to content

Commit

Permalink
Back out "[fbgemm] fp16 gemm using avx512"
Browse files Browse the repository at this point in the history
Summary: Original commit changeset: 6605bcecf391

Reviewed By: jspark1105

Differential Revision: D17692046

fbshipit-source-id: 6fd324f24ff0633f91f55b7194054ea6fbe27ed5
  • Loading branch information
Ullas Simhan authored and facebook-github-bot committed Oct 1, 2019
1 parent e82b986 commit c8b8540
Show file tree
Hide file tree
Showing 8 changed files with 476 additions and 3,247 deletions.
4 changes: 1 addition & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,7 @@ set(FBGEMM_AVX2_SRCS
src/UtilsAvx2.cc)

#All the source files that use avx512 instructions statically
set(FBGEMM_AVX512_SRCS
src/FbgemmFP16UKernelsAvx512.cc
src/UtilsAvx512.cc)
set(FBGEMM_AVX512_SRCS src/UtilsAvx512.cc)

set(FBGEMM_PUBLIC_HEADERS include/fbgemm/Fbgemm.h
include/fbgemm/FbgemmBuild.h
Expand Down
6 changes: 1 addition & 5 deletions include/fbgemm/FbgemmFP16.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
// WARNING: this is a legacy fp16 fbgemm implementation and will soon be
// upgraded to match with new fbgemm interface.

#include <cpuinfo.h>
#include <cassert>
#include <cstdlib>
#include <memory>
Expand Down Expand Up @@ -82,10 +81,7 @@ class PackedGemmMatrixFP16 {
}

void initializeParam() {
if (!cpuinfo_initialize()) {
throw std::runtime_error("Failed to initialize cpuinfo!");
}
bcol_ = (fbgemmHasAvx512Support() ? 16 : 8) * kernelNumColBlocks();
bcol_ = 8 * kernelNumColBlocks();

// set up internal packing parameters
nbrow_ = ((numRows() % blockRowSize()) == 0)
Expand Down
208 changes: 24 additions & 184 deletions src/FbgemmFP16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
#include <utility>

#include "FbgemmFP16UKernelsAvx2.h"
#include "FbgemmFP16UKernelsAvx512.h"

using namespace std;

Expand All @@ -33,50 +32,28 @@ inline void PackA(int nrow, int ncol, const float* from, int ldim, float* to) {
transpose_simd(nrow, ncol, from, ldim, to, nrow);
}

// Each kernel does the following computation that multiplies
// mb x k A sub-matrix with k x b_block_cols*64 B sub-matrix
// for (int j = 0; j < b_block_cols * 64; j += 64) {
// for (int kk = 0; kk < k; ++k) {
// for (int i = 0; i < mb; ++i) {
// c[i][j:j+64] += a[i][kk] * b[kk][j:j+64]
// }
// }
// }

struct KernelInfo {
using knl_ptr = funcptr_fp16;
// optimized kernels to cover all cases
// 2 in ?x2 should be the same as kernel_ncol_blocks.
// Here with kernel_ncol_blocks = 2, we can provide up to 6x2 kernels, due to
// the restrictions of ymm register numbers (16).
static constexpr knl_ptr kernel_avx2[] = {nullptr,
gemmkernel_1x2_AVX2_fA0fB0fC0,
gemmkernel_2x2_AVX2_fA0fB0fC0,
gemmkernel_3x2_AVX2_fA0fB0fC0,
gemmkernel_4x2_AVX2_fA0fB0fC0,
gemmkernel_5x2_AVX2_fA0fB0fC0,
gemmkernel_6x2_AVX2_fA0fB0fC0};

static constexpr knl_ptr kernel_avx512[] = {nullptr,
gemmkernel_1x2_AVX512_fA0fB0fC0,
gemmkernel_2x2_AVX512_fA0fB0fC0,
gemmkernel_3x2_AVX512_fA0fB0fC0,
gemmkernel_4x2_AVX512_fA0fB0fC0,
gemmkernel_5x2_AVX512_fA0fB0fC0,
gemmkernel_6x2_AVX512_fA0fB0fC0,
gemmkernel_7x2_AVX512_fA0fB0fC0,
gemmkernel_8x2_AVX512_fA0fB0fC0,
gemmkernel_9x2_AVX512_fA0fB0fC0,
gemmkernel_10x2_AVX512_fA0fB0fC0,
gemmkernel_11x2_AVX512_fA0fB0fC0,
gemmkernel_12x2_AVX512_fA0fB0fC0,
gemmkernel_13x2_AVX512_fA0fB0fC0,
gemmkernel_14x2_AVX512_fA0fB0fC0};
static constexpr array<knl_ptr, 7> kernel = {
{
nullptr,
gemmkernel_1x2_AVX2_fA0fB0fC0,
gemmkernel_2x2_AVX2_fA0fB0fC0,
gemmkernel_3x2_AVX2_fA0fB0fC0,
gemmkernel_4x2_AVX2_fA0fB0fC0,
gemmkernel_5x2_AVX2_fA0fB0fC0,
gemmkernel_6x2_AVX2_fA0fB0fC0
}
};

// autotuned kernel splits for various cases m = 1:mb_max
// may need re-autotuning for new uarch
// clang-format off
static constexpr array<array<array<int, 2>, 2>, 121> partition_avx2 = {
static constexpr array<array<array<int, 2>, 2>, 121> partition = {
// NOTE: clang-format wants to use a different formatting but the current
// formatting should be easier to read.
{
Expand Down Expand Up @@ -203,139 +180,10 @@ struct KernelInfo {
{{ { 6, 20 }, { 0, 0 } } }, // 120
}
};
static constexpr array<array<array<int, 2>, 2>, 121> partition_avx512 = {
// NOTE: clang-format wants to use a different formatting but the current
// formatting should be easier to read.
{
{{ { 0, 0 }, { 0, 0 } } }, // 0
{{ { 1, 1 }, { 0, 0 } } }, // 1
{{ { 2, 1 }, { 0, 0 } } }, // 2
{{ { 3, 1 }, { 0, 0 } } }, // 3
{{ { 4, 1 }, { 0, 0 } } }, // 4
{{ { 5, 1 }, { 0, 0 } } }, // 5
{{ { 6, 1 }, { 0, 0 } } }, // 6
{{ { 7, 1 }, { 0, 0 } } }, // 7
{{ { 8, 1 }, { 0, 0 } } }, // 8
{{ { 9, 1 }, { 0, 0 } } }, // 9
{{ { 10, 1 }, { 0, 0 } } }, // 10
{{ { 11, 1 }, { 0, 0 } } }, // 11
{{ { 12, 1 }, { 0, 0 } } }, // 12
{{ { 13, 1 }, { 0, 0 } } }, // 13
{{ { 14, 1 }, { 0, 0 } } }, // 14
{{ { 8, 1 }, { 7, 1 } } }, // 15
{{ { 8, 2 }, { 0, 0 } } }, // 16
{{ { 9, 1 }, { 8, 1 } } }, // 17
{{ { 9, 2 }, { 0, 0 } } }, // 18
{{ { 10, 1 }, { 9, 1 } } }, // 19
{{ { 10, 2 }, { 0, 0 } } }, // 20
{{ { 11, 1 }, { 10, 1 } } }, // 21
{{ { 11, 2 }, { 0, 0 } } }, // 22
{{ { 12, 1 }, { 11, 1 } } }, // 23
{{ { 12, 2 }, { 0, 0 } } }, // 24
{{ { 13, 1 }, { 12, 1 } } }, // 25
{{ { 13, 2 }, { 0, 0 } } }, // 26
{{ { 14, 1 }, { 13, 1 } } }, // 27
{{ { 14, 2 }, { 0, 0 } } }, // 28
{{ { 10, 2 }, { 9, 1 } } }, // 29
{{ { 10, 3 }, { 0, 0 } } }, // 30
{{ { 11, 2 }, { 9, 1 } } }, // 31
{{ { 11, 2 }, { 10, 1 } } }, // 32
{{ { 11, 3 }, { 0, 0 } } }, // 33
{{ { 12, 2 }, { 10, 1 } } }, // 34
{{ { 12, 2 }, { 11, 1 } } }, // 35
{{ { 12, 3 }, { 0, 0 } } }, // 36
{{ { 13, 2 }, { 11, 1 } } }, // 37
{{ { 13, 2 }, { 12, 1 } } }, // 38
{{ { 13, 3 }, { 0, 0 } } }, // 39
{{ { 14, 2 }, { 12, 1 } } }, // 40
{{ { 14, 2 }, { 13, 1 } } }, // 41
{{ { 14, 3 }, { 0, 0 } } }, // 42
{{ { 11, 3 }, { 10, 1 } } }, // 43
{{ { 11, 4 }, { 0, 0 } } }, // 44
{{ { 12, 3 }, { 9, 1 } } }, // 45
{{ { 12, 3 }, { 10, 1 } } }, // 46
{{ { 12, 3 }, { 11, 1 } } }, // 47
{{ { 12, 4 }, { 0, 0 } } }, // 48
{{ { 13, 3 }, { 10, 1 } } }, // 49
{{ { 13, 3 }, { 11, 1 } } }, // 50
{{ { 13, 3 }, { 12, 1 } } }, // 51
{{ { 13, 4 }, { 0, 0 } } }, // 52
{{ { 14, 3 }, { 11, 1 } } }, // 53
{{ { 14, 3 }, { 12, 1 } } }, // 54
{{ { 14, 3 }, { 13, 1 } } }, // 55
{{ { 14, 4 }, { 0, 0 } } }, // 56
{{ { 12, 4 }, { 9, 1 } } }, // 57
{{ { 12, 4 }, { 10, 1 } } }, // 58
{{ { 12, 4 }, { 11, 1 } } }, // 59
{{ { 12, 5 }, { 0, 0 } } }, // 60
{{ { 13, 4 }, { 9, 1 } } }, // 61
{{ { 13, 4 }, { 10, 1 } } }, // 62
{{ { 13, 4 }, { 11, 1 } } }, // 63
{{ { 13, 4 }, { 12, 1 } } }, // 64
{{ { 13, 5 }, { 0, 0 } } }, // 65
{{ { 14, 4 }, { 10, 1 } } }, // 66
{{ { 14, 4 }, { 11, 1 } } }, // 67
{{ { 14, 4 }, { 12, 1 } } }, // 68
{{ { 14, 4 }, { 13, 1 } } }, // 69
{{ { 14, 5 }, { 0, 0 } } }, // 70
{{ { 12, 5 }, { 11, 1 } } }, // 71
{{ { 12, 6 }, { 0, 0 } } }, // 72
{{ { 13, 5 }, { 8, 1 } } }, // 73
{{ { 13, 5 }, { 9, 1 } } }, // 74
{{ { 13, 5 }, { 10, 1 } } }, // 75
{{ { 13, 5 }, { 11, 1 } } }, // 76
{{ { 13, 5 }, { 12, 1 } } }, // 77
{{ { 13, 6 }, { 0, 0 } } }, // 78
{{ { 14, 5 }, { 9, 1 } } }, // 79
{{ { 14, 5 }, { 10, 1 } } }, // 80
{{ { 14, 5 }, { 11, 1 } } }, // 81
{{ { 14, 5 }, { 12, 1 } } }, // 82
{{ { 14, 5 }, { 13, 1 } } }, // 83
{{ { 14, 6 }, { 0, 0 } } }, // 84
{{ { 13, 6 }, { 7, 1 } } }, // 85
{{ { 13, 6 }, { 8, 1 } } }, // 86
{{ { 13, 6 }, { 9, 1 } } }, // 87
{{ { 13, 6 }, { 10, 1 } } }, // 88
{{ { 13, 6 }, { 11, 1 } } }, // 89
{{ { 13, 6 }, { 12, 1 } } }, // 90
{{ { 13, 7 }, { 0, 0 } } }, // 91
{{ { 14, 6 }, { 8, 1 } } }, // 92
{{ { 14, 6 }, { 9, 1 } } }, // 93
{{ { 14, 6 }, { 10, 1 } } }, // 94
{{ { 14, 6 }, { 11, 1 } } }, // 95
{{ { 14, 6 }, { 12, 1 } } }, // 96
{{ { 14, 6 }, { 13, 1 } } }, // 97
{{ { 14, 7 }, { 0, 0 } } }, // 98
{{ { 13, 7 }, { 8, 1 } } }, // 99
{{ { 13, 7 }, { 9, 1 } } }, // 100
{{ { 13, 7 }, { 10, 1 } } }, // 101
{{ { 13, 7 }, { 11, 1 } } }, // 102
{{ { 13, 7 }, { 12, 1 } } }, // 103
{{ { 13, 8 }, { 0, 0 } } }, // 104
{{ { 14, 7 }, { 7, 1 } } }, // 105
{{ { 14, 7 }, { 8, 1 } } }, // 106
{{ { 14, 7 }, { 9, 1 } } }, // 107
{{ { 14, 7 }, { 10, 1 } } }, // 108
{{ { 14, 7 }, { 11, 1 } } }, // 109
{{ { 14, 7 }, { 12, 1 } } }, // 110
{{ { 14, 7 }, { 13, 1 } } }, // 111
{{ { 14, 8 }, { 0, 0 } } }, // 112
{{ { 13, 8 }, { 9, 1 } } }, // 113
{{ { 13, 8 }, { 10, 1 } } }, // 114
{{ { 13, 8 }, { 11, 1 } } }, // 115
{{ { 13, 8 }, { 12, 1 } } }, // 116
{{ { 13, 9 }, { 0, 0 } } }, // 117
{{ { 14, 8 }, { 6, 1 } } }, // 118
{{ { 14, 8 }, { 7, 1 } } }, // 119
{{ { 14, 8 }, { 8, 1 } } }, // 120
}
};
// clang-format on
};
constexpr KernelInfo::knl_ptr KernelInfo::kernel_avx2[];
constexpr KernelInfo::knl_ptr KernelInfo::kernel_avx512[];
constexpr array<array<array<int, 2>, 2>, 121> KernelInfo::partition_avx2;
constexpr array<array<array<int, 2>, 2>, 121> KernelInfo::partition_avx512;
constexpr array<KernelInfo::knl_ptr, 7> KernelInfo::kernel;
constexpr array<array<array<int, 2>, 2>, 121> KernelInfo::partition;

// autotuned kernel splits for various cases m = 1:mb_max
FBGEMM_API void cblas_gemm_compute(
Expand All @@ -356,8 +204,7 @@ FBGEMM_API void cblas_gemm_compute(
// constants
const int n = Bp.numCols(), k = Bp.numRows(), ldc = n;
const int mb_max = 120;
bool has_avx512 = fbgemmHasAvx512Support();
int simd_width = has_avx512 ? 16 : 8;
constexpr int simd_width = 8;
int kernel_ncol_blocks = Bp.kernelNumColBlocks();
int kernel_ncols = kernel_ncol_blocks * simd_width;

Expand All @@ -367,18 +214,13 @@ FBGEMM_API void cblas_gemm_compute(

GemmParams gp;

const funcptr_fp16* kernels =
has_avx512 ? KernelInfo::kernel_avx512 : KernelInfo::kernel_avx2;
const array<array<array<int, 2>, 2>, 121>& partition =
has_avx512 ? KernelInfo::partition_avx512 : KernelInfo::partition_avx2;

int i_begin, i_end;
// fbgemmGetRange(num_threads, thread_id, m, 1, i_begin, i_end);
i_begin = 0;
i_end = m;
for (auto m0 = i_begin; m0 < i_end; m0 += mb_max) {
int mb = std::min(mb_max, i_end - m0);
assert(mb < partition.size());
assert(mb < KernelInfo::partition.size());
for (auto k_ind = 0; k_ind < k; k_ind += Bp.blockRowSize()) {
// set up proper accumulation to avoid "Nan" problem
float beta_;
Expand All @@ -398,8 +240,8 @@ FBGEMM_API void cblas_gemm_compute(

auto m1 = m0;
for (auto c = 0; c < 2; c++) {
auto kernel_nrows = partition[mb][c][0];
auto nkernel_nrows = partition[mb][c][1];
auto kernel_nrows = KernelInfo::partition[mb][c][0];
auto nkernel_nrows = KernelInfo::partition[mb][c][1];

auto m_start = m1, m_end = m1 + kernel_nrows * nkernel_nrows;
for (auto m2 = m_start; m2 < m_end; m2 += kernel_nrows) {
Expand Down Expand Up @@ -433,7 +275,7 @@ FBGEMM_API void cblas_gemm_compute(
gp.C += Bp.blockColSize() * jb_begin;
gp.b_block_cols = jb_end - jb_begin;
if (gp.b_block_cols) {
kernels[kernel_nrows](&gp);
KernelInfo::kernel[kernel_nrows](&gp);
}
} else {
int last_blk_col = nbcol * Bp.blockColSize();
Expand All @@ -445,7 +287,7 @@ FBGEMM_API void cblas_gemm_compute(
gp.C += Bp.blockColSize() * jb_begin;
gp.b_block_cols = jb_end - jb_begin;
if (gp.b_block_cols) {
kernels[kernel_nrows](&gp);
KernelInfo::kernel[kernel_nrows](&gp);
}
}

Expand All @@ -459,21 +301,19 @@ FBGEMM_API void cblas_gemm_compute(
gp.B = &(Bp(k_ind, last_blk_col));
gp.C = &C[m2 * ldc + last_blk_col];
gp.b_block_cols = 1;
kernels[kernel_nrows](&gp);
KernelInfo::kernel[kernel_nrows](&gp);
} else {
// small temporary buffer: the size should be larger than the
// required kernel_nrow x kernel_ncols elements computed in the
// registers.
float c_tmp[14 * 32] = {0};
assert(
sizeof(c_tmp) / sizeof(c_tmp[0]) >=
kernel_nrows * kernel_ncols);
float c_tmp[16 * 24] = {0};
assert((16 * 24) > kernel_nrows * kernel_ncols);

gp.B = &(Bp(k_ind, last_blk_col));
gp.C = c_tmp;
gp.ldc = kernel_ncols * sizeof(C[0]);
gp.b_block_cols = 1;
kernels[kernel_nrows](&gp);
KernelInfo::kernel[kernel_nrows](&gp);
for (int i = 0; i < kernel_nrows; i++) {
// Todo: use assembly
for (int j = last_blk_col; j < n; j++) {
Expand Down
Loading

0 comments on commit c8b8540

Please sign in to comment.