Skip to content

Commit

Permalink
Change kernel function names to use ISA name directly (pytorch#156)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#156

Use instruction set name from the description structure while naming the kernel functions

Reviewed By: jspark1105

Differential Revision: D18239352

fbshipit-source-id: 0678e4ad01e19465e8ea7cdf6563feaf2ce7e501
  • Loading branch information
efiks authored and facebook-github-bot committed Oct 31, 2019
1 parent b884868 commit 0bf264f
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 77 deletions.
54 changes: 27 additions & 27 deletions src/FbgemmFP16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,34 +45,34 @@ inline void PackA(int nrow, int ncol, const float* from, int ldim, float* to) {
// }

struct KernelInfo {
using knl_ptr = funcptr_fp16;
// optimized kernels to cover all cases
// 2 in ?x2 should be the same as kernel_ncol_blocks.
// Here with kernel_ncol_blocks = 2, we can provide up to 6x2 kernels, due to
// the restrictions of ymm register numbers (16).
static constexpr knl_ptr kernel_avx2[] = {nullptr,
gemmkernel_1x2_AVX2_fA0fB0fC0,
gemmkernel_2x2_AVX2_fA0fB0fC0,
gemmkernel_3x2_AVX2_fA0fB0fC0,
gemmkernel_4x2_AVX2_fA0fB0fC0,
gemmkernel_5x2_AVX2_fA0fB0fC0,
gemmkernel_6x2_AVX2_fA0fB0fC0};
using knl_ptr = funcptr_fp16;
// optimized kernels to cover all cases
// 2 in ?x2 should be the same as kernel_ncol_blocks.
// Here with kernel_ncol_blocks = 2, we can provide up to 6x2 kernels, due to
// the restrictions of ymm register numbers (16).
static constexpr knl_ptr kernel_avx2[] = {nullptr,
gemmkernel_1x2_Avx2_fA0fB0fC0,
gemmkernel_2x2_Avx2_fA0fB0fC0,
gemmkernel_3x2_Avx2_fA0fB0fC0,
gemmkernel_4x2_Avx2_fA0fB0fC0,
gemmkernel_5x2_Avx2_fA0fB0fC0,
gemmkernel_6x2_Avx2_fA0fB0fC0};

static constexpr knl_ptr kernel_avx512[] = {nullptr,
gemmkernel_1x2_AVX512_fA0fB0fC0,
gemmkernel_2x2_AVX512_fA0fB0fC0,
gemmkernel_3x2_AVX512_fA0fB0fC0,
gemmkernel_4x2_AVX512_fA0fB0fC0,
gemmkernel_5x2_AVX512_fA0fB0fC0,
gemmkernel_6x2_AVX512_fA0fB0fC0,
gemmkernel_7x2_AVX512_fA0fB0fC0,
gemmkernel_8x2_AVX512_fA0fB0fC0,
gemmkernel_9x2_AVX512_fA0fB0fC0,
gemmkernel_10x2_AVX512_fA0fB0fC0,
gemmkernel_11x2_AVX512_fA0fB0fC0,
gemmkernel_12x2_AVX512_fA0fB0fC0,
gemmkernel_13x2_AVX512_fA0fB0fC0,
gemmkernel_14x2_AVX512_fA0fB0fC0};
static constexpr knl_ptr kernel_avx512[] = {nullptr,
gemmkernel_1x2_Avx512_fA0fB0fC0,
gemmkernel_2x2_Avx512_fA0fB0fC0,
gemmkernel_3x2_Avx512_fA0fB0fC0,
gemmkernel_4x2_Avx512_fA0fB0fC0,
gemmkernel_5x2_Avx512_fA0fB0fC0,
gemmkernel_6x2_Avx512_fA0fB0fC0,
gemmkernel_7x2_Avx512_fA0fB0fC0,
gemmkernel_8x2_Avx512_fA0fB0fC0,
gemmkernel_9x2_Avx512_fA0fB0fC0,
gemmkernel_10x2_Avx512_fA0fB0fC0,
gemmkernel_11x2_Avx512_fA0fB0fC0,
gemmkernel_12x2_Avx512_fA0fB0fC0,
gemmkernel_13x2_Avx512_fA0fB0fC0,
gemmkernel_14x2_Avx512_fA0fB0fC0};

// autotuned kernel splits for various cases m = 1:mb_max
// may need re-autotuning for new uarch
Expand Down
12 changes: 6 additions & 6 deletions src/FbgemmFP16UKernelsAvx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

namespace fbgemm {

void __attribute__((noinline)) gemmkernel_1x2_AVX2_fA0fB0fC0(GemmParams* gp) {
void __attribute__((noinline)) gemmkernel_1x2_Avx2_fA0fB0fC0(GemmParams* gp) {
asm volatile(
#if !defined(__clang__)
"mov r14, %[gp]\t\n"
Expand Down Expand Up @@ -106,7 +106,7 @@ void __attribute__((noinline)) gemmkernel_1x2_AVX2_fA0fB0fC0(GemmParams* gp) {
"r12",
"memory");
}
void __attribute__((noinline)) gemmkernel_2x2_AVX2_fA0fB0fC0(GemmParams* gp) {
void __attribute__((noinline)) gemmkernel_2x2_Avx2_fA0fB0fC0(GemmParams* gp) {
asm volatile(
#if !defined(__clang__)
"mov r14, %[gp]\t\n"
Expand Down Expand Up @@ -217,7 +217,7 @@ void __attribute__((noinline)) gemmkernel_2x2_AVX2_fA0fB0fC0(GemmParams* gp) {
"r12",
"memory");
}
void __attribute__((noinline)) gemmkernel_3x2_AVX2_fA0fB0fC0(GemmParams* gp) {
void __attribute__((noinline)) gemmkernel_3x2_Avx2_fA0fB0fC0(GemmParams* gp) {
asm volatile(
#if !defined(__clang__)
"mov r14, %[gp]\t\n"
Expand Down Expand Up @@ -341,7 +341,7 @@ void __attribute__((noinline)) gemmkernel_3x2_AVX2_fA0fB0fC0(GemmParams* gp) {
"r12",
"memory");
}
void __attribute__((noinline)) gemmkernel_4x2_AVX2_fA0fB0fC0(GemmParams* gp) {
void __attribute__((noinline)) gemmkernel_4x2_Avx2_fA0fB0fC0(GemmParams* gp) {
asm volatile(
#if !defined(__clang__)
"mov r14, %[gp]\t\n"
Expand Down Expand Up @@ -478,7 +478,7 @@ void __attribute__((noinline)) gemmkernel_4x2_AVX2_fA0fB0fC0(GemmParams* gp) {
"r12",
"memory");
}
void __attribute__((noinline)) gemmkernel_5x2_AVX2_fA0fB0fC0(GemmParams* gp) {
void __attribute__((noinline)) gemmkernel_5x2_Avx2_fA0fB0fC0(GemmParams* gp) {
asm volatile(
#if !defined(__clang__)
"mov r14, %[gp]\t\n"
Expand Down Expand Up @@ -628,7 +628,7 @@ void __attribute__((noinline)) gemmkernel_5x2_AVX2_fA0fB0fC0(GemmParams* gp) {
"r12",
"memory");
}
void __attribute__((noinline)) gemmkernel_6x2_AVX2_fA0fB0fC0(GemmParams* gp) {
void __attribute__((noinline)) gemmkernel_6x2_Avx2_fA0fB0fC0(GemmParams* gp) {
asm volatile(
#if !defined(__clang__)
"mov r14, %[gp]\t\n"
Expand Down
12 changes: 6 additions & 6 deletions src/FbgemmFP16UKernelsAvx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ struct GemmParams {
uint64_t b_block_cols;
uint64_t b_block_size;
};
void __attribute__((noinline)) gemmkernel_1x2_AVX2_fA0fB0fC0(GemmParams* gp);
void __attribute__((noinline)) gemmkernel_2x2_AVX2_fA0fB0fC0(GemmParams* gp);
void __attribute__((noinline)) gemmkernel_3x2_AVX2_fA0fB0fC0(GemmParams* gp);
void __attribute__((noinline)) gemmkernel_4x2_AVX2_fA0fB0fC0(GemmParams* gp);
void __attribute__((noinline)) gemmkernel_5x2_AVX2_fA0fB0fC0(GemmParams* gp);
void __attribute__((noinline)) gemmkernel_6x2_AVX2_fA0fB0fC0(GemmParams* gp);
void __attribute__((noinline)) gemmkernel_1x2_Avx2_fA0fB0fC0(GemmParams* gp);
void __attribute__((noinline)) gemmkernel_2x2_Avx2_fA0fB0fC0(GemmParams* gp);
void __attribute__((noinline)) gemmkernel_3x2_Avx2_fA0fB0fC0(GemmParams* gp);
void __attribute__((noinline)) gemmkernel_4x2_Avx2_fA0fB0fC0(GemmParams* gp);
void __attribute__((noinline)) gemmkernel_5x2_Avx2_fA0fB0fC0(GemmParams* gp);
void __attribute__((noinline)) gemmkernel_6x2_Avx2_fA0fB0fC0(GemmParams* gp);
typedef void (*funcptr_fp16)(GemmParams* gp);
;

Expand Down
47 changes: 28 additions & 19 deletions src/FbgemmFP16UKernelsAvx512.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

namespace fbgemm {

void __attribute__((noinline)) gemmkernel_1x2_AVX512_fA0fB0fC0(GemmParams* gp) {
void __attribute__((noinline)) gemmkernel_1x2_Avx512_fA0fB0fC0(GemmParams* gp) {
asm volatile(
#if !defined(__clang__)
"mov r14, %[gp]\t\n"
Expand Down Expand Up @@ -46,6 +46,7 @@ void __attribute__((noinline)) gemmkernel_1x2_AVX512_fA0fB0fC0(GemmParams* gp) {
"vxorps zmm0,zmm0,zmm0\t\n"
"vxorps zmm1,zmm1,zmm1\t\n"


"loop_inner%=:\t\n"

"vcvtph2ps zmm3,YMMWORD PTR [r10 + 0]\t\n"
Expand Down Expand Up @@ -105,7 +106,7 @@ void __attribute__((noinline)) gemmkernel_1x2_AVX512_fA0fB0fC0(GemmParams* gp) {
"r12",
"memory");
}
void __attribute__((noinline)) gemmkernel_2x2_AVX512_fA0fB0fC0(GemmParams* gp) {
void __attribute__((noinline)) gemmkernel_2x2_Avx512_fA0fB0fC0(GemmParams* gp) {
asm volatile(
#if !defined(__clang__)
"mov r14, %[gp]\t\n"
Expand Down Expand Up @@ -145,6 +146,7 @@ void __attribute__((noinline)) gemmkernel_2x2_AVX512_fA0fB0fC0(GemmParams* gp) {
"vxorps zmm2,zmm2,zmm2\t\n"
"vxorps zmm3,zmm3,zmm3\t\n"


"loop_inner%=:\t\n"

"vcvtph2ps zmm5,YMMWORD PTR [r10 + 0]\t\n"
Expand Down Expand Up @@ -215,7 +217,7 @@ void __attribute__((noinline)) gemmkernel_2x2_AVX512_fA0fB0fC0(GemmParams* gp) {
"r12",
"memory");
}
void __attribute__((noinline)) gemmkernel_3x2_AVX512_fA0fB0fC0(GemmParams* gp) {
void __attribute__((noinline)) gemmkernel_3x2_Avx512_fA0fB0fC0(GemmParams* gp) {
asm volatile(
#if !defined(__clang__)
"mov r14, %[gp]\t\n"
Expand Down Expand Up @@ -257,6 +259,7 @@ void __attribute__((noinline)) gemmkernel_3x2_AVX512_fA0fB0fC0(GemmParams* gp) {
"vxorps zmm4,zmm4,zmm4\t\n"
"vxorps zmm5,zmm5,zmm5\t\n"


"loop_inner%=:\t\n"

"vcvtph2ps zmm7,YMMWORD PTR [r10 + 0]\t\n"
Expand Down Expand Up @@ -338,7 +341,7 @@ void __attribute__((noinline)) gemmkernel_3x2_AVX512_fA0fB0fC0(GemmParams* gp) {
"r12",
"memory");
}
void __attribute__((noinline)) gemmkernel_4x2_AVX512_fA0fB0fC0(GemmParams* gp) {
void __attribute__((noinline)) gemmkernel_4x2_Avx512_fA0fB0fC0(GemmParams* gp) {
asm volatile(
#if !defined(__clang__)
"mov r14, %[gp]\t\n"
Expand Down Expand Up @@ -382,6 +385,7 @@ void __attribute__((noinline)) gemmkernel_4x2_AVX512_fA0fB0fC0(GemmParams* gp) {
"vxorps zmm6,zmm6,zmm6\t\n"
"vxorps zmm7,zmm7,zmm7\t\n"


"loop_inner%=:\t\n"

"vcvtph2ps zmm9,YMMWORD PTR [r10 + 0]\t\n"
Expand Down Expand Up @@ -474,7 +478,7 @@ void __attribute__((noinline)) gemmkernel_4x2_AVX512_fA0fB0fC0(GemmParams* gp) {
"r12",
"memory");
}
void __attribute__((noinline)) gemmkernel_5x2_AVX512_fA0fB0fC0(GemmParams* gp) {
void __attribute__((noinline)) gemmkernel_5x2_Avx512_fA0fB0fC0(GemmParams* gp) {
asm volatile(
#if !defined(__clang__)
"mov r14, %[gp]\t\n"
Expand Down Expand Up @@ -520,6 +524,7 @@ void __attribute__((noinline)) gemmkernel_5x2_AVX512_fA0fB0fC0(GemmParams* gp) {
"vxorps zmm8,zmm8,zmm8\t\n"
"vxorps zmm9,zmm9,zmm9\t\n"


"loop_inner%=:\t\n"

"vcvtph2ps zmm11,YMMWORD PTR [r10 + 0]\t\n"
Expand Down Expand Up @@ -623,7 +628,7 @@ void __attribute__((noinline)) gemmkernel_5x2_AVX512_fA0fB0fC0(GemmParams* gp) {
"r12",
"memory");
}
void __attribute__((noinline)) gemmkernel_6x2_AVX512_fA0fB0fC0(GemmParams* gp) {
void __attribute__((noinline)) gemmkernel_6x2_Avx512_fA0fB0fC0(GemmParams* gp) {
asm volatile(
#if !defined(__clang__)
"mov r14, %[gp]\t\n"
Expand Down Expand Up @@ -671,6 +676,7 @@ void __attribute__((noinline)) gemmkernel_6x2_AVX512_fA0fB0fC0(GemmParams* gp) {
"vxorps zmm10,zmm10,zmm10\t\n"
"vxorps zmm11,zmm11,zmm11\t\n"


"loop_inner%=:\t\n"

"vcvtph2ps zmm13,YMMWORD PTR [r10 + 0]\t\n"
Expand Down Expand Up @@ -785,7 +791,7 @@ void __attribute__((noinline)) gemmkernel_6x2_AVX512_fA0fB0fC0(GemmParams* gp) {
"r12",
"memory");
}
void __attribute__((noinline)) gemmkernel_7x2_AVX512_fA0fB0fC0(GemmParams* gp) {
void __attribute__((noinline)) gemmkernel_7x2_Avx512_fA0fB0fC0(GemmParams* gp) {
asm volatile(
#if !defined(__clang__)
"mov r14, %[gp]\t\n"
Expand Down Expand Up @@ -835,6 +841,7 @@ void __attribute__((noinline)) gemmkernel_7x2_AVX512_fA0fB0fC0(GemmParams* gp) {
"vxorps zmm12,zmm12,zmm12\t\n"
"vxorps zmm13,zmm13,zmm13\t\n"


"loop_inner%=:\t\n"

"vcvtph2ps zmm15,YMMWORD PTR [r10 + 0]\t\n"
Expand Down Expand Up @@ -960,7 +967,7 @@ void __attribute__((noinline)) gemmkernel_7x2_AVX512_fA0fB0fC0(GemmParams* gp) {
"r12",
"memory");
}
void __attribute__((noinline)) gemmkernel_8x2_AVX512_fA0fB0fC0(GemmParams* gp) {
void __attribute__((noinline)) gemmkernel_8x2_Avx512_fA0fB0fC0(GemmParams* gp) {
asm volatile(
#if !defined(__clang__)
"mov r14, %[gp]\t\n"
Expand Down Expand Up @@ -1012,6 +1019,7 @@ void __attribute__((noinline)) gemmkernel_8x2_AVX512_fA0fB0fC0(GemmParams* gp) {
"vxorps zmm14,zmm14,zmm14\t\n"
"vxorps zmm15,zmm15,zmm15\t\n"


"loop_inner%=:\t\n"

"vcvtph2ps zmm17,YMMWORD PTR [r10 + 0]\t\n"
Expand Down Expand Up @@ -1148,7 +1156,7 @@ void __attribute__((noinline)) gemmkernel_8x2_AVX512_fA0fB0fC0(GemmParams* gp) {
"r12",
"memory");
}
void __attribute__((noinline)) gemmkernel_9x2_AVX512_fA0fB0fC0(GemmParams* gp) {
void __attribute__((noinline)) gemmkernel_9x2_Avx512_fA0fB0fC0(GemmParams* gp) {
asm volatile(
#if !defined(__clang__)
"mov r14, %[gp]\t\n"
Expand Down Expand Up @@ -1202,6 +1210,7 @@ void __attribute__((noinline)) gemmkernel_9x2_AVX512_fA0fB0fC0(GemmParams* gp) {
"vxorps zmm16,zmm16,zmm16\t\n"
"vxorps zmm17,zmm17,zmm17\t\n"


"loop_inner%=:\t\n"

"vcvtph2ps zmm19,YMMWORD PTR [r10 + 0]\t\n"
Expand Down Expand Up @@ -1349,8 +1358,7 @@ void __attribute__((noinline)) gemmkernel_9x2_AVX512_fA0fB0fC0(GemmParams* gp) {
"r12",
"memory");
}
void __attribute__((noinline))
gemmkernel_10x2_AVX512_fA0fB0fC0(GemmParams* gp) {
void __attribute__((noinline)) gemmkernel_10x2_Avx512_fA0fB0fC0(GemmParams* gp) {
asm volatile(
#if !defined(__clang__)
"mov r14, %[gp]\t\n"
Expand Down Expand Up @@ -1406,6 +1414,7 @@ gemmkernel_10x2_AVX512_fA0fB0fC0(GemmParams* gp) {
"vxorps zmm18,zmm18,zmm18\t\n"
"vxorps zmm19,zmm19,zmm19\t\n"


"loop_inner%=:\t\n"

"vcvtph2ps zmm21,YMMWORD PTR [r10 + 0]\t\n"
Expand Down Expand Up @@ -1564,8 +1573,7 @@ gemmkernel_10x2_AVX512_fA0fB0fC0(GemmParams* gp) {
"r12",
"memory");
}
void __attribute__((noinline))
gemmkernel_11x2_AVX512_fA0fB0fC0(GemmParams* gp) {
void __attribute__((noinline)) gemmkernel_11x2_Avx512_fA0fB0fC0(GemmParams* gp) {
asm volatile(
#if !defined(__clang__)
"mov r14, %[gp]\t\n"
Expand Down Expand Up @@ -1623,6 +1631,7 @@ gemmkernel_11x2_AVX512_fA0fB0fC0(GemmParams* gp) {
"vxorps zmm20,zmm20,zmm20\t\n"
"vxorps zmm21,zmm21,zmm21\t\n"


"loop_inner%=:\t\n"

"vcvtph2ps zmm23,YMMWORD PTR [r10 + 0]\t\n"
Expand Down Expand Up @@ -1792,8 +1801,7 @@ gemmkernel_11x2_AVX512_fA0fB0fC0(GemmParams* gp) {
"r12",
"memory");
}
void __attribute__((noinline))
gemmkernel_12x2_AVX512_fA0fB0fC0(GemmParams* gp) {
void __attribute__((noinline)) gemmkernel_12x2_Avx512_fA0fB0fC0(GemmParams* gp) {
asm volatile(
#if !defined(__clang__)
"mov r14, %[gp]\t\n"
Expand Down Expand Up @@ -1853,6 +1861,7 @@ gemmkernel_12x2_AVX512_fA0fB0fC0(GemmParams* gp) {
"vxorps zmm22,zmm22,zmm22\t\n"
"vxorps zmm23,zmm23,zmm23\t\n"


"loop_inner%=:\t\n"

"vcvtph2ps zmm25,YMMWORD PTR [r10 + 0]\t\n"
Expand Down Expand Up @@ -2033,8 +2042,7 @@ gemmkernel_12x2_AVX512_fA0fB0fC0(GemmParams* gp) {
"r12",
"memory");
}
void __attribute__((noinline))
gemmkernel_13x2_AVX512_fA0fB0fC0(GemmParams* gp) {
void __attribute__((noinline)) gemmkernel_13x2_Avx512_fA0fB0fC0(GemmParams* gp) {
asm volatile(
#if !defined(__clang__)
"mov r14, %[gp]\t\n"
Expand Down Expand Up @@ -2096,6 +2104,7 @@ gemmkernel_13x2_AVX512_fA0fB0fC0(GemmParams* gp) {
"vxorps zmm24,zmm24,zmm24\t\n"
"vxorps zmm25,zmm25,zmm25\t\n"


"loop_inner%=:\t\n"

"vcvtph2ps zmm27,YMMWORD PTR [r10 + 0]\t\n"
Expand Down Expand Up @@ -2287,8 +2296,7 @@ gemmkernel_13x2_AVX512_fA0fB0fC0(GemmParams* gp) {
"r12",
"memory");
}
void __attribute__((noinline))
gemmkernel_14x2_AVX512_fA0fB0fC0(GemmParams* gp) {
void __attribute__((noinline)) gemmkernel_14x2_Avx512_fA0fB0fC0(GemmParams* gp) {
asm volatile(
#if !defined(__clang__)
"mov r14, %[gp]\t\n"
Expand Down Expand Up @@ -2352,6 +2360,7 @@ gemmkernel_14x2_AVX512_fA0fB0fC0(GemmParams* gp) {
"vxorps zmm26,zmm26,zmm26\t\n"
"vxorps zmm27,zmm27,zmm27\t\n"


"loop_inner%=:\t\n"

"vcvtph2ps zmm29,YMMWORD PTR [r10 + 0]\t\n"
Expand Down
Loading

0 comments on commit 0bf264f

Please sign in to comment.