forked from pytorch/FBGEMM
-
Notifications
You must be signed in to change notification settings - Fork 0
/
FbgemmFP16UKernelsIntrinsicAvx512.cc
141 lines (134 loc) · 4.76 KB
/
FbgemmFP16UKernelsIntrinsicAvx512.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#if defined(__x86_64__) || defined(__i386__) || \
(defined(_MSC_VER) && (defined(_M_X64) || defined(_M_IX86)))
#include <immintrin.h>
#endif
#include "./FbgemmFP16UKernelsAvx512.h"
namespace fbgemm {
// Intrinsic kernel for MSVC
void gemmkernel_Avx512_fp16_fA0fB0fC0(
GemmParamsFP16* gp,
const size_t kernel_nrows) {
// register buffer
__m512 zmmSum[28];
size_t idxA = 0, idxB = 0, idxC = 0;
// ldc in float size
size_t ldc_floatsize = gp->ldc / sizeof(float);
// load beta
__m512 zmmBeta;
if (gp->beta != 0)
zmmBeta = _mm512_broadcastss_ps(_mm_broadcast_ss(&gp->beta));
// outer loop - block columns
for (uint64_t ii = 0; ii < gp->b_block_cols; ii++) {
// reset index
idxA = 0;
// inner loop - k
for (uint64_t kk = 0; kk < gp->k; kk++) {
// load B
__m512 zmmB0 =
_mm512_cvtph_ps(_mm256_load_si256((__m256i*)(gp->B + idxB)));
__m512 zmmB1 =
_mm512_cvtph_ps(_mm256_load_si256((__m256i*)(gp->B + idxB + 16)));
idxB += 32;
// first element
if (kk == 0) {
if (gp->beta != 0) { // accumulate
for (size_t jj = 0; jj < kernel_nrows; jj++) {
// load A
__m512 zmmA = _mm512_broadcastss_ps(
_mm_broadcast_ss((float const*)(gp->A + idxA + jj)));
// C = A * B + beta * C
zmmSum[2 * jj] = _mm512_fmadd_ps(
zmmA,
zmmB0,
_mm512_mul_ps(
zmmBeta,
_mm512_loadu_ps(gp->C + idxC + jj * ldc_floatsize)));
zmmSum[2 * jj + 1] = _mm512_fmadd_ps(
zmmA,
zmmB1,
_mm512_mul_ps(
zmmBeta,
_mm512_loadu_ps(gp->C + idxC + 16 + jj * ldc_floatsize)));
}
idxA += kernel_nrows;
} else { // set zero
for (size_t jj = 0; jj < kernel_nrows; jj++) {
// load A
__m512 zmmA = _mm512_broadcastss_ps(
_mm_broadcast_ss((float const*)(gp->A + idxA + jj)));
// C = A * B
zmmSum[2 * jj] = _mm512_mul_ps(zmmA, zmmB0);
zmmSum[2 * jj + 1] = _mm512_mul_ps(zmmA, zmmB1);
}
idxA += kernel_nrows;
}
} else {
for (size_t jj = 0; jj < kernel_nrows; jj++) {
// load A
__m512 zmmA = _mm512_broadcastss_ps(
_mm_broadcast_ss((float const*)(gp->A + idxA + jj)));
// C = A * B + C
zmmSum[2 * jj] = _mm512_fmadd_ps(zmmA, zmmB0, zmmSum[2 * jj]);
zmmSum[2 * jj + 1] = _mm512_fmadd_ps(zmmA, zmmB1, zmmSum[2 * jj + 1]);
}
idxA += kernel_nrows;
}
}
// store C
for (size_t jj = 0; jj < kernel_nrows; jj++) {
_mm512_storeu_ps(gp->C + idxC + jj * ldc_floatsize, zmmSum[2 * jj]);
_mm512_storeu_ps(
gp->C + idxC + 16 + jj * ldc_floatsize, zmmSum[2 * jj + 1]);
}
idxC += 32;
}
}
void NOINLINE gemmkernel_1x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 1);
}
void NOINLINE gemmkernel_2x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 2);
}
void NOINLINE gemmkernel_3x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 3);
}
void NOINLINE gemmkernel_4x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 4);
}
void NOINLINE gemmkernel_5x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 5);
}
void NOINLINE gemmkernel_6x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 6);
}
void NOINLINE gemmkernel_7x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 7);
}
void NOINLINE gemmkernel_8x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 8);
}
void NOINLINE gemmkernel_9x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 9);
}
void NOINLINE gemmkernel_10x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 10);
}
void NOINLINE gemmkernel_11x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 11);
}
void NOINLINE gemmkernel_12x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 12);
}
void NOINLINE gemmkernel_13x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 13);
}
void NOINLINE gemmkernel_14x2_Avx512_fp16_fA0fB0fC0(GemmParamsFP16* gp) {
gemmkernel_Avx512_fp16_fA0fB0fC0(gp, 14);
}
} // namespace fbgemm