Skip to content

Commit

Permalink
modified instrinsic fp16 kernel for windows build (pytorch#259)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#259

As title

Reviewed By: shz0116

Differential Revision: D19460978

fbshipit-source-id: 7b92980f1ba0f6e5b3ff8f43353e2acdebf96808
  • Loading branch information
jspark1105 committed Mar 21, 2020
1 parent 04e4804 commit c373524
Show file tree
Hide file tree
Showing 12 changed files with 402 additions and 340 deletions.
19 changes: 14 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -109,12 +109,21 @@ set(FBGEMM_AVX512_SRCS
#for MSVC is different. Also, MSVC doesn't support inline assembly
#for x64 builds. We fallback to default slower kernel for MSVC builds
#for now.
list(APPEND FBGEMM_AVX2_SRCS
src/FbgemmFP16UKernelsAvx2.cc)
if(MSVC)
list(APPEND FBGEMM_AVX2_SRCS
src/FbgemmFP16UKernelsIntrinsicAvx2.cc)

list(APPEND FBGEMM_AVX512_SRCS
src/FbgemmFP16UKernelsIntrinsicAvx512.cc
src/FbgemmFP16UKernelsIntrinsicAvx512_256.cc)
else()
list(APPEND FBGEMM_AVX2_SRCS
src/FbgemmFP16UKernelsAvx2.cc)

list(APPEND FBGEMM_AVX512_SRCS
src/FbgemmFP16UKernelsAvx512.cc
src/FbgemmFP16UKernelsAvx512_256.cc)
list(APPEND FBGEMM_AVX512_SRCS
src/FbgemmFP16UKernelsAvx512.cc
src/FbgemmFP16UKernelsAvx512_256.cc)
endif()

set(FBGEMM_PUBLIC_HEADERS include/fbgemm/Fbgemm.h
include/fbgemm/FbgemmBuild.h
Expand Down
10 changes: 1 addition & 9 deletions bench/FP16Benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -365,15 +365,7 @@ int main(int argc, const char* argv[]) {
if (num_instances > 1) {
// Set-up execution for multi-instance mode
// Number of threads in OpenMP parallel region is explicitly
// set to the number of instances to be executed
// If not previosly set by KMP_AFFINITY env. variable
// threads are affinitized sequentially to logical processors
char env_var[1024];
sprintf(
env_var, "granularity=fine,explicit,proclist=[1-%d]", num_instances);
#ifndef _MSC_VER
setenv("KMP_AFFINITY", env_var, 0); // Don't overide if already set
#endif
// set to the number of instances to be executed.
omp_set_num_threads(num_instances);
#ifdef USE_MKL
// each instance should be run with a single thread
Expand Down
4 changes: 2 additions & 2 deletions bench/PackedRequantizeAcc16Benchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@
#include <cmath>
#include <iomanip>
#include <iostream>
#include <numeric>
#include <random>
#include <vector>
#include <numeric>

#ifdef _OPENMP
#include <omp.h>
Expand Down Expand Up @@ -156,7 +156,7 @@ void performance_test() {
},
NWARMUP,
NITER,
[&] () {
[&]() {
if (flush) {
llc_flush(llc);
}
Expand Down
1 change: 1 addition & 0 deletions include/fbgemm/FbgemmBuild.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
#if __clang__ || __GNUC__ >= 4 || __INTEL_COMPILER
#define ALWAYS_INLINE inline __attribute__((__always_inline__))
#elif _MSC_VER
// commenting out because __forceinline takes too long time in MSVC
#define ALWAYS_INLINE // __forceinline
#else
#define ALWAYS_INLINE inline
Expand Down
93 changes: 0 additions & 93 deletions src/FbgemmFP16UKernelsAvx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,9 @@
* LICENSE file in the root directory of this source tree.
*/
#include "./FbgemmFP16UKernelsAvx2.h"
#ifdef _MSC_VER
#include <immintrin.h>
#endif

namespace fbgemm {

#ifndef _MSC_VER
void NOINLINE
gemmkernel_1x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
asm volatile(
Expand Down Expand Up @@ -1045,94 +1041,5 @@ gemmkernel_6x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
"r15",
"memory");
}
#else // _MSC_VER
// Intrinsic kernel for MSVC
void gemmkernel_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp, const size_t kernel_nrows) {
// register buffer
__m256 ymmSum[12];
size_t idxA = 0, idxB = 0, idxC = 0;
// ldc in float size
size_t ldc_floatsize = gp->ldc / sizeof(float);
// load beta
__m256 ymmBeta;
if (gp->beta != 0)
ymmBeta = _mm256_broadcast_ss(&gp->beta);

// outer loop - block columns
for(uint64_t ii = 0; ii < gp->b_block_cols; ii++) {
// reset index
idxA = 0;
// inner loop - k
for(uint64_t kk = 0; kk < gp->k; kk++) {
// load B
__m256 ymmB0 = _mm256_cvtph_ps(_mm_load_si128((__m128i*)(gp->B + idxB)));
__m256 ymmB1 = _mm256_cvtph_ps(_mm_load_si128((__m128i*)(gp->B + idxB + 8)));
idxB += 16;

// first element
if (kk == 0) {
if(gp->beta != 0) { // accumulate
for(size_t jj = 0; jj < kernel_nrows; jj++) {
// load A
__m256 ymmA = _mm256_broadcastss_ps(_mm_broadcast_ss((float const*)(gp->A + idxA + jj)));
// C = A * B + beta * C
ymmSum[2 * jj] = _mm256_fmadd_ps(ymmA, ymmB0, _mm256_mul_ps(ymmBeta, _mm256_loadu_ps(gp->C + idxC + jj * ldc_floatsize)));
ymmSum[2 * jj + 1] = _mm256_fmadd_ps(ymmA, ymmB1, _mm256_mul_ps(ymmBeta, _mm256_loadu_ps(gp->C + idxC + 8 + jj * ldc_floatsize)));
}
idxA += kernel_nrows;
} else { // set zero
for(size_t jj = 0; jj < kernel_nrows; jj++) {
// load A
__m256 ymmA = _mm256_broadcastss_ps(_mm_broadcast_ss((float const*)(gp->A + idxA + jj)));
// C = A * B
ymmSum[2 * jj] = _mm256_mul_ps(ymmA, ymmB0);
ymmSum[2 * jj + 1] = _mm256_mul_ps(ymmA, ymmB1);
}
idxA += kernel_nrows;
}
} else {
for(size_t jj = 0; jj < kernel_nrows; jj++) {
// load A
__m256 ymmA = _mm256_broadcastss_ps(_mm_broadcast_ss((float const*)(gp->A + idxA + jj)));
// C = A * B + C
ymmSum[2 * jj] = _mm256_fmadd_ps(ymmA, ymmB0, ymmSum[2 * jj]);
ymmSum[2 * jj + 1] = _mm256_fmadd_ps(ymmA, ymmB1, ymmSum[2 * jj + 1]);
}
idxA += kernel_nrows;
}
}
// store C
for(size_t jj = 0; jj < kernel_nrows; jj++) {
_mm256_storeu_ps(gp->C + idxC + jj * ldc_floatsize, ymmSum[2 * jj]);
_mm256_storeu_ps(gp->C + idxC + 8 + jj * ldc_floatsize, ymmSum[2 * jj + 1]);
}
idxC += 16;
}
}

void NOINLINE
gemmkernel_1x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx2_fp16_fA0fB0fC0(gp, 1);
}
void NOINLINE
gemmkernel_2x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx2_fp16_fA0fB0fC0(gp, 2);
}
void NOINLINE
gemmkernel_3x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx2_fp16_fA0fB0fC0(gp, 3);
}
void NOINLINE
gemmkernel_4x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx2_fp16_fA0fB0fC0(gp, 4);
}
void NOINLINE
gemmkernel_5x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx2_fp16_fA0fB0fC0(gp, 5);
}
void NOINLINE
gemmkernel_6x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx2_fp16_fA0fB0fC0(gp, 6);
}
#endif // _MSC_VER
} // namespace fbgemm
126 changes: 0 additions & 126 deletions src/FbgemmFP16UKernelsAvx512.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,9 @@
* LICENSE file in the root directory of this source tree.
*/
#include "./FbgemmFP16UKernelsAvx512.h"
#ifdef _MSC_VER
#include <immintrin.h>
#endif

