Skip to content

Commit

Permalink
Add the Naive bfloat16 implementation based on MKL
Browse files Browse the repository at this point in the history
Summary:
Add the Naive bfloat16 implemenetation based on MKL.

For this Naive bfloat16 implementation for C += A * B (A, B, and C are all bfloat16 type), we do the following three steps:
1. Convert bfloat16 A, B, C to fp32;
2. Call cblas_sgemm from MKL/BLAS;
3. Convert fp32 C back to bfloat16 C.

Reviewed By: jspark1105

Differential Revision: D14391444

fbshipit-source-id: 1147dd2a18c4bbdec6c15f1d0f15d698d3741afe
  • Loading branch information
jianyuh authored and facebook-github-bot committed Mar 18, 2019
1 parent 6011ce3 commit 1351790
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
39 changes: 39 additions & 0 deletions src/RefImplementations.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
*/
#include "RefImplementations.h"

#include "fbgemm/Types.h"

#include <algorithm>
#include <cassert>
#include <cmath>
Expand Down Expand Up @@ -166,6 +168,43 @@ void matmul_fp_ref(
}
}

void cblas_sgemm_ref(
const matrix_op_t transa,
const matrix_op_t transb,
const int m,
const int n,
const int k,
float alpha,
const float* Afp32,
int lda,
const float* Bfp32,
int ldb,
float beta,
float* Cfp32,
int ldc
) {
for (int i = 0; i < m; ++i) {
for (int j = 0; j < n; ++j) {
float sum = 0;
for (int p = 0; p < k; ++p) {
float a =
(transa == matrix_op_t::NoTranspose ? Afp32[i * lda + p]
: Afp32[p * lda + i]);
float b =
(transb == matrix_op_t::NoTranspose ? Bfp32[p * ldb + j]
: Bfp32[j * ldb + p]);
sum += a * b;
}
if (beta == 0) {
Cfp32[i * ldc + j] = alpha * sum;
} else {
Cfp32[i * ldc + j] = alpha * sum + beta * Cfp32[i * ldc + j];
}
}
}
}


void row_offsets_u8acc32_ref(
int M,
int K,
Expand Down
19 changes: 19 additions & 0 deletions src/RefImplementations.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,25 @@ void FBGEMM_API matmul_fp_ref(
const float* Bfp32,
float* Cfp32);

/**
* @brief Reference implementation of cblas_sgemm in MKL/BLAS.
*/
void FBGEMM_API cblas_sgemm_ref(
const matrix_op_t transa,
const matrix_op_t transb,
const int m,
const int n,
const int k,
float alpha,
const float* Afp32,
int lda,
const float* Bfp32,
int ldb,
float beta,
float* Cfp32,
int ldc
);

/**
* @brief Reference implementation to compute row_offsets (sums of rows of A).
*/
Expand Down

0 comments on commit 1351790

Please sign in to comment.