Skip to content

Commit

Permalink
Extend kernel code to support multiple intruction sets (pytorch#391)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#391

Restructure code to be prepare adding new uArch support.
- Convert JIT generators to template
- Replace if() with switch() statements
- Generalize common routines

Reviewed By: dskhudia

Differential Revision: D22385872

fbshipit-source-id: 09c669beaebc341573c5b0aca8ee17eadd363995
  • Loading branch information
efiks authored and facebook-github-bot committed Jul 29, 2020
1 parent 139c6f2 commit cad1c21
Show file tree
Hide file tree
Showing 19 changed files with 628 additions and 670 deletions.
1 change: 1 addition & 0 deletions defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ def get_fbgemm_generic_srcs(with_base = False):
"src/FbgemmFloat16Convert.cc",
"src/FbgemmI64.cc",
"src/FbgemmI8Spmdm.cc",
"src/GenerateKernel.cc",
"src/GenerateKernelU8S8S32ACC16.cc",
"src/GenerateKernelU8S8S32ACC16Avx512.cc", # Acc16 AVX512 JIT code gen
"src/GenerateKernelU8S8S32ACC16Avx512VNNI.cc",
Expand Down
103 changes: 103 additions & 0 deletions include/fbgemm/PackingTraits-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,19 @@ struct PackingTraits<
static constexpr int NCB{
8}; ///< Cache block for N dimension (multiple of NR).
static constexpr int KCB{512}; ///< Cache block for K dimension.

static std::tuple<int, int, int> getCacheBlockParams() {
return std::tuple<int, int, int>(int(MCB), int(KCB), int(MR));
}
static std::tuple<int, int, int> getKernelParams() {
return std::tuple<int, int, int>(int(MCB), int(NCB), int(NR_MIN));
}
static std::tuple<int, int, int> getMatrixPackAParams() {
return std::tuple<int, int, int>(int(MCB), int(KCB), int(ROW_INTERLEAVE));
}
static std::tuple<int, int, int> getMatrixPackBParams() {
return std::tuple<int, int, int>(int(KCB), int(NCB), int(ROW_INTERLEAVE));
}
};

/**
Expand Down Expand Up @@ -112,6 +125,19 @@ struct PackingTraits<
static constexpr int NCB{
64}; ///< Cache block for N dimension (multiple of NR).
static constexpr int KCB{256}; ///< Cache block for K dimension.

static std::tuple<int, int, int> getCacheBlockParams() {
return std::tuple<int, int, int>(int(MCB), int(KCB), int(MR));
}
static std::tuple<int, int, int> getKernelParams() {
return std::tuple<int, int, int>(int(MCB), int(NCB), int(NR_MIN));
}
static std::tuple<int, int, int> getMatrixPackAParams() {
return std::tuple<int, int, int>(int(MCB), int(KCB), int(ROW_INTERLEAVE));
}
static std::tuple<int, int, int> getMatrixPackBParams() {
return std::tuple<int, int, int>(int(KCB), int(NCB), int(ROW_INTERLEAVE));
}
};

/**
Expand All @@ -133,6 +159,16 @@ struct PackingTraits<float, float, inst_set_t::avx2> {
static constexpr int NCB{
64}; ///< Cache block for N dimension (multiple of NR)
static constexpr int KCB{256}; ///< Cache block for K dimension

static std::tuple<int, int, int> getCacheBlockParams() {
return std::tuple<int, int, int>(int(MCB), int(KCB), int(MR));
}
static std::tuple<int, int, int> getMatrixPackAParams() {
return std::tuple<int, int, int>(int(MCB), int(KCB), int(ROW_INTERLEAVE));
}
static std::tuple<int, int, int> getMatrixPackBParams() {
return std::tuple<int, int, int>(int(KCB), int(NCB), int(ROW_INTERLEAVE));
}
};

/**
Expand Down Expand Up @@ -183,6 +219,19 @@ struct PackingTraits<
static constexpr int NCB{
32}; ///< Cache block for N dimension (multiple of NR).
static constexpr int KCB{256}; ///< Cache block for K dimension.

static std::tuple<int, int, int> getCacheBlockParams() {
return std::tuple<int, int, int>(int(MCB), int(KCB), int(MR));
}
static std::tuple<int, int, int> getKernelParams() {
return std::tuple<int, int, int>(int(MCB), int(NCB), int(NR_MIN));
}
static std::tuple<int, int, int> getMatrixPackAParams() {
return std::tuple<int, int, int>(int(MCB), int(KCB), int(ROW_INTERLEAVE));
}
static std::tuple<int, int, int> getMatrixPackBParams() {
return std::tuple<int, int, int>(int(KCB), int(NCB), int(ROW_INTERLEAVE));
}
};

/**
Expand Down Expand Up @@ -220,6 +269,19 @@ struct PackingTraits<
static constexpr int NCB{
128}; ///< Cache block for N dimension (multiple of NR).
static constexpr int KCB{256}; ///< Cache block for K dimension.

static std::tuple<int, int, int> getCacheBlockParams() {
return std::tuple<int, int, int>(int(MCB), int(KCB), int(MR));
}
static std::tuple<int, int, int> getKernelParams() {
return std::tuple<int, int, int>(int(MCB), int(NCB), int(NR_MIN));
}
static std::tuple<int, int, int> getMatrixPackAParams() {
return std::tuple<int, int, int>(int(MCB), int(KCB), int(ROW_INTERLEAVE));
}
static std::tuple<int, int, int> getMatrixPackBParams() {
return std::tuple<int, int, int>(int(KCB), int(NCB), int(ROW_INTERLEAVE));
}
};

/**
Expand Down Expand Up @@ -270,4 +332,45 @@ struct PackingTraits<
static constexpr int NCB{
48}; ///< Cache block for N dimension (multiple of NR).
static constexpr int KCB{512}; ///< Cache block for K dimension.

static std::tuple<int, int, int> getCacheBlockParams() {
return std::tuple<int, int, int>(int(MCB), int(KCB), int(MR));
}
static std::tuple<int, int, int> getKernelParams() {
return std::tuple<int, int, int>(int(MCB), int(NCB), int(NR_MIN));
}
static std::tuple<int, int, int> getMatrixPackAParams() {
return std::tuple<int, int, int>(int(MCB), int(KCB), int(ROW_INTERLEAVE));
}
static std::tuple<int, int, int> getMatrixPackBParams() {
return std::tuple<int, int, int>(int(KCB), int(NCB), int(ROW_INTERLEAVE));
}
};

/**
* @brief Packing parameter specialization for I64 GEMM
* integers.
*
* This is picked when T is of int64 type and instruction
* set is avx512.
*/
template <>
struct PackingTraits<int64_t, int64_t, inst_set_t::avx512> {
static constexpr int MR{2}; ///< Register block for M dimension.
static constexpr int NR_MIN{8}; ///< Minimum register block for N dimension.
///< 8 because 8 int64 elements
///< completely fill a 512-bit wide vector.
static constexpr int NR{
32}; ///< Register block for N dimension.
///< Must be a multiple of 16 because 16*ROW_INTERLEAVE int8 elements
///< completely fill a 512-bit wide vector. Total registers used for
///< N dimension: NR*8/VLEN. We use MR x
///< NR*8/VLEN zmm registers
///< for C accumulations.

static constexpr int MCB{
16}; ///< Cache block for M dimension (multiple of MR).
static constexpr int NCB{
64}; ///< Cache block for N dimension (multiple of NR).
static constexpr int KCB{8}; ///< Cache block for K dimension.
};
Loading

0 comments on commit cad1c21

Please sign in to comment.