Skip to content

Commit

Permalink
Direct Convolution JIT assembly for KH=2, KW = 6
Browse files Browse the repository at this point in the history
Summary:
this diff has specialized codegen for convolution case where KH=2 and KW=6

## Performance results on local devserver with AVX2 instruction:
1, 16, 16,     {2, 126}, 1, {2, 6}, {1, 2}, {0, 0, 0, 0},     {1, 1}, {0, 0}, false
Fbgemm baseline:
3.8 GOPS
This diff:
9.2 GOPS

1, 64, 64,     {2, 257}, 1, {2, 6}, {1, 2}, {0, 0, 0, 0},     {1, 1}, {0, 0}, false
Fbgemm baseline:
43.8 GOPS
This diff:
61.2 GOPS

## How to invoke indirect convolution function:
**At offline:**
1. Weights need to be transposed to (oc/8) - (kh) - (kw) - (ic/4) - 8 - 4
2. Create the convolution function based on problem size:
```
       CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj;
       CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp fn;
       fn = codeObj.getOrCreateDirectConv<inst_set_t::avx2>(
        true, conv_p.OUT_DIM[1], conv_p.IN_DIM[1] * conv_p.IC, conv_p.stride[1] * conv_p.IC);
```
3. Compute the *col_offsets* of weight tensor
4. Make sure you have allocated the space for: output tensor (Cint32_fb, Cint8_fb), and some temporary space for input rowsum ( InSum: IN_DIM[0] x IN_DIM[1], rowSum: OUT_DIM[0] x OUT_DIM[1])

**Online:**
Make sure we have:
conv_p ( the problem info), Aint8 (input tensor), bBuf_tr ( the transposed weight tensor), Cint32_fb ( the 32-bit results after accumulation), Cint8_fb ( the final quantized 8-bit output).

       // compute direct conv row sum
       directConvRowSum(conv_p, Aint8.data(),
            inSum, rowSum, row_offsets.data());

      // kernel for direct convolution
        for (int oc = 0; oc < conv_p.OC; oc+= 8) {
          fn(Aint8.data(),
              bBuf_tr.data() + oc * kernel_dim * conv_p.IC ,
              bBuf_tr.data(),
              Cint32_fb.data() + oc,
              conv_p.IC * conv_p.K[1],
              conv_p.OC);
        }

        requantizationParams_t<> reqObj = {
          Aint8_zero_point, // Aq_zero_point
          Bint8_zero_point.data(),
          C_zero_point,
          C_multiplier.data(),
          rowSum, // row_offsets
          //row_offsets.data(),
          col_offsets.data(), // col_offsets
          nullptr, // bias
          static_cast<std::uint32_t>(conv_p.OC), // ncols
          1, // groups
          nullptr};

        requantizeOutputProcessingAvx2<false, false, QuantizationGranularity::TENSOR,
          false, false>(Cint8_fb.data(),
              Cint32_ref.data(),
              {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC}, conv_p.OC, conv_p.OC, reqObj);

For more details please refer to test_asmjit2.cc

Reviewed By: dskhudia

Differential Revision: D31775222

fbshipit-source-id: 294450613b0978277e75d171d6a560124c14ecda
  • Loading branch information
jiyuanzFB authored and facebook-github-bot committed Dec 18, 2021
1 parent b8c0923 commit 2ffb487
Show file tree
Hide file tree
Showing 4 changed files with 636 additions and 1 deletion.
1 change: 1 addition & 0 deletions defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def get_fbgemm_generic_srcs(with_base = False):
"src/FbgemmI64.cc",
"src/FbgemmSparseDense.cc",
"src/FbgemmI8Spmdm.cc",
"src/GenerateKernelDirectConvU8S8S32ACC32.cc",
"src/GenerateKernel.cc",
"src/GenerateKernelU8S8S32ACC16.cc",
"src/GenerateKernelU8S8S32ACC16Avx512.cc", # Acc16 AVX512 JIT code gen
Expand Down
159 changes: 159 additions & 0 deletions src/DirectConv.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <asmjit/asmjit.h>
#include <cpuinfo.h>
#include <cassert>
#include <cstdint>
#include <map>
#include <mutex>
#include <sstream>
#include <string>
#include <tuple>
#include <type_traits>
#include "./CodeCache.h"
#include "fbgemm/ConvUtils.h"
#include "fbgemm/Fbgemm.h"
#include "fbgemm/Utils.h"
/*#define FBGEMM_LOG_CODE 1*/

namespace fbgemm {

namespace x86 = asmjit::x86;

/**
* @brief Generate instructions for initializing the C registers to 0.
*/
void initCRegs(x86::Emitter* a, int rowRegs, int colRegs);

static asmjit::JitRuntime& runtime() {
static asmjit::JitRuntime rt; //< JIT Runtime for asmjit,
// depents on other static
// variables. Required to prevent
// initialization order fiasco
return rt;
}

template <typename TA, typename TB, typename TC, typename accT>
class DirectConvCodeGenBase {
public:
using jit_micro_kernel_fp = void (*)(
const TA* bufferA,
const TB* bufferB,
const TB* b_pf,
TC* bufferC,
int kc,
int ldc);

static std::mutex rtMutex_; ///< Control access to runtime;

// The hash depends on accumulate, mc, nc, ncb, kcb, nr, mr
static CodeCache<
std::tuple<bool, int, int, int, int, int, int>,
jit_micro_kernel_fp>
codeCache_; ///< JIT Code Cache for reuse.

/**
* @brief Generate instructions for storing the C registers back to the
* memory.
*/
template <inst_set_t instSet>
void storeCRegs(
x86::Emitter* a,
int rowRegs,
int colRegs,
x86::Gp C_Offset,
x86::Gp ldcReg,
bool accum);

/**
* @brief Generate filename to dump generated code
* (debug-only)
*/
template <inst_set_t instSet>
static std::string getCodeLoggingFile(
bool accum,
int mc,
int nc,
int NCB,
int KCB,
int MR,
int NR) {
std::ostringstream oss;
oss << "directconv_";
if (std::is_same<accT, std::int16_t>::value) {
oss << "acc16_";
} else if (std::is_same<accT, std::int32_t>::value) {
oss << "acc32_";
} else {
oss << "unknown_";
}
oss << "accum-" + std::to_string(accum) << "_MC-" + std::to_string(mc)
<< "_NC-" + std::to_string(nc) << "_NCB-" + std::to_string(NCB)
<< "_KCB-" + std::to_string(KCB) << "_MR-" + std::to_string(MR)
<< "_NR-" + std::to_string(NR);
if (instSet == inst_set_t::avx512_vnni) {
oss << "_avx512vnni";
} else if (instSet == inst_set_t::avx512) {
oss << "_avx512";
} else if (instSet == inst_set_t::avx512_ymm) {
oss << "_avx512_ymm";
} else if (instSet == inst_set_t::avx2) {
oss << "_avx2";
}
oss << ".txt";
return oss.str();
}

/**
* @brief Get or Create the instructions for macro-kernel.
*
* If the problem size (mc, nc) and accumulation flag (accum) can be found in
* the code cache (a hash map), then get the macro-kernel instructions
* directly from it. Otherwise, create the instructions for macro-kernel, and
* store that into the code cache.
*/
template <inst_set_t instSet>
jit_micro_kernel_fp
getOrCreateDirectConv(bool accum, int32_t mc, int32_t nc, int32_t kc);

/**
* @brief Generate instructions for computing block in the rank-k update.
*/
template <inst_set_t instSet>
void genComputeBlock(
x86::Emitter* a,
x86::Gp buffer_A,
x86::Gp buffer_B,
x86::Gp B_pf,
int rowRegs,
int colRegs,
int lda);
/**
* @brief Generate instructions for computing block in the rank-k update.
*/
template <inst_set_t instSet>
void genComputeBlockDirectConv(
x86::Emitter* a,
x86::Gp buffer_A,
x86::Gp buffer_B,
x86::Gp B_pf,
int rowRegs,
int colRegs,
int strideXich);
};

template <typename TA, typename TB, typename TC, typename accT>
std::mutex DirectConvCodeGenBase<TA, TB, TC, accT>::rtMutex_;

template <typename TA, typename TB, typename TC, typename accT>
CodeCache<
std::tuple<bool, int, int, int, int, int, int>,
typename DirectConvCodeGenBase<TA, TB, TC, accT>::jit_micro_kernel_fp>
DirectConvCodeGenBase<TA, TB, TC, accT>::codeCache_;

}; // namespace fbgemm
2 changes: 1 addition & 1 deletion src/GenerateKernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include <tuple>
#include "./CodeCache.h"
#include "fbgemm/Fbgemm.h"
/*#define FBGEMM_LOG_CODE 1*/
//#define FBGEMM_LOG_CODE 1

namespace fbgemm {

Expand Down
Loading

0 comments on commit 2ffb487

Please sign in to comment.