Skip to content

Commit

Permalink
Add oneDNN BRGEMM support on CPU (pytorch#131878)
Browse files Browse the repository at this point in the history
  • Loading branch information
CaoE authored and pytorchmergebot committed Sep 7, 2024
1 parent b53d97c commit f7c0c06
Show file tree
Hide file tree
Showing 5 changed files with 440 additions and 5 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -332,6 +332,7 @@ intern_build_aten_ops(
"@fbgemm",
"@mkl",
"@sleef",
"@mkl_dnn//:mkl-dnn",
],
)

Expand Down
375 changes: 374 additions & 1 deletion aten/src/ATen/native/CPUBlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,17 @@ extern "C" void zaxpy_(int *n, void *a, const void *x, int *incx, void *y, int *
#include <fbgemm/FbgemmI64.h>
#endif // USE_FBGEMM

#if AT_MKLDNN_ENABLED()
#include <oneapi/dnnl/dnnl_version.h>
#endif // oneDNN

#define ONEDNN_UKERNEL_ENABLED (DNNL_VERSION_MAJOR >=3 && DNNL_VERSION_MINOR >=5)

#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
#include <oneapi/dnnl/dnnl_ukernel.hpp>
#include <oneapi/dnnl/dnnl.hpp>
#endif // oneDNN BRGEMM

namespace at::native::cpublas {
namespace internal {

Expand Down Expand Up @@ -822,4 +833,366 @@ void copy(int64_t n, const c10::complex<float> *x, int64_t incx, c10::complex<fl
n, x, incx, y, incy);
}

} // namespace at::native::cpublas
// oneDNN BRGEMM
#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
struct BrgemmKey {
int64_t M;
int64_t N;
int64_t K;
int64_t batch_size;
int64_t lda;
int64_t ldb;
int64_t ldc;
ScalarType dt_a;
ScalarType dt_b;
ScalarType dt_c;
float alpha;
float beta;
BrgemmKey(
int64_t M,
int64_t N,
int64_t K,
int64_t batch_size,
int64_t lda,
int64_t ldb,
int64_t ldc,
ScalarType dt_a,
ScalarType dt_b,
ScalarType dt_c,
float alpha,
float beta)
: M(M),
N(N),
K(K),
batch_size(batch_size),
lda(lda),
ldb(ldb),
ldc(ldc),
dt_a(dt_a),
dt_b(dt_b),
dt_c(dt_c),
alpha(alpha),
beta(beta) {}
bool operator==(const BrgemmKey& other) const {
return M == other.M && N == other.N && K == other.K &&
batch_size == other.batch_size && lda == other.lda &&
ldb == other.ldb && ldc == other.ldc && dt_a == other.dt_a &&
dt_b == other.dt_b && dt_c == other.dt_c && alpha == other.alpha &&
beta == other.beta;
}
};

struct PackKey {
int64_t K;
int64_t N;
int64_t ld_in;
int64_t ld_out;
ScalarType dt_in;
ScalarType dt_out;
PackKey(
int64_t K,
int64_t N,
int64_t ld_in,
int64_t ld_out,
ScalarType dt_in,
ScalarType dt_out)
: K(K),
N(N),
ld_in(ld_in),
ld_out(ld_out),
dt_in(dt_in),
dt_out(dt_out) {}
bool operator==(const PackKey& other) const {
return N == other.N && K == other.K && ld_in == other.ld_in &&
ld_out == other.ld_out && dt_in == other.dt_in &&
dt_out == other.dt_out;
}
};

inline dnnl::memory::data_type get_dnnl_dtype(ScalarType dtype) {
if (dtype == ScalarType::Float) {
return dnnl::memory::data_type::f32;
} else if (dtype == ScalarType::BFloat16) {
return dnnl::memory::data_type::bf16;
} else if (dtype == ScalarType::Half) {
return dnnl::memory::data_type::f16;
} else if (dtype == ScalarType::Byte) {
return dnnl::memory::data_type::u8;
} else if (dtype == ScalarType::Char) {
return dnnl::memory::data_type::s8;
} else {
TORCH_CHECK(false, "get_dnnl_dtype expects float/bfloat16/half/int8 tensor input");
}
}

template<typename key_t>
struct UnsafeUkernelKeyHasher {
std::size_t operator()(const key_t& key) const;
};

template<>
std::size_t UnsafeUkernelKeyHasher<BrgemmKey>::operator()(const BrgemmKey& key) const {
// Use beta, M, N, and K to compute hash to reduce the overhead as
// batch size, alpha, and data types are unlikely to change within the same kernel and
// leading dimensions are likely to be related to M, K, N or use fixed values.
std::size_t h = std::hash<float>()(key.beta + 1);
h = std::hash<int64_t>()(key.M) ^ (h << 1);
h = std::hash<int64_t>()(key.N) ^ (h << 1);
h = std::hash<int64_t>()(key.K) ^ (h << 1);
h = std::hash<int64_t>()(key.ldc) ^ (h << 1);
return h;
}

template<>
std::size_t UnsafeUkernelKeyHasher<PackKey>::operator()(const PackKey& key) const {
// Use K and N to compute hash to reduce the overhead as
// data types are unlikely to change and
// ld_in/ld_out is likely to be related to K, N or use fixed values
std::size_t h = std::hash<int64_t>()(key.K);
h = std::hash<int64_t>()(key.N) ^ (h << 1);
return h;
}

template <typename key_t, typename value_t>
struct KernelCache {
using kstore_t = std::unordered_map<key_t, std::shared_ptr<value_t>, UnsafeUkernelKeyHasher<key_t>>;
static inline std::shared_ptr<value_t>&& fetch_or_create(
const key_t& key,
const std::function<std::shared_ptr<value_t>()>& callback) {
auto&& search = get_store().find(key);
if (search != get_store().end()) {
return std::move(search->second);
} else {
get_store().insert({key, callback()});
return std::move(get_store()[key]);
}
}

static inline kstore_t& get_store() {
static thread_local kstore_t cache_kernels;
return cache_kernels;
}
};

// Helper struct for convenient brgemm configuration
struct GemmHelper {
GemmHelper(
int64_t M,
int64_t N,
int64_t K,
int64_t bs,
int64_t ld_a,
int64_t ld_b,
int64_t ld_c,
ScalarType dt_a,
ScalarType dt_b,
ScalarType dt_c,
const float alpha,
const float beta) {
// Create brgemm
brg = dnnl::ukernel::brgemm(
M,
N,
K,
bs,
ld_a,
ld_b,
ld_c,
get_dnnl_dtype(dt_a),
get_dnnl_dtype(dt_b),
get_dnnl_dtype(dt_c),
alpha,
beta);
// Create a scratchpad buffer for the brgemm execution
scratchpad = std::vector<uint8_t>(brg.get_scratchpad_size());
// Prepare default vector of pairs of tensors A and B offsets for each batch.
A_B_offsets.reserve(1);
A_B_offsets[0] = std::make_pair(0, 0);
}
dnnl::ukernel::brgemm brg;
std::vector<uint8_t> scratchpad;
std::vector<std::pair<int64_t, int64_t>> A_B_offsets;
};

struct Brgemm : public KernelCache <BrgemmKey, GemmHelper> {
// Fetch/create GemmHelper object and execute brgemm with batch size = 1
template <typename scalar_t_a, typename scalar_t_b, typename scalar_t_c>
static inline void call(
int64_t M,
int64_t N,
int64_t K,
int64_t ld_a,
int64_t ld_b,
int64_t ld_c,
const float alpha,
const float beta,
const scalar_t_a* A,
const scalar_t_b* B,
scalar_t_c* C) {
auto&& key = BrgemmKey(
M,
N,
K,
int64_t(1),
ld_a,
ld_b,
ld_c,
c10::CppTypeToScalarType<scalar_t_a>::value,
c10::CppTypeToScalarType<scalar_t_b>::value,
c10::CppTypeToScalarType<scalar_t_c>::value,
alpha,
beta);
// Fetch/create GemmHelper object
auto&& value = fetch_or_create(key, [&]() {
auto&& v = std::make_shared<GemmHelper>(
M,
N,
K,
1,
ld_a,
ld_b,
ld_c,
c10::CppTypeToScalarType<scalar_t_a>::value,
c10::CppTypeToScalarType<scalar_t_b>::value,
c10::CppTypeToScalarType<scalar_t_c>::value,
alpha,
beta);
(*v).brg.generate();
return std::move(v);
});
if (get_current() != value) {
dnnl::ukernel::brgemm::release_hw_context();
((*value).brg).set_hw_context();
get_current() = value;
}
((*value).brg)
.execute(A, B, (*value).A_B_offsets, C, (*value).scratchpad.data());
}

static inline std::shared_ptr<GemmHelper>& get_current() {
static thread_local std::shared_ptr<GemmHelper> current;
return current;
}

static inline bool device_check(ScalarType dtype) {
if (!at::globalContext().userEnabledMkldnn()) {
return false;
}
if (dtype == ScalarType::Half) {
static bool fp16_support = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core_fp16;
return fp16_support;
}
return false;
}
};

using pack_t = dnnl::ukernel::brgemm_pack_B;
struct Pack : public KernelCache <PackKey, pack_t> {
static inline void call(
int64_t K,
int64_t N,
int64_t ld_in,
int64_t ld_out,
ScalarType dt_in,
ScalarType dt_out,
const void* in,
void* out) {
auto&& key = PackKey(K, N, ld_in, ld_out, dt_in, dt_out);
auto&& pack = fetch_or_create(key, [&]() {
auto&& p = std::make_shared<pack_t>(
K, N, ld_in, ld_out, get_dnnl_dtype(dt_in), get_dnnl_dtype(dt_out));
if (need_pack(dt_in)) {
(*p).generate();
}
return std::move(p);
});
if (need_pack(dt_in)) {
(*pack).execute(in, out);
} else {
TORCH_CHECK(false, "No need to pack");
}
}

static inline bool need_pack(ScalarType dtype) {
if (!at::globalContext().userEnabledMkldnn()) {
return false;
}
if (dtype == ScalarType::Half) {
static bool fp16_pack = dnnl::get_effective_cpu_isa() >= dnnl::cpu_isa::avx512_core_amx_fp16;
return fp16_pack;
}
return false;
}
};
#endif

void brgemm(
int64_t M,
int64_t N,
int64_t K,
int64_t ld_a,
int64_t ld_b,
int64_t ld_c,
const float alpha,
const float beta,
const at::Half* A,
const at::Half* B,
float* C) {
#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
if (Brgemm::device_check(ScalarType::Half)) {
Brgemm::call<at::Half, at::Half, float>(
M, N, K, ld_a, ld_b, ld_c, alpha, beta, A, B, C);
return;
}
#endif
TORCH_CHECK(false,
"Half Brgemm is only supported on X64 when oneDNN ukernel is enabled and avx512_fp16 is supported");
}

void brgemm(
int64_t M,
int64_t N,
int64_t K,
int64_t ld_a,
int64_t ld_b,
int64_t ld_c,
const float alpha,
const float beta,
const at::BFloat16* A,
const at::BFloat16* B,
float* C) {
TORCH_CHECK(false,
"BFloat16 Brgemm is currently not supported");
}

void brgemm_release() {
#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
dnnl::ukernel::brgemm::release_hw_context();
#endif
}

void pack(
int64_t K,
int64_t N,
int64_t ld_in,
int64_t ld_out,
ScalarType dt_in,
ScalarType dt_out,
const void* in,
void* out) {
#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
Pack::call(K, N, ld_in, ld_out, dt_in, dt_out, in, out);
#else
TORCH_CHECK(false, "pack is only supported on X64 with oneDNN ukernel enabled");
#endif
}

bool need_pack(ScalarType dt_in) {
#if ONEDNN_UKERNEL_ENABLED && (defined(__x86_64__) || (defined(_M_X64) && !defined(_M_ARM64EC)))
return Pack::need_pack(dt_in);
#else
return false;
#endif
}

} // namespace at::native::cpublas
Loading

0 comments on commit f7c0c06

Please sign in to comment.