diff --git a/CMakeLists.txt b/CMakeLists.txt index c0ae7cb250..65aeaa6cf4 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -109,12 +109,21 @@ set(FBGEMM_AVX512_SRCS #for MSVC is different. Also, MSVC doesn't support inline assembly #for x64 builds. We fallback to default slower kernel for MSVC builds #for now. -list(APPEND FBGEMM_AVX2_SRCS - src/FbgemmFP16UKernelsAvx2.cc) +if(MSVC) + list(APPEND FBGEMM_AVX2_SRCS + src/FbgemmFP16UKernelsIntrinsicAvx2.cc) + + list(APPEND FBGEMM_AVX512_SRCS + src/FbgemmFP16UKernelsIntrinsicAvx512.cc + src/FbgemmFP16UKernelsIntrinsicAvx512_256.cc) +else() + list(APPEND FBGEMM_AVX2_SRCS + src/FbgemmFP16UKernelsAvx2.cc) -list(APPEND FBGEMM_AVX512_SRCS - src/FbgemmFP16UKernelsAvx512.cc - src/FbgemmFP16UKernelsAvx512_256.cc) + list(APPEND FBGEMM_AVX512_SRCS + src/FbgemmFP16UKernelsAvx512.cc + src/FbgemmFP16UKernelsAvx512_256.cc) +endif() set(FBGEMM_PUBLIC_HEADERS include/fbgemm/Fbgemm.h include/fbgemm/FbgemmBuild.h diff --git a/bench/FP16Benchmark.cc b/bench/FP16Benchmark.cc index 5d693ede53..2c1556d286 100644 --- a/bench/FP16Benchmark.cc +++ b/bench/FP16Benchmark.cc @@ -365,15 +365,7 @@ int main(int argc, const char* argv[]) { if (num_instances > 1) { // Set-up execution for multi-instance mode // Number of threads in OpenMP parallel region is explicitly - // set to the number of instances to be executed - // If not previosly set by KMP_AFFINITY env. variable - // threads are affinitized sequentially to logical processors - char env_var[1024]; - sprintf( - env_var, "granularity=fine,explicit,proclist=[1-%d]", num_instances); -#ifndef _MSC_VER - setenv("KMP_AFFINITY", env_var, 0); // Don't overide if already set -#endif + // set to the number of instances to be executed. omp_set_num_threads(num_instances); #ifdef USE_MKL // each instance should be run with a single thread diff --git a/bench/PackedRequantizeAcc16Benchmark.cc b/bench/PackedRequantizeAcc16Benchmark.cc index 3553024192..4b472a4c78 100644 --- a/bench/PackedRequantizeAcc16Benchmark.cc +++ b/bench/PackedRequantizeAcc16Benchmark.cc @@ -9,9 +9,9 @@ #include #include #include +#include #include #include -#include #ifdef _OPENMP #include @@ -156,7 +156,7 @@ void performance_test() { }, NWARMUP, NITER, - [&] () { + [&]() { if (flush) { llc_flush(llc); } diff --git a/include/fbgemm/FbgemmBuild.h b/include/fbgemm/FbgemmBuild.h index 99b8508628..72709c731a 100644 --- a/include/fbgemm/FbgemmBuild.h +++ b/include/fbgemm/FbgemmBuild.h @@ -55,6 +55,7 @@ #if __clang__ || __GNUC__ >= 4 || __INTEL_COMPILER #define ALWAYS_INLINE inline __attribute__((__always_inline__)) #elif _MSC_VER + // commenting out because __forceinline takes too long time in MSVC #define ALWAYS_INLINE // __forceinline #else #define ALWAYS_INLINE inline diff --git a/src/FbgemmFP16UKernelsAvx2.cc b/src/FbgemmFP16UKernelsAvx2.cc index f23445c56c..0670295b6c 100644 --- a/src/FbgemmFP16UKernelsAvx2.cc +++ b/src/FbgemmFP16UKernelsAvx2.cc @@ -5,13 +5,9 @@ * LICENSE file in the root directory of this source tree. */ #include "./FbgemmFP16UKernelsAvx2.h" -#ifdef _MSC_VER -#include -#endif namespace fbgemm { -#ifndef _MSC_VER void NOINLINE gemmkernel_1x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) { asm volatile( @@ -1045,94 +1041,5 @@ gemmkernel_6x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) { "r15", "memory"); } -#else // _MSC_VER -// Intrinsic kernel for MSVC -void gemmkernel_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp, const size_t kernel_nrows) { - // register buffer - __m256 ymmSum[12]; - size_t idxA = 0, idxB = 0, idxC = 0; - // ldc in float size - size_t ldc_floatsize = gp->ldc / sizeof(float); - // load beta - __m256 ymmBeta; - if (gp->beta != 0) - ymmBeta = _mm256_broadcast_ss(&gp->beta); - - // outer loop - block columns - for(uint64_t ii = 0; ii < gp->b_block_cols; ii++) { - // reset index - idxA = 0; - // inner loop - k - for(uint64_t kk = 0; kk < gp->k; kk++) { - // load B - __m256 ymmB0 = _mm256_cvtph_ps(_mm_load_si128((__m128i*)(gp->B + idxB))); - __m256 ymmB1 = _mm256_cvtph_ps(_mm_load_si128((__m128i*)(gp->B + idxB + 8))); - idxB += 16; - - // first element - if (kk == 0) { - if(gp->beta != 0) { // accumulate - for(size_t jj = 0; jj < kernel_nrows; jj++) { - // load A - __m256 ymmA = _mm256_broadcastss_ps(_mm_broadcast_ss((float const*)(gp->A + idxA + jj))); - // C = A * B + beta * C - ymmSum[2 * jj] = _mm256_fmadd_ps(ymmA, ymmB0, _mm256_mul_ps(ymmBeta, _mm256_loadu_ps(gp->C + idxC + jj * ldc_floatsize))); - ymmSum[2 * jj + 1] = _mm256_fmadd_ps(ymmA, ymmB1, _mm256_mul_ps(ymmBeta, _mm256_loadu_ps(gp->C + idxC + 8 + jj * ldc_floatsize))); - } - idxA += kernel_nrows; - } else { // set zero - for(size_t jj = 0; jj < kernel_nrows; jj++) { - // load A - __m256 ymmA = _mm256_broadcastss_ps(_mm_broadcast_ss((float const*)(gp->A + idxA + jj))); - // C = A * B - ymmSum[2 * jj] = _mm256_mul_ps(ymmA, ymmB0); - ymmSum[2 * jj + 1] = _mm256_mul_ps(ymmA, ymmB1); - } - idxA += kernel_nrows; - } - } else { - for(size_t jj = 0; jj < kernel_nrows; jj++) { - // load A - __m256 ymmA = _mm256_broadcastss_ps(_mm_broadcast_ss((float const*)(gp->A + idxA + jj))); - // C = A * B + C - ymmSum[2 * jj] = _mm256_fmadd_ps(ymmA, ymmB0, ymmSum[2 * jj]); - ymmSum[2 * jj + 1] = _mm256_fmadd_ps(ymmA, ymmB1, ymmSum[2 * jj + 1]); - } - idxA += kernel_nrows; - } - } - // store C - for(size_t jj = 0; jj < kernel_nrows; jj++) { - _mm256_storeu_ps(gp->C + idxC + jj * ldc_floatsize, ymmSum[2 * jj]); - _mm256_storeu_ps(gp->C + idxC + 8 + jj * ldc_floatsize, ymmSum[2 * jj + 1]); - } - idxC += 16; - } -} -void NOINLINE -gemmkernel_1x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) { - gemmkernel_Avx2_fp16_fA0fB0fC0(gp, 1); -} -void NOINLINE -gemmkernel_2x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) { - gemmkernel_Avx2_fp16_fA0fB0fC0(gp, 2); -} -void NOINLINE -gemmkernel_3x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) { - gemmkernel_Avx2_fp16_fA0fB0fC0(gp, 3); -} -void NOINLINE -gemmkernel_4x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) { - gemmkernel_Avx2_fp16_fA0fB0fC0(gp, 4); -} -void NOINLINE -gemmkernel_5x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) { - gemmkernel_Avx2_fp16_fA0fB0fC0(gp, 5); -} -void NOINLINE -gemmkernel_6x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) { - gemmkernel_Avx2_fp16_fA0fB0fC0(gp, 6); -} -#endif // _MSC_VER } // namespace fbgemm diff --git a/src/FbgemmFP16UKernelsAvx512.cc b/src/FbgemmFP16UKernelsAvx512.cc index 4f2b28abf8..e1e6e6064a 100644 --- a/src/FbgemmFP16UKernelsAvx512.cc +++ b/src/FbgemmFP16UKernelsAvx512.cc @@ -5,13 +5,9 @@ * LICENSE file in the root directory of this source tree. */ #include "./FbgemmFP16UKernelsAvx512.h" -#ifdef _MSC_VER -#include -#endif namespace fbgemm { -#ifndef _MSC_VER void NOINLINE gemmkernel_1x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { asm volatile( @@ -3489,127 +3485,5 @@ gemmkernel_14x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { "r15", "memory"); } -#else // _MSC_VER -// Intrinsic kernel for MSVC -void gemmkernel_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp, const size_t kernel_nrows) { - // register buffer - __m512 zmmSum[28]; - size_t idxA = 0, idxB = 0, idxC = 0; - // ldc in float size - size_t ldc_floatsize = gp->ldc / sizeof(float); - // load beta - __m512 zmmBeta; - if (gp->beta != 0) - zmmBeta = _mm512_broadcastss_ps(_mm_broadcast_ss(&gp->beta)); - - // outer loop - block columns - for(uint64_t ii = 0; ii < gp->b_block_cols; ii++) { - // reset index - idxA = 0; - // inner loop - k - for(uint64_t kk = 0; kk < gp->k; kk++) { - // load B - __m512 zmmB0 = _mm512_cvtph_ps(_mm256_load_si256((__m256i*)(gp->B + idxB))); - __m512 zmmB1 = _mm512_cvtph_ps(_mm256_load_si256((__m256i*)(gp->B + idxB + 16))); - idxB += 32; - - // first element - if (kk == 0) { - if(gp->beta != 0) { // accumulate - for(size_t jj = 0; jj < kernel_nrows; jj++) { - // load A - __m512 zmmA = _mm512_broadcastss_ps(_mm_broadcast_ss((float const*)(gp->A + idxA + jj))); - // C = A * B + beta * C - zmmSum[2 * jj] = _mm512_fmadd_ps(zmmA, zmmB0, _mm512_mul_ps(zmmBeta, _mm512_loadu_ps(gp->C + idxC + jj * ldc_floatsize))); - zmmSum[2 * jj + 1] = _mm512_fmadd_ps(zmmA, zmmB1, _mm512_mul_ps(zmmBeta, _mm512_loadu_ps(gp->C + idxC + 16 + jj * ldc_floatsize))); - } - idxA += kernel_nrows; - } else { // set zero - for(size_t jj = 0; jj < kernel_nrows; jj++) { - // load A - __m512 zmmA = _mm512_broadcastss_ps(_mm_broadcast_ss((float const*)(gp->A + idxA + jj))); - // C = A * B - zmmSum[2 * jj] = _mm512_mul_ps(zmmA, zmmB0); - zmmSum[2 * jj + 1] = _mm512_mul_ps(zmmA, zmmB1); - } - idxA += kernel_nrows; - } - } else { - for(size_t jj = 0; jj < kernel_nrows; jj++) { - // load A - __m512 zmmA = _mm512_broadcastss_ps(_mm_broadcast_ss((float const*)(gp->A + idxA + jj))); - // C = A * B + C - zmmSum[2 * jj] = _mm512_fmadd_ps(zmmA, zmmB0, zmmSum[2 * jj]); - zmmSum[2 * jj + 1] = _mm512_fmadd_ps(zmmA, zmmB1, zmmSum[2 * jj + 1]); - } - idxA += kernel_nrows; - } - } - // store C - for(size_t jj = 0; jj < kernel_nrows; jj++) { - _mm512_storeu_ps(gp->C + idxC + jj * ldc_floatsize, zmmSum[2 * jj]); - _mm512_storeu_ps(gp->C + idxC + 16 + jj * ldc_floatsize, zmmSum[2 * jj + 1]); - } - idxC += 32; - } -} - -void NOINLINE -gemmkernel_1x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { - gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 1); -} -void NOINLINE -gemmkernel_2x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { - gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 2); -} -void NOINLINE -gemmkernel_3x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { - gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 3); -} -void NOINLINE -gemmkernel_4x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { - gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 4); -} -void NOINLINE -gemmkernel_5x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { - gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 5); -} -void NOINLINE -gemmkernel_6x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { - gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 6); -} -void NOINLINE -gemmkernel_7x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { - gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 7); -} -void NOINLINE -gemmkernel_8x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { - gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 8); -} -void NOINLINE -gemmkernel_9x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { - gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 9); -} -void NOINLINE -gemmkernel_10x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { - gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 10); -} -void NOINLINE -gemmkernel_11x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { - gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 11); -} -void NOINLINE -gemmkernel_12x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { - gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 12); -} -void NOINLINE -gemmkernel_13x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { - gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 13); -} -void NOINLINE -gemmkernel_14x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { - gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 14); -} -#endif // _MSC_VER } // namespace fbgemm diff --git a/src/FbgemmFP16UKernelsAvx512_256.cc b/src/FbgemmFP16UKernelsAvx512_256.cc index 07b5288f08..dccc5c284a 100644 --- a/src/FbgemmFP16UKernelsAvx512_256.cc +++ b/src/FbgemmFP16UKernelsAvx512_256.cc @@ -5,13 +5,9 @@ * LICENSE file in the root directory of this source tree. */ #include "./FbgemmFP16UKernelsAvx512_256.h" -#ifdef _MSC_VER -#include -#endif namespace fbgemm { -#ifndef _MSC_VER void NOINLINE gemmkernel_7x2_Avx512_256_fp16_fA0fB0fC0(GemmParamsFP16* gp) { asm volatile( @@ -2456,102 +2452,5 @@ gemmkernel_14x2_Avx512_256_fp16_fA0fB0fC0(GemmParamsFP16* gp) { "r15", "memory"); } -#else // _MSC_VER -// Intrinsic kernel for MSVC -void gemmkernel_Avx512_256_fp16_fA0fB0fC0(GemmParamsFP16* gp, const size_t kernel_nrows) { - // register buffer - __m256 ymmSum[28]; - size_t idxA = 0, idxB = 0, idxC = 0; - // ldc in float size - size_t ldc_floatsize = gp->ldc / sizeof(float); - // load beta - __m256 ymmBeta; - if (gp->beta != 0) - ymmBeta = _mm256_broadcast_ss(&gp->beta); - - // outer loop - block columns - for(uint64_t ii = 0; ii < gp->b_block_cols; ii++) { - // reset index - idxA = 0; - // inner loop - k - for(uint64_t kk = 0; kk < gp->k; kk++) { - // load B - __m256 ymmB0 = _mm256_cvtph_ps(_mm_load_si128((__m128i*)(gp->B + idxB))); - __m256 ymmB1 = _mm256_cvtph_ps(_mm_load_si128((__m128i*)(gp->B + idxB + 8))); - idxB += 16; - - // first element - if (kk == 0) { - if(gp->beta != 0) { // accumulate - for(size_t jj = 0; jj < kernel_nrows; jj++) { - // load A - __m256 ymmA = _mm256_broadcastss_ps(_mm_broadcast_ss((float const*)(gp->A + idxA + jj))); - // C = A * B + beta * C - ymmSum[2 * jj] = _mm256_fmadd_ps(ymmA, ymmB0, _mm256_mul_ps(ymmBeta, _mm256_loadu_ps(gp->C + idxC + jj * ldc_floatsize))); - ymmSum[2 * jj + 1] = _mm256_fmadd_ps(ymmA, ymmB1, _mm256_mul_ps(ymmBeta, _mm256_loadu_ps(gp->C + idxC + 8 + jj * ldc_floatsize))); - } - idxA += kernel_nrows; - } else { // set zero - for(size_t jj = 0; jj < kernel_nrows; jj++) { - // load A - __m256 ymmA = _mm256_broadcastss_ps(_mm_broadcast_ss((float const*)(gp->A + idxA + jj))); - // C = A * B - ymmSum[2 * jj] = _mm256_mul_ps(ymmA, ymmB0); - ymmSum[2 * jj + 1] = _mm256_mul_ps(ymmA, ymmB1); - } - idxA += kernel_nrows; - } - } else { - for(size_t jj = 0; jj < kernel_nrows; jj++) { - // load A - __m256 ymmA = _mm256_broadcastss_ps(_mm_broadcast_ss((float const*)(gp->A + idxA + jj))); - // C = A * B + C - ymmSum[2 * jj] = _mm256_fmadd_ps(ymmA, ymmB0, ymmSum[2 * jj]); - ymmSum[2 * jj + 1] = _mm256_fmadd_ps(ymmA, ymmB1, ymmSum[2 * jj + 1]); - } - idxA += kernel_nrows; - } - } - // store C - for(size_t jj = 0; jj < kernel_nrows; jj++) { - _mm256_storeu_ps(gp->C + idxC + jj * ldc_floatsize, ymmSum[2 * jj]); - _mm256_storeu_ps(gp->C + idxC + 8 + jj * ldc_floatsize, ymmSum[2 * jj + 1]); - } - idxC += 16; - } -} -void NOINLINE -gemmkernel_7x2_Avx512_256_fp16_fA0fB0fC0(GemmParamsFP16* gp) { - gemmkernel_Avx512_256_fp16_fA0fB0fC0(gp, 7); -} -void NOINLINE -gemmkernel_8x2_Avx512_256_fp16_fA0fB0fC0(GemmParamsFP16* gp) { - gemmkernel_Avx512_256_fp16_fA0fB0fC0(gp, 8); -} -void NOINLINE -gemmkernel_9x2_Avx512_256_fp16_fA0fB0fC0(GemmParamsFP16* gp) { - gemmkernel_Avx512_256_fp16_fA0fB0fC0(gp, 9); -} -void NOINLINE -gemmkernel_10x2_Avx512_256_fp16_fA0fB0fC0(GemmParamsFP16* gp) { - gemmkernel_Avx512_256_fp16_fA0fB0fC0(gp, 10); -} -void NOINLINE -gemmkernel_11x2_Avx512_256_fp16_fA0fB0fC0(GemmParamsFP16* gp) { - gemmkernel_Avx512_256_fp16_fA0fB0fC0(gp, 11); -} -void NOINLINE -gemmkernel_12x2_Avx512_256_fp16_fA0fB0fC0(GemmParamsFP16* gp) { - gemmkernel_Avx512_256_fp16_fA0fB0fC0(gp, 12); -} -void NOINLINE -gemmkernel_13x2_Avx512_256_fp16_fA0fB0fC0(GemmParamsFP16* gp) { - gemmkernel_Avx512_256_fp16_fA0fB0fC0(gp, 13); -} -void NOINLINE -gemmkernel_14x2_Avx512_256_fp16_fA0fB0fC0(GemmParamsFP16* gp) { - gemmkernel_Avx512_256_fp16_fA0fB0fC0(gp, 14); -} -#endif // _MSC_VER } // namespace fbgemm diff --git a/src/FbgemmFP16UKernelsIntrinsicAvx2.cc b/src/FbgemmFP16UKernelsIntrinsicAvx2.cc new file mode 100644 index 0000000000..8785661c80 --- /dev/null +++ b/src/FbgemmFP16UKernelsIntrinsicAvx2.cc @@ -0,0 +1,121 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifdef _MSC_VER +#include "./FbgemmFP16UKernelsAvx2.h" +#include + +namespace fbgemm { + +// Intrinsic kernel for MSVC +void gemmkernel_Avx2_fp16_fA0fB0fC0( + GemmParamsFP16* gp, + const size_t kernel_nrows) { + // register buffer + __m256 ymmSum[12]; + size_t idxA = 0, idxB = 0, idxC = 0; + // ldc in float size + size_t ldc_floatsize = gp->ldc / sizeof(float); + // load beta + __m256 ymmBeta; + if (gp->beta != 0) + ymmBeta = _mm256_broadcast_ss(&gp->beta); + + // outer loop - block columns + for (uint64_t ii = 0; ii < gp->b_block_cols; ii++) { + // reset index + idxA = 0; + // inner loop - k + for (uint64_t kk = 0; kk < gp->k; kk++) { + // load B + __m256 ymmB0 = _mm256_cvtph_ps(_mm_load_si128((__m128i*)(gp->B + idxB))); + __m256 ymmB1 = + _mm256_cvtph_ps(_mm_load_si128((__m128i*)(gp->B + idxB + 8))); + idxB += 16; + + // first element + if (kk == 0) { + if (gp->beta != 0) { // accumulate + for (size_t jj = 0; jj < kernel_nrows; jj++) { + // load A + __m256 ymmA = _mm256_broadcastss_ps( + _mm_broadcast_ss((float const*)(gp->A + idxA + jj))); + // C = A * B + beta * C + ymmSum[2 * jj] = _mm256_fmadd_ps( + ymmA, + ymmB0, + _mm256_mul_ps( + ymmBeta, + _mm256_loadu_ps(gp->C + idxC + jj * ldc_floatsize))); + ymmSum[2 * jj + 1] = _mm256_fmadd_ps( + ymmA, + ymmB1, + _mm256_mul_ps( + ymmBeta, + _mm256_loadu_ps(gp->C + idxC + 8 + jj * ldc_floatsize))); + } + idxA += kernel_nrows; + } else { // set zero + for (size_t jj = 0; jj < kernel_nrows; jj++) { + // load A + __m256 ymmA = _mm256_broadcastss_ps( + _mm_broadcast_ss((float const*)(gp->A + idxA + jj))); + // C = A * B + ymmSum[2 * jj] = _mm256_mul_ps(ymmA, ymmB0); + ymmSum[2 * jj + 1] = _mm256_mul_ps(ymmA, ymmB1); + } + idxA += kernel_nrows; + } + } else { + for (size_t jj = 0; jj < kernel_nrows; jj++) { + // load A + __m256 ymmA = _mm256_broadcastss_ps( + _mm_broadcast_ss((float const*)(gp->A + idxA + jj))); + // C = A * B + C + ymmSum[2 * jj] = _mm256_fmadd_ps(ymmA, ymmB0, ymmSum[2 * jj]); + ymmSum[2 * jj + 1] = _mm256_fmadd_ps(ymmA, ymmB1, ymmSum[2 * jj + 1]); + } + idxA += kernel_nrows; + } + } + // store C + for (size_t jj = 0; jj < kernel_nrows; jj++) { + _mm256_storeu_ps(gp->C + idxC + jj * ldc_floatsize, ymmSum[2 * jj]); + _mm256_storeu_ps( + gp->C + idxC + 8 + jj * ldc_floatsize, ymmSum[2 * jj + 1]); + } + idxC += 16; + } +} + +void NOINLINE +gemmkernel_1x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) { + gemmkernel_Avx2_fp16_fA0fB0fC0(gp, 1); +} +void NOINLINE +gemmkernel_2x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) { + gemmkernel_Avx2_fp16_fA0fB0fC0(gp, 2); +} +void NOINLINE +gemmkernel_3x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) { + gemmkernel_Avx2_fp16_fA0fB0fC0(gp, 3); +} +void NOINLINE +gemmkernel_4x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) { + gemmkernel_Avx2_fp16_fA0fB0fC0(gp, 4); +} +void NOINLINE +gemmkernel_5x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) { + gemmkernel_Avx2_fp16_fA0fB0fC0(gp, 5); +} +void NOINLINE +gemmkernel_6x2_Avx2_fp16_fA0fB0fC0(GemmParamsFP16* gp) { + gemmkernel_Avx2_fp16_fA0fB0fC0(gp, 6); +} + +} // namespace fbgemm +#endif // _MSC_VER diff --git a/src/FbgemmFP16UKernelsIntrinsicAvx512.cc b/src/FbgemmFP16UKernelsIntrinsicAvx512.cc new file mode 100644 index 0000000000..7ac34fd046 --- /dev/null +++ b/src/FbgemmFP16UKernelsIntrinsicAvx512.cc @@ -0,0 +1,140 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifdef _MSC_VER +#include "./FbgemmFP16UKernelsAvx512.h" +#include + +namespace fbgemm { + +// Intrinsic kernel for MSVC +void gemmkernel_Avx512_fp16_fA0fB0fC0( + GemmParamsFP16* gp, + const size_t kernel_nrows) { + // register buffer + __m512 zmmSum[28]; + size_t idxA = 0, idxB = 0, idxC = 0; + // ldc in float size + size_t ldc_floatsize = gp->ldc / sizeof(float); + // load beta + __m512 zmmBeta; + if (gp->beta != 0) + zmmBeta = _mm512_broadcastss_ps(_mm_broadcast_ss(&gp->beta)); + + // outer loop - block columns + for (uint64_t ii = 0; ii < gp->b_block_cols; ii++) { + // reset index + idxA = 0; + // inner loop - k + for (uint64_t kk = 0; kk < gp->k; kk++) { + // load B + __m512 zmmB0 = + _mm512_cvtph_ps(_mm256_load_si256((__m256i*)(gp->B + idxB))); + __m512 zmmB1 = + _mm512_cvtph_ps(_mm256_load_si256((__m256i*)(gp->B + idxB + 16))); + idxB += 32; + + // first element + if (kk == 0) { + if (gp->beta != 0) { // accumulate + for (size_t jj = 0; jj < kernel_nrows; jj++) { + // load A + __m512 zmmA = _mm512_broadcastss_ps( + _mm_broadcast_ss((float const*)(gp->A + idxA + jj))); + // C = A * B + beta * C + zmmSum[2 * jj] = _mm512_fmadd_ps( + zmmA, + zmmB0, + _mm512_mul_ps( + zmmBeta, + _mm512_loadu_ps(gp->C + idxC + jj * ldc_floatsize))); + zmmSum[2 * jj + 1] = _mm512_fmadd_ps( + zmmA, + zmmB1, + _mm512_mul_ps( + zmmBeta, + _mm512_loadu_ps(gp->C + idxC + 16 + jj * ldc_floatsize))); + } + idxA += kernel_nrows; + } else { // set zero + for (size_t jj = 0; jj < kernel_nrows; jj++) { + // load A + __m512 zmmA = _mm512_broadcastss_ps( + _mm_broadcast_ss((float const*)(gp->A + idxA + jj))); + // C = A * B + zmmSum[2 * jj] = _mm512_mul_ps(zmmA, zmmB0); + zmmSum[2 * jj + 1] = _mm512_mul_ps(zmmA, zmmB1); + } + idxA += kernel_nrows; + } + } else { + for (size_t jj = 0; jj < kernel_nrows; jj++) { + // load A + __m512 zmmA = _mm512_broadcastss_ps( + _mm_broadcast_ss((float const*)(gp->A + idxA + jj))); + // C = A * B + C + zmmSum[2 * jj] = _mm512_fmadd_ps(zmmA, zmmB0, zmmSum[2 * jj]); + zmmSum[2 * jj + 1] = _mm512_fmadd_ps(zmmA, zmmB1, zmmSum[2 * jj + 1]); + } + idxA += kernel_nrows; + } + } + // store C + for (size_t jj = 0; jj < kernel_nrows; jj++) { + _mm512_storeu_ps(gp->C + idxC + jj * ldc_floatsize, zmmSum[2 * jj]); + _mm512_storeu_ps( + gp->C + idxC + 16 + jj * ldc_floatsize, zmmSum[2 * jj + 1]); + } + idxC += 32; + } +} + +void NOINLINE gemmkernel_1x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { + gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 1); +} +void NOINLINE gemmkernel_2x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { + gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 2); +} +void NOINLINE gemmkernel_3x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { + gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 3); +} +void NOINLINE gemmkernel_4x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { + gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 4); +} +void NOINLINE gemmkernel_5x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { + gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 5); +} +void NOINLINE gemmkernel_6x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { + gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 6); +} +void NOINLINE gemmkernel_7x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { + gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 7); +} +void NOINLINE gemmkernel_8x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { + gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 8); +} +void NOINLINE gemmkernel_9x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { + gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 9); +} +void NOINLINE gemmkernel_10x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { + gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 10); +} +void NOINLINE gemmkernel_11x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { + gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 11); +} +void NOINLINE gemmkernel_12x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { + gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 12); +} +void NOINLINE gemmkernel_13x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { + gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 13); +} +void NOINLINE gemmkernel_14x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) { + gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 14); +} + +} // namespace fbgemm +#endif // _MSC_VER diff --git a/src/FbgemmFP16UKernelsIntrinsicAvx512_256.cc b/src/FbgemmFP16UKernelsIntrinsicAvx512_256.cc new file mode 100644 index 0000000000..6b219d2bb1 --- /dev/null +++ b/src/FbgemmFP16UKernelsIntrinsicAvx512_256.cc @@ -0,0 +1,121 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#ifdef _MSC_VER +#include "./FbgemmFP16UKernelsAvx512_256.h" +#include + +namespace fbgemm { + +// Intrinsic kernel for MSVC +void gemmkernel_Avx512_256_fp16_fA0fB0fC0( + GemmParamsFP16* gp, + const size_t kernel_nrows) { + // register buffer + __m256 ymmSum[28]; + size_t idxA = 0, idxB = 0, idxC = 0; + // ldc in float size + size_t ldc_floatsize = gp->ldc / sizeof(float); + // load beta + __m256 ymmBeta; + if (gp->beta != 0) + ymmBeta = _mm256_broadcast_ss(&gp->beta); + + // outer loop - block columns + for (uint64_t ii = 0; ii < gp->b_block_cols; ii++) { + // reset index + idxA = 0; + // inner loop - k + for (uint64_t kk = 0; kk < gp->k; kk++) { + // load B + __m256 ymmB0 = _mm256_cvtph_ps(_mm_load_si128((__m128i*)(gp->B + idxB))); + __m256 ymmB1 = + _mm256_cvtph_ps(_mm_load_si128((__m128i*)(gp->B + idxB + 8))); + idxB += 16; + + // first element + if (kk == 0) { + if (gp->beta != 0) { // accumulate + for (size_t jj = 0; jj < kernel_nrows; jj++) { + // load A + __m256 ymmA = _mm256_broadcastss_ps( + _mm_broadcast_ss((float const*)(gp->A + idxA + jj))); + // C = A * B + beta * C + ymmSum[2 * jj] = _mm256_fmadd_ps( + ymmA, + ymmB0, + _mm256_mul_ps( + ymmBeta, + _mm256_loadu_ps(gp->C + idxC + jj * ldc_floatsize))); + ymmSum[2 * jj + 1] = _mm256_fmadd_ps( + ymmA, + ymmB1, + _mm256_mul_ps( + ymmBeta, + _mm256_loadu_ps(gp->C + idxC + 8 + jj * ldc_floatsize))); + } + idxA += kernel_nrows; + } else { // set zero + for (size_t jj = 0; jj < kernel_nrows; jj++) { + // load A + __m256 ymmA = _mm256_broadcastss_ps( + _mm_broadcast_ss((float const*)(gp->A + idxA + jj))); + // C = A * B + ymmSum[2 * jj] = _mm256_mul_ps(ymmA, ymmB0); + ymmSum[2 * jj + 1] = _mm256_mul_ps(ymmA, ymmB1); + } + idxA += kernel_nrows; + } + } else { + for (size_t jj = 0; jj < kernel_nrows; jj++) { + // load A + __m256 ymmA = _mm256_broadcastss_ps( + _mm_broadcast_ss((float const*)(gp->A + idxA + jj))); + // C = A * B + C + ymmSum[2 * jj] = _mm256_fmadd_ps(ymmA, ymmB0, ymmSum[2 * jj]); + ymmSum[2 * jj + 1] = _mm256_fmadd_ps(ymmA, ymmB1, ymmSum[2 * jj + 1]); + } + idxA += kernel_nrows; + } + } + // store C + for (size_t jj = 0; jj < kernel_nrows; jj++) { + _mm256_storeu_ps(gp->C + idxC + jj * ldc_floatsize, ymmSum[2 * jj]); + _mm256_storeu_ps( + gp->C + idxC + 8 + jj * ldc_floatsize, ymmSum[2 * jj + 1]); + } + idxC += 16; + } +} + +void NOINLINE gemmkernel_7x2_Avx512_256_fp16_fA0fB0fC0(GemmParamsFP16* gp) { + gemmkernel_Avx512_256_fp16_fA0fB0fC0(gp, 7); +} +void NOINLINE gemmkernel_8x2_Avx512_256_fp16_fA0fB0fC0(GemmParamsFP16* gp) { + gemmkernel_Avx512_256_fp16_fA0fB0fC0(gp, 8); +} +void NOINLINE gemmkernel_9x2_Avx512_256_fp16_fA0fB0fC0(GemmParamsFP16* gp) { + gemmkernel_Avx512_256_fp16_fA0fB0fC0(gp, 9); +} +void NOINLINE gemmkernel_10x2_Avx512_256_fp16_fA0fB0fC0(GemmParamsFP16* gp) { + gemmkernel_Avx512_256_fp16_fA0fB0fC0(gp, 10); +} +void NOINLINE gemmkernel_11x2_Avx512_256_fp16_fA0fB0fC0(GemmParamsFP16* gp) { + gemmkernel_Avx512_256_fp16_fA0fB0fC0(gp, 11); +} +void NOINLINE gemmkernel_12x2_Avx512_256_fp16_fA0fB0fC0(GemmParamsFP16* gp) { + gemmkernel_Avx512_256_fp16_fA0fB0fC0(gp, 12); +} +void NOINLINE gemmkernel_13x2_Avx512_256_fp16_fA0fB0fC0(GemmParamsFP16* gp) { + gemmkernel_Avx512_256_fp16_fA0fB0fC0(gp, 13); +} +void NOINLINE gemmkernel_14x2_Avx512_256_fp16_fA0fB0fC0(GemmParamsFP16* gp) { + gemmkernel_Avx512_256_fp16_fA0fB0fC0(gp, 14); +} + +} // namespace fbgemm +#endif // _MSC_VER diff --git a/test/Im2ColFusedRequantizeTest.cc b/test/Im2ColFusedRequantizeTest.cc index 7e6c167667..99ded3162a 100644 --- a/test/Im2ColFusedRequantizeTest.cc +++ b/test/Im2ColFusedRequantizeTest.cc @@ -6,8 +6,8 @@ */ #include #include -#include #include +#include #ifdef _OPENMP #include diff --git a/test/PackedRequantizeAcc16Test.cc b/test/PackedRequantizeAcc16Test.cc index 8b23f2c12c..781755ce0f 100644 --- a/test/PackedRequantizeAcc16Test.cc +++ b/test/PackedRequantizeAcc16Test.cc @@ -8,11 +8,9 @@ #include #include #include +#include #include #include -#ifdef _MSC_VER -#include -#endif #ifdef _OPENMP #include