Skip to content

Commit

Permalink
Minor improvements in GEMM Kernels (pytorch#368)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#368

1) Replace vxorps with vpxor. vpxor can execute on multiple ports. Also zero-out xmm and ymm/zmm are zeroed out implicitly.
2) Use integer version of vmov (On some architectures there is some penalty for fp to integer vector lane changes)

Reference: https://stackoverflow.com/questions/33666617/what-is-the-best-way-to-set-a-register-to-zero-in-x86-assembly-xor-mov-or-and

No significant change in performance.

Before:
```
     M,      N,      K,             Type,  GOPS
    64,    800,    320,         MKL_fp32,  46.3
    64,    800,    320,  FBGEMM_i8_acc32,  83.6
    64,    800,    320,  FBGEMM_i8_acc16,  79.2

    64,    768,    512,         MKL_fp32,  46.9
    64,    768,    512,  FBGEMM_i8_acc32,  88.1
    64,    768,    512,  FBGEMM_i8_acc16,  89.2

    16,    256,    512,         MKL_fp32,  27.6
    16,    256,    512,  FBGEMM_i8_acc32,  43.5
    16,    256,    512,  FBGEMM_i8_acc16,  54.1

   128,    128,    128,         MKL_fp32,  30.1
   128,    128,    128,  FBGEMM_i8_acc32,  42.1
   128,    128,    128,  FBGEMM_i8_acc16,  40.6

   256,    512,    256,         MKL_fp32,  44.8
   256,    512,    256,  FBGEMM_i8_acc32,  91.7
   256,    512,    256,  FBGEMM_i8_acc16,  91.1

  1024,   1024,   1024,         MKL_fp32,  48.8
  1024,   1024,   1024,  FBGEMM_i8_acc32,  97.0
  1024,   1024,   1024,  FBGEMM_i8_acc16,  97.6
```
After:
```
     M,      N,      K,             Type,  GOPS
    64,    800,    320,         MKL_fp32,  46.2
    64,    800,    320,  FBGEMM_i8_acc32,  83.5
    64,    800,    320,  FBGEMM_i8_acc16,  80.8

    64,    768,    512,         MKL_fp32,  47.2
    64,    768,    512,  FBGEMM_i8_acc32,  88.5
    64,    768,    512,  FBGEMM_i8_acc16,  87.3

    16,    256,    512,         MKL_fp32,  26.0
    16,    256,    512,  FBGEMM_i8_acc32,  44.0
    16,    256,    512,  FBGEMM_i8_acc16,  54.5

   128,    128,    128,         MKL_fp32,  29.6
   128,    128,    128,  FBGEMM_i8_acc32,  42.2
   128,    128,    128,  FBGEMM_i8_acc16,  38.5

   256,    512,    256,         MKL_fp32,  44.3
   256,    512,    256,  FBGEMM_i8_acc32,  91.1
   256,    512,    256,  FBGEMM_i8_acc16,  91.0

  1024,   1024,   1024,         MKL_fp32,  48.7
  1024,   1024,   1024,  FBGEMM_i8_acc32,  96.6
  1024,   1024,   1024,  FBGEMM_i8_acc16,  96.5
```

Reviewed By: jspark1105

Differential Revision: D21433384

fbshipit-source-id: d0abd56f454293e159d3fda9d94bc84e011060c8
  • Loading branch information
dskhudia authored and facebook-github-bot committed May 14, 2020
1 parent 7ed5f9f commit 46981b8
Show file tree
Hide file tree
Showing 7 changed files with 29 additions and 17 deletions.
6 changes: 3 additions & 3 deletions src/GenerateI8Depthwise.cc
Original file line number Diff line number Diff line change
Expand Up @@ -351,7 +351,7 @@ GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate(
}
x86::Ymm zero(vreg_id);
if (need_zero && (!recompute_zero || !has_pad)) {
e->vxorps(zero, zero, zero);
e->vpxor(zero.xmm(), zero.xmm(), zero.xmm());
}

