Skip to content

Commit

Permalink
Refactoring GenerateKernel files (pytorch#178)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#178

Fix bug in computing colRegs (doesn't affect correctness for int8_t)
Remove unused parameter ldc in getOrCreate function
Remove unused function signature nr_min
Remove leadingDimCRegAssign argument with default value in initCRegs, genComputeBlock, storeCRegs which can be error prone

Reviewed By: jianyuh

Differential Revision: D18442555

fbshipit-source-id: b6e2d4a98f383f442a5408e70cf6ad04ac4b4938
  • Loading branch information
jspark1105 committed Mar 21, 2020
1 parent 6d6c798 commit 0e8b68c
Show file tree
Hide file tree
Showing 8 changed files with 143 additions and 224 deletions.
22 changes: 9 additions & 13 deletions src/ExecuteKernelU8S8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ ExecuteKernel<
nrMinSize_ = PackingTraits<
int8_t,
typename packingAMatrix::accType,
inst_set_t::avx2>::NR;
inst_set_t::avx2>::NR_MIN;
} else {
assert(0 && "unsupported architecure");
}
Expand Down Expand Up @@ -140,30 +140,26 @@ void ExecuteKernel<
accum,
packed_rows_A,
packedB_.blockColSize(),
packedA_.numPackedCols(),
nbSize_);
packedA_.numPackedCols());
} else {
fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>(
accum,
packed_rows_A,
packedB_.blockColSize(),
packedA_.numPackedCols(),
nbSize_);
packedA_.numPackedCols());
}
} else if (fbgemmHasAvx512Support()) {
fn = BaseType::template getOrCreate<inst_set_t::avx512>(
accum,
packed_rows_A,
packedB_.blockColSize(),
packedA_.numPackedCols(),
nbSize_);
packedA_.numPackedCols());
} else if (fbgemmHasAvx2Support()) {
fn = BaseType::template getOrCreate<inst_set_t::avx2>(
accum,
packed_rows_A,
packedB_.blockColSize(),
packedA_.numPackedCols(),
nbSize_);
packedA_.numPackedCols());
} else {
// TODO: Have default slower path
assert(0 && "unsupported architecture");
Expand All @@ -186,17 +182,17 @@ void ExecuteKernel<
// For AVX512VNNI, we redirect int16_t to int32_t accumulation.
CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
fn = codeObj.getOrCreate<inst_set_t::avx512_vnni>(
accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_);
accum, packed_rows_A, nc, packedA_.numPackedCols());
} else {
fn = BaseType::template getOrCreate<inst_set_t::avx512_vnni>(
accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_);
accum, packed_rows_A, nc, packedA_.numPackedCols());
}
} else if (fbgemmHasAvx512Support()) {
fn = BaseType::template getOrCreate<inst_set_t::avx512>(
accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_);
accum, packed_rows_A, nc, packedA_.numPackedCols());
} else if (fbgemmHasAvx2Support()) {
fn = BaseType::template getOrCreate<inst_set_t::avx2>(
accum, packed_rows_A, nc, packedA_.numPackedCols(), nbSize_);
accum, packed_rows_A, nc, packedA_.numPackedCols());
} else {
// TODO: Have default slower path
assert(0 && "unsupported architecture");
Expand Down
51 changes: 21 additions & 30 deletions src/GenerateKernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ template <typename TA, typename TB, typename TC, typename accT>
class CodeGenBase {
public:
using jit_micro_kernel_fp = void (*)(
TA* bufferA,
TB* bufferB,
TB* b_pf,
const TA* bufferA,
const TB* bufferB,
const TB* b_pf,
TC* bufferC,
int kc,
int ldc);
Expand Down Expand Up @@ -72,17 +72,13 @@ class CodeGenBase {
*/
template <inst_set_t instSet>
jit_micro_kernel_fp
getOrCreate(bool accum, int32_t mc, int32_t nc, int32_t kc, int32_t ldc);
getOrCreate(bool accum, int32_t mc, int32_t nc, int32_t kc);

/**
* @brief Generate instructions for initializing the C registers to 0.
*/
template <inst_set_t instSet>
void initCRegs(
x86::Emitter* a,
int rowRegs,
int colRegs,
int leadingDimCRegAssign = 4);
void initCRegs(x86::Emitter* a, int rowRegs, int colRegs);

/**
* @brief Generate instructions for computing block in the rank-k update.
Expand All @@ -95,8 +91,7 @@ class CodeGenBase {
x86::Gp B_pf,
int rowRegs,
int colRegs,
int lda,
int leadingDimCRegAssign = 4);
int lda);

/**
* @brief Generate instructions for storing the C registers back to the
Expand All @@ -109,8 +104,7 @@ class CodeGenBase {
int colRegs,
x86::Gp C_Offset,
x86::Gp ldcReg,
bool accum,
int leadingDimCRegAssign = 4);
bool accum);

const BlockingFactors* blocking_params;
/**
Expand All @@ -125,8 +119,7 @@ class CodeGenBase {
int NCB,
int KCB,
int MR,
int NR,
int NR_MIN) {
int NR) {
std::ostringstream oss;
oss << "gemm_";
if (std::is_same<accT, std::int16_t>::value) {
Expand All @@ -136,14 +129,10 @@ class CodeGenBase {
} else {
oss << "unknown_";
}
oss << "accum-" + std::to_string(accum)
<< "_MC-" + std::to_string(mc)
<< "_NC-" + std::to_string(nc)
<< "_NCB-" + std::to_string(NCB)
<< "_NCB-" + std::to_string(KCB)
<< "_MR-" + std::to_string(MR)
<< "_NR-" + std::to_string(NR)
<< "_NR_MIN-" + std::to_string(NR_MIN);
oss << "accum-" + std::to_string(accum) << "_MC-" + std::to_string(mc)
<< "_NC-" + std::to_string(nc) << "_NCB-" + std::to_string(NCB)
<< "_NCB-" + std::to_string(KCB) << "_MR-" + std::to_string(MR)
<< "_NR-" + std::to_string(NR);
if (instSet == inst_set_t::avx512_vnni) {
oss << "_avx512vnni";
} else if (instSet == inst_set_t::avx512) {
Expand All @@ -159,28 +148,30 @@ class CodeGenBase {
int vectorWidth_; ///< Vector width in bits.
int VLEN_; ///< Vector width in elements.

static asmjit::JitRuntime &runtime() {
static asmjit::JitRuntime& runtime() {
static asmjit::JitRuntime rt; //< JIT Runtime for asmjit,
// depents on other static
// variables. Required to prevent
// initialization order fiasco
return rt;
}

static std::mutex rtMutex_; ///< Controll access to runtime;
static std::mutex rtMutex_; ///< Controll access to runtime;

// The hash depends on accumulate, mc, nc, ncb, kcb, nr, mr, nr_min
static CodeCache<std::tuple<bool, int, int, int, int, int, int, int>,
jit_micro_kernel_fp>
// The hash depends on accumulate, mc, nc, ncb, kcb, nr, mr
static CodeCache<
std::tuple<bool, int, int, int, int, int, int>,
jit_micro_kernel_fp>
codeCache_; ///< JIT Code Cache for reuse.
};

template <typename TA, typename TB, typename TC, typename accT>
std::mutex CodeGenBase<TA, TB, TC, accT>::rtMutex_;

template <typename TA, typename TB, typename TC, typename accT>
CodeCache<std::tuple<bool, int, int, int, int, int, int, int>,
typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp>
CodeCache<
std::tuple<bool, int, int, int, int, int, int>,
typename CodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp>
CodeGenBase<TA, TB, TC, accT>::codeCache_;

} // namespace fbgemm
54 changes: 24 additions & 30 deletions src/GenerateKernelU8S8S32ACC16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,14 @@ namespace x86 = asmjit::x86;
template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
inst_set_t::avx2>(
x86::Emitter* a,
int rowRegs,
int colRegs,
int leadingDimCReg) {
inst_set_t::avx2>(x86::Emitter* a, int rowRegs, int colRegs) {
using CRegs = x86::Ymm;
for (int i = 0; i < rowRegs; ++i) {
for (int j = 0; j < colRegs; ++j) {
a->vxorps(
CRegs(i * leadingDimCReg + j),
CRegs(i * leadingDimCReg + j),
CRegs(i * leadingDimCReg + j));
CRegs(i * colRegs + j),
CRegs(i * colRegs + j),
CRegs(i * colRegs + j));
}
}
}
Expand All @@ -48,10 +44,9 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
x86::Gp /* unused (reserved for prefetching)*/,
int rowRegs,
int colRegs,
int lda,
int leadingDimCReg) {
int lda) {
// used for matrix A
x86::Ymm AReg = x86::ymm12;
x86::Ymm AReg = x86::ymm13;

x86::Ymm tmpReg = x86::ymm14;

Expand All @@ -64,8 +59,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::genComputeBlock<
for (int j = 0; j < colRegs; ++j) {
a->vpmaddubsw(
tmpReg, AReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
a->vpaddsw(
CRegs(i * leadingDimCReg + j), tmpReg, CRegs(i * leadingDimCReg + j));
a->vpaddsw(CRegs(i * colRegs + j), tmpReg, CRegs(i * colRegs + j));
// Prefetching is hurting performance in some cases
// because prefetch instructions itself consumes a slot
// in pipeline issue thus slowing down the kernel.
Expand All @@ -89,8 +83,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
int colRegs,
x86::Gp C_Offset,
x86::Gp ldcReg,
bool accum,
int leadingDimCReg) {
bool accum) {
x86::Xmm extractDest128 = x86::xmm15;
x86::Ymm extractDest256 = x86::ymm15;

Expand All @@ -99,7 +92,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::storeCRegs<
a->imul(C_Offset, ldcReg, static_cast<asmjit::Imm>(i * sizeof(int32_t)));
for (int j = 0; j < colRegs; ++j) {
for (int idx = 0; idx < 2; ++idx) {
a->vextracti128(extractDest128, CRegs(i * leadingDimCReg + j), idx);
a->vextracti128(extractDest128, CRegs(i * colRegs + j), idx);
a->vpmovsxwd(extractDest256, extractDest128);
x86::Mem destAddr = x86::dword_ptr(
a->zcx(), C_Offset, 0, (j * 2 + idx) * 8 * sizeof(int32_t));
Expand All @@ -123,9 +116,8 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
bool accum,
int32_t mc,
int32_t nc,
int32_t kc,
int32_t /* unused */) {
std::tuple<bool, int, int, int, int, int, int, int> kernelSig;
int32_t kc) {
std::tuple<bool, int, int, int, int, int, int> kernelSig;
int kBlock;
int nBlock;
int mRegBlockSize;
Expand All @@ -152,14 +144,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
}

kernelSig = std::make_tuple(
accum,
mc,
nc,
nBlock,
kBlock,
mRegBlockSize,
nRegBlockSize,
nRegBlockSizeMin);
accum, mc, nc, nBlock, kBlock, mRegBlockSize, nRegBlockSize);

return codeCache_.getOrCreate(kernelSig, [&]() -> jit_micro_kernel_fp {
asmjit::CodeHolder code;
Expand Down Expand Up @@ -187,10 +172,19 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
}
#endif

int mRegBlocks = mc / mRegBlockSize;
int mRegBlocksRem = mc % mRegBlockSize;
assert(
kc % row_interleave == 0 && "kc must be a multiple of row_interleave");
assert(nc % nRegBlockSizeMin == 0 && "nc must be a multiple of NR_MIN");
int maxMRegs = mRegBlockSize;
int maxNRegs = nRegBlockSize * row_interleave / VLEN_;
assert(
maxMRegs * maxNRegs <= 13 &&
"MR*(NR*ROW_INTERLEAVE*8/256"
"must be <= 13(available registers constraint)");

int mRegBlocks = mc / mRegBlockSize;
int mRegBlocksRem = mc % mRegBlockSize;

// assert((nc == nRegBlockSize) &&
//"nc must be equal to the number of register blocks");

Expand Down Expand Up @@ -239,7 +233,7 @@ CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::getOrCreate<inst_set_t::avx2>(
x86::Gp iIdx = a->gpz(13);
x86::Gp kIdx = a->gpz(14);

int colRegs = nc * row_interleave * sizeof(int8_t) / VLEN_;
int colRegs = nc * row_interleave / VLEN_;
if (mRegBlocks > 0) {
// move 0 to iteration variables
a->mov(iIdx, 0);
Expand Down
Loading

0 comments on commit 0e8b68c

Please sign in to comment.