namespace fbgemm {

#ifndef _MSC_VER
void NOINLINE
gemmkernel_1x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
asm volatile(
Expand Down Expand Up @@ -3489,127 +3485,5 @@ gemmkernel_14x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
"r15",
"memory");
}
#else // _MSC_VER
// Intrinsic kernel for MSVC
void gemmkernel_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp, const size_t kernel_nrows) {
// register buffer
__m512 zmmSum[28];
size_t idxA = 0, idxB = 0, idxC = 0;
// ldc in float size
size_t ldc_floatsize = gp->ldc / sizeof(float);
// load beta
__m512 zmmBeta;
if (gp->beta != 0)
zmmBeta = _mm512_broadcastss_ps(_mm_broadcast_ss(&gp->beta));

// outer loop - block columns
for(uint64_t ii = 0; ii < gp->b_block_cols; ii++) {
// reset index
idxA = 0;
// inner loop - k
for(uint64_t kk = 0; kk < gp->k; kk++) {
// load B
__m512 zmmB0 = _mm512_cvtph_ps(_mm256_load_si256((__m256i*)(gp->B + idxB)));
__m512 zmmB1 = _mm512_cvtph_ps(_mm256_load_si256((__m256i*)(gp->B + idxB + 16)));
idxB += 32;

// first element
if (kk == 0) {
if(gp->beta != 0) { // accumulate
for(size_t jj = 0; jj < kernel_nrows; jj++) {
// load A
__m512 zmmA = _mm512_broadcastss_ps(_mm_broadcast_ss((float const*)(gp->A + idxA + jj)));
// C = A * B + beta * C
zmmSum[2 * jj] = _mm512_fmadd_ps(zmmA, zmmB0, _mm512_mul_ps(zmmBeta, _mm512_loadu_ps(gp->C + idxC + jj * ldc_floatsize)));
zmmSum[2 * jj + 1] = _mm512_fmadd_ps(zmmA, zmmB1, _mm512_mul_ps(zmmBeta, _mm512_loadu_ps(gp->C + idxC + 16 + jj * ldc_floatsize)));
}
idxA += kernel_nrows;
} else { // set zero
for(size_t jj = 0; jj < kernel_nrows; jj++) {
// load A
__m512 zmmA = _mm512_broadcastss_ps(_mm_broadcast_ss((float const*)(gp->A + idxA + jj)));
// C = A * B
zmmSum[2 * jj] = _mm512_mul_ps(zmmA, zmmB0);
zmmSum[2 * jj + 1] = _mm512_mul_ps(zmmA, zmmB1);
}
idxA += kernel_nrows;
}
} else {
for(size_t jj = 0; jj < kernel_nrows; jj++) {
// load A
__m512 zmmA = _mm512_broadcastss_ps(_mm_broadcast_ss((float const*)(gp->A + idxA + jj)));
// C = A * B + C
zmmSum[2 * jj] = _mm512_fmadd_ps(zmmA, zmmB0, zmmSum[2 * jj]);
zmmSum[2 * jj + 1] = _mm512_fmadd_ps(zmmA, zmmB1, zmmSum[2 * jj + 1]);
}
idxA += kernel_nrows;
}
}
// store C
for(size_t jj = 0; jj < kernel_nrows; jj++) {
_mm512_storeu_ps(gp->C + idxC + jj * ldc_floatsize, zmmSum[2 * jj]);
_mm512_storeu_ps(gp->C + idxC + 16 + jj * ldc_floatsize, zmmSum[2 * jj + 1]);
}
idxC += 32;
}
}

void NOINLINE
gemmkernel_1x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 1);
}
void NOINLINE
gemmkernel_2x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 2);
}
void NOINLINE
gemmkernel_3x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 3);
}
void NOINLINE
gemmkernel_4x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 4);
}
void NOINLINE
gemmkernel_5x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 5);
}
void NOINLINE
gemmkernel_6x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 6);
}
void NOINLINE
gemmkernel_7x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 7);
}
void NOINLINE
gemmkernel_8x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 8);
}
void NOINLINE
gemmkernel_9x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 9);
}
void NOINLINE
gemmkernel_10x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 10);
}
void NOINLINE
gemmkernel_11x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 11);
}
void NOINLINE
gemmkernel_12x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 12);
}
void NOINLINE
gemmkernel_13x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 13);
}
void NOINLINE
gemmkernel_14x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 14);
}
#endif // _MSC_VER

} // namespace fbgemm
Loading

0 comments on commit c373524

Please sign in to comment.