Skip to content

Commit

Permalink
Move FP16 function to common location (pytorch#438)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#438

Extract common C++ and templetize

Reviewed By: dskhudia

Differential Revision: D22152352

fbshipit-source-id: 62e8b85c437a2bf957833d821c7916ad432fffa7
  • Loading branch information
efiks authored and facebook-github-bot committed Oct 15, 2020
1 parent 48bc1b2 commit abc56f6
Show file tree
Hide file tree
Showing 10 changed files with 595 additions and 503 deletions.
1 change: 1 addition & 0 deletions defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def get_fbgemm_generic_srcs(with_base = False):
"src/Fbgemm.cc",
"src/FbgemmBfloat16Convert.cc",
"src/FbgemmConv.cc",
"src/FbgemmFPCommon.cc",
"src/FbgemmFP16.cc",
"src/FbgemmFloat16Convert.cc",
"src/FbgemmI64.cc",
Expand Down
17 changes: 13 additions & 4 deletions include/fbgemm/FbgemmFP16.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,26 @@
namespace fbgemm {

using PackedGemmMatrixFP16 = PackedGemmMatrixB<float16>;
/**
* restrictions: transa == CblasNoTrans
*/

template<typename T>
FBGEMM_API void cblas_gemm_compute(
const matrix_op_t transa,
const int m,
const float* A,
const PackedGemmMatrixFP16& Bp,
const PackedGemmMatrixB<T>& Bp,
const float beta,
float* C,
int thread_id = 0,
int num_threads = 1);

extern template void cblas_gemm_compute<float16>(
const matrix_op_t transa,
const int m,
const float* A,
const PackedGemmMatrixFP16& Bp,
const float beta,
float* C,
int thread_id,
int num_threads);

}; // namespace fbgemm
Loading

0 comments on commit abc56f6

Please sign in to comment.