Skip to content

Commit

Permalink
Add threading for FBGEMM FP16
Browse files Browse the repository at this point in the history
Summary: Add threading support for FBGEMM FP16 routines.

Reviewed By: dskhudia, jacobkahn

Differential Revision: D13792341

fbshipit-source-id: eb31a11340ac9fd0ee9b4f570d161e7c7e6a7602
  • Loading branch information
jianyuh authored and facebook-github-bot committed Jan 30, 2019
1 parent 7933330 commit 03a8fa5
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 74 deletions.
39 changes: 35 additions & 4 deletions bench/FP16Benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,22 @@ void performance_test() {
C_ref.data(),
n);
#endif
cblas_gemm_compute(
matrix_op_t::NoTranspose, m, A.data(), Bp, beta, C_fb.data());
#ifdef _OPENMP
#pragma omp parallel
#endif
{
int num_threads = fbgemm_get_num_threads();
int tid = fbgemm_get_thread_num();
cblas_gemm_compute(
matrix_op_t::NoTranspose,
m,
A.data(),
Bp,
beta,
C_fb.data(),
tid,
num_threads);
}

#if defined(USE_MKL) || defined(USE_BLAS)
// Compare results
Expand Down Expand Up @@ -201,8 +215,25 @@ void performance_test() {
}

t_begin = chrono::system_clock::now();
cblas_gemm_compute(
matrix_op_t::NoTranspose, m, A.data(), Bp, beta, C_fb.data());

#ifdef _OPENMP
#pragma omp parallel
#endif
{
int num_threads = fbgemm_get_num_threads();
int tid = fbgemm_get_thread_num();

cblas_gemm_compute(
matrix_op_t::NoTranspose,
m,
A.data(),
Bp,
beta,
C_fb.data(),
tid,
num_threads);
}

t_end = chrono::system_clock::now();

if (it >= 0) {
Expand Down
23 changes: 6 additions & 17 deletions include/fbgemm/FbgemmFP16.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,14 +192,9 @@ class PackedGemmMatrixFP16 {
const float* A,
const PackedGemmMatrixFP16& Bp,
const float beta,
float* C);
friend void cblas_gemm_compute(
const matrix_op_t transa,
const int m,
const float* A,
const PackedGemmMatrixFP16& Bp,
const float beta,
float* C);
float* C,
int thread_id,
int num_threads);
};

/**
Expand All @@ -211,13 +206,7 @@ extern void cblas_gemm_compute(
const float* A,
const PackedGemmMatrixFP16& Bp,
const float beta,
float* C);
extern void cblas_gemm_compute(
const matrix_op_t transa,
const int m,
const float* A,
const PackedGemmMatrixFP16& Bp,
const float beta,
float* C);

float* C,
int thread_id = 0,
int num_threads = 1);
}; // namespace fbgemm
136 changes: 84 additions & 52 deletions src/FbgemmFP16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
*/
#include "fbgemm/FbgemmFP16.h"

#include "fbgemm/Fbgemm.h"

#include <cpuinfo.h>
#include <array>
#include <utility>
Expand Down Expand Up @@ -34,21 +36,24 @@ struct KernelInfo {
using knl_ptr = funcptr_fp16;
// optimized kernels to cover all cases
static constexpr array<knl_ptr, 15> kernel = {
{nullptr,
gemmkernel_1x1_AVX2_fA0fB0fC0,
gemmkernel_2x1_AVX2_fA0fB0fC0,
gemmkernel_3x1_AVX2_fA0fB0fC0,
gemmkernel_4x1_AVX2_fA0fB0fC0,
gemmkernel_5x1_AVX2_fA0fB0fC0,
gemmkernel_6x1_AVX2_fA0fB0fC0,
gemmkernel_7x1_AVX2_fA0fB0fC0,
gemmkernel_8x1_AVX2_fA0fB0fC0,
gemmkernel_9x1_AVX2_fA0fB0fC0,
gemmkernel_10x1_AVX2_fA0fB0fC0,
gemmkernel_11x1_AVX2_fA0fB0fC0,
gemmkernel_12x1_AVX2_fA0fB0fC0,
gemmkernel_13x1_AVX2_fA0fB0fC0,
gemmkernel_14x1_AVX2_fA0fB0fC0}};
{
nullptr,
gemmkernel_1x1_AVX2_fA0fB0fC0,
gemmkernel_2x1_AVX2_fA0fB0fC0,
gemmkernel_3x1_AVX2_fA0fB0fC0,
gemmkernel_4x1_AVX2_fA0fB0fC0,
gemmkernel_5x1_AVX2_fA0fB0fC0,
gemmkernel_6x1_AVX2_fA0fB0fC0,
gemmkernel_7x1_AVX2_fA0fB0fC0,
gemmkernel_8x1_AVX2_fA0fB0fC0,
gemmkernel_9x1_AVX2_fA0fB0fC0,
gemmkernel_10x1_AVX2_fA0fB0fC0,
gemmkernel_11x1_AVX2_fA0fB0fC0,
gemmkernel_12x1_AVX2_fA0fB0fC0,
gemmkernel_13x1_AVX2_fA0fB0fC0,
gemmkernel_14x1_AVX2_fA0fB0fC0
}
};

// autotuned kernel splits for various cases m = 1:mb_max
// may need re-autotuning for new uarch
Expand Down Expand Up @@ -177,7 +182,7 @@ struct KernelInfo {
{{ { 12, 9 }, { 10, 1 } } },
{{ { 12, 9 }, { 11, 1 } } },
{{ { 12, 10 }, { 0, 0 } } }
}
}
};
};
constexpr array<KernelInfo::knl_ptr, 15> KernelInfo::kernel;
Expand All @@ -190,7 +195,9 @@ FBGEMM_API void cblas_gemm_compute(
const float* A,
const PackedGemmMatrixFP16& Bp,
const float beta,
float* C) {
float* C,
int thread_id,
int num_threads) {
// ground truth
assert(cpuinfo_initialize());
assert(cpuinfo_has_x86_fma3());
Expand All @@ -209,8 +216,13 @@ FBGEMM_API void cblas_gemm_compute(
new std::array<float, 256 * 1024>());

GemmParams gp;
for (auto m0 = 0; m0 < m; m0 += mb_max) {
int mb = std::min(mb_max, m - m0);

int i_begin, i_end;
// fbgemmGetRange(num_threads, thread_id, m, 1, i_begin, i_end);
i_begin = 0;
i_end = m;
for (auto m0 = i_begin; m0 < i_end; m0 += mb_max) {
int mb = std::min(mb_max, i_end - m0);
assert(mb < KernelInfo::partition.size());
for (auto k_ind = 0; k_ind < k; k_ind += Bp.blockRowSize()) {
// set up proper accumulation to avoid "Nan" problem
Expand Down Expand Up @@ -249,46 +261,66 @@ FBGEMM_API void cblas_gemm_compute(
gp.ldc = ldc * sizeof(C[0]);
gp.b_block_cols = nbcol;
gp.b_block_size = gp.k * Bp.blockColSize() * sizeof(gp.B[0]);

if ((n % Bp.blockColSize()) == 0) {
KernelInfo::kernel[kernel_nrows](&gp);
int jb_begin, jb_end;
fbgemmGetRange(
num_threads, thread_id, gp.b_block_cols, 1, jb_begin, jb_end);
gp.B += gp.k * Bp.blockColSize() * jb_begin;
gp.C += 8 * jb_begin;
gp.b_block_cols = jb_end - jb_begin;
if (gp.b_block_cols) {
KernelInfo::kernel[kernel_nrows](&gp);
}
} else {
int last_blk_col = nbcol * Bp.blockColSize();
if (nbcol) {
KernelInfo::kernel[kernel_nrows](&gp);
int jb_begin, jb_end;
fbgemmGetRange(
num_threads, thread_id, gp.b_block_cols, 1, jb_begin, jb_end);
gp.B += gp.k * Bp.blockColSize() * jb_begin;
gp.C += 8 * jb_begin;
gp.b_block_cols = jb_end - jb_begin;
if (gp.b_block_cols) {
KernelInfo::kernel[kernel_nrows](&gp);
}
}

// leftover
int rem = n - last_blk_col;
assert(rem < kernel_ncols);
int b = (rem % simd_width) ? ((rem + simd_width) / simd_width)
: (rem / simd_width);
assert(b == 1);
if ((rem % simd_width) == 0) {
gp.B = &(Bp(k_ind, last_blk_col));
gp.C = &C[m2 * ldc + last_blk_col];
gp.b_block_cols = 1;
KernelInfo::kernel[kernel_nrows](&gp);
} else {
// small temporary buffer
float c_tmp[16 * 24] = {0};
assert((16 * 24) > kernel_nrows * kernel_ncols);
// use one thread to handle the fringe cases
if (thread_id == num_threads - 1) {
// leftover
int rem = n - last_blk_col;
assert(rem < kernel_ncols);
int b = (rem % simd_width) ? ((rem + simd_width) / simd_width)
: (rem / simd_width);
assert(b == 1);
if ((rem % simd_width) == 0) {
gp.B = &(Bp(k_ind, last_blk_col));
gp.C = &C[m2 * ldc + last_blk_col];
gp.b_block_cols = 1;
KernelInfo::kernel[kernel_nrows](&gp);
} else {
// small temporary buffer
float c_tmp[16 * 24] = {0};
assert((16 * 24) > kernel_nrows * kernel_ncols);

gp.B = &(Bp(k_ind, last_blk_col));
gp.C = c_tmp;
gp.ldc = 8 * sizeof(C[0]);
gp.b_block_cols = 1;
KernelInfo::kernel[kernel_nrows](&gp);
for (int i = 0; i < kernel_nrows; i++) {
// Todo: use assembly
for (int j = last_blk_col; j < n; j++) {
assert(
i * 8 + (j - last_blk_col) <
sizeof(c_tmp) / sizeof(c_tmp[0]));
if (accum == 0) {
C[(m2 + i) * ldc + j] = c_tmp[i * 8 + (j - last_blk_col)];
} else {
C[(m2 + i) * ldc + j] = beta_ * C[(m2 + i) * ldc + j] +
c_tmp[i * 8 + (j - last_blk_col)];
gp.B = &(Bp(k_ind, last_blk_col));
gp.C = c_tmp;
gp.ldc = 8 * sizeof(C[0]);
gp.b_block_cols = 1;
KernelInfo::kernel[kernel_nrows](&gp);
for (int i = 0; i < kernel_nrows; i++) {
// Todo: use assembly
for (int j = last_blk_col; j < n; j++) {
assert(
i * 8 + (j - last_blk_col) <
sizeof(c_tmp) / sizeof(c_tmp[0]));
if (accum == 0) {
C[(m2 + i) * ldc + j] = c_tmp[i * 8 + (j - last_blk_col)];
} else {
C[(m2 + i) * ldc + j] = beta_ * C[(m2 + i) * ldc + j] +
c_tmp[i * 8 + (j - last_blk_col)];
}
}
}
}
Expand Down
16 changes: 15 additions & 1 deletion test/FP16Test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
*/
#include <random>

#ifdef _OPENMP
#include <omp.h>
#endif

#include <gtest/gtest.h>

#include "TestUtils.h"
Expand Down Expand Up @@ -97,7 +101,17 @@ TEST_P(FBGemmFP16Test, Test) {

// fbgemm fp16
PackedGemmMatrixFP16 Bp(btrans, k, n, alpha, B.data());
cblas_gemm_compute(atrans, m, A.data(), Bp, beta, C.data());

#ifdef _OPENMP
#pragma omp parallel
#endif
{
int num_threads = fbgemm_get_num_threads();
int tid = fbgemm_get_thread_num();

cblas_gemm_compute(
atrans, m, A.data(), Bp, beta, C.data(), tid, num_threads);
}

// correctness check
for (int i = 0; i < m; ++i) {
Expand Down

0 comments on commit 03a8fa5

Please sign in to comment.