// Assign scalar registers
Expand Down Expand Up @@ -433,7 +433,7 @@ GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate(
if (i % 4 == 3 || i == K - 1) {
if (i == K - 1 && (i / 4 * 4 == K - 3 || i / 4 * 4 == K - 1)) {
if (recompute_zero && has_pad) {
e->vxorps(zero, zero, zero);
e->vpxor(zero.xmm(), zero.xmm(), zero.xmm());
}
}

Expand Down Expand Up @@ -465,7 +465,7 @@ GenI8Depthwise::jit_kernel_signature GenI8Depthwise::getOrCreate(
asmjit::Imm(r < 2 ? 0x20 : 0x31));
}
for (int r = 0; r < (main_loop ? 4 : remainder / 8); ++r) {
e->vmovaps(c[r], a[r]);
e->vmovdqa(c[r], a[r]);
}
}
}
Expand Down
6 changes: 4 additions & 2 deletions src/GenerateKernelU8S8S32ACC16.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
inst_set_t::avx2>(x86::Emitter* a, int rowRegs, int colRegs) {
using CRegs = x86::Ymm;
using CRegs = x86::Xmm;
// Take advantage of implicit zeroing out
// i.e., zero out xmm and ymm will be zeroed out too
for (int i = 0; i < rowRegs; ++i) {
for (int j = 0; j < colRegs; ++j) {
a->vxorps(
a->vpxor(
CRegs(i * colRegs + j),
CRegs(i * colRegs + j),
CRegs(i * colRegs + j));
Expand Down
6 changes: 4 additions & 2 deletions src/GenerateKernelU8S8S32ACC16Avx512.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int16_t>::initCRegs<
inst_set_t::avx512>(x86::Emitter* a, int rowRegs, int colRegs) {
using CRegs = x86::Zmm;
using CRegs = x86::Xmm;
// Take advantage of implicit zeroing out
// i.e., zero out xmm and zmm will be zeroed out too
for (int i = 0; i < rowRegs; ++i) {
for (int j = 0; j < colRegs; ++j) {
a->vxorps(
a->vpxor(
CRegs(i * colRegs + j),
CRegs(i * colRegs + j),
CRegs(i * colRegs + j));
Expand Down
8 changes: 5 additions & 3 deletions src/GenerateKernelU8S8S32ACC32.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
inst_set_t::avx2>(x86::Emitter* a, int rowRegs, int colRegs) {
using CRegs = x86::Ymm;
using CRegs = x86::Xmm;
// Take advantage of implicit zeroing out
// i.e., zero out xmm and ymm will be zeroed out too
for (int i = 0; i < rowRegs; ++i) {
for (int j = 0; j < colRegs; ++j) {
a->vxorps(
a->vpxor(
CRegs(i * colRegs + j),
CRegs(i * colRegs + j),
CRegs(i * colRegs + j));
Expand Down Expand Up @@ -61,7 +63,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<

for (int j = 0; j < colRegs; ++j) {
// load B
a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
a->vmovdqa(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
// load A, broadcast and fmas
for (int i = 0; i < rowRegs; ++i) {
a->vpbroadcastd(
Expand Down
8 changes: 5 additions & 3 deletions src/GenerateKernelU8S8S32ACC32Avx512.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@ template <>
template <>
void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::initCRegs<
inst_set_t::avx512>(x86::Emitter* a, int rowRegs, int colRegs) {
using CRegs = x86::Zmm;
using CRegs = x86::Xmm;
// Take advantage of implicit zeroing out
// i.e., zero out xmm and zmm will be zeroed out too
for (int i = 0; i < rowRegs; ++i) {
for (int j = 0; j < colRegs; ++j) {
a->vxorps(
a->vpxor(
CRegs(i * colRegs + j),
CRegs(i * colRegs + j),
CRegs(i * colRegs + j));
Expand Down Expand Up @@ -60,7 +62,7 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
using CRegs = x86::Zmm;
for (int j = 0; j < colRegs; ++j) {
// load B
a->vmovaps(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
a->vmovdqa32(BReg, x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
// load A, broadcast and fmas
for (int i = 0; i < rowRegs; ++i) {
a->vpbroadcastd(
Expand Down
3 changes: 2 additions & 1 deletion src/GenerateKernelU8S8S32ACC32Avx512VNNI.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ void CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::genComputeBlock<
assert(colRegs * (rowRegs + 1) <= 31);

for (int j = 0; j < colRegs; ++j) {
a->vmovaps(x86::Zmm(30-j), x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
a->vmovdqa32(
x86::Zmm(30 - j), x86::dword_ptr(buffer_B, j * VLEN_ * sizeof(int8_t)));
}

for (int i = 0; i < rowRegs; i++) {
Expand Down
9 changes: 6 additions & 3 deletions src/GroupwiseConvAcc32Avx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -308,11 +308,13 @@ template <int SPATIAL_DIM>
void GenConvKernel<SPATIAL_DIM, inst_set_t::avx2>::initResultRegs(
x86::Emitter* a) {
if (kLoopIters_ > 0) {
// Take advantage of implicit zeroing out
// i.e., zero out xmm and ymm will be zeroed out too
for (int k = 0; k < kLoopIters_; ++k) {
a->vxorps(x86::Ymm(9 - k), x86::Ymm(9 - k), x86::Ymm(9 - k));
a->vpxor(x86::Xmm(9 - k), x86::Xmm(9 - k), x86::Xmm(9 - k));
}
} else {
a->vxorps(x86::Ymm(9), x86::Ymm(9), x86::Ymm(9));
a->vpxor(x86::Xmm(9), x86::Xmm(9), x86::Xmm(9));
}
}

Expand Down Expand Up @@ -557,7 +559,8 @@ void GenConvKernel<SPATIAL_DIM, inst_set_t::avx2>::genForSingleOutput(

// row offset
if (this->needRowOffset_) {
a->vxorps(rowOffsetReg_V_, rowOffsetReg_V_, rowOffsetReg_V_);
a->vpxor(
rowOffsetReg_V_.xmm(), rowOffsetReg_V_.xmm(), rowOffsetReg_V_.xmm());
}

bool isWidthMiddle = !isLeft && !isRight;
Expand Down

0 comments on commit 46981b8

Please sign in to comment.