Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Direct Convolution JIT assembly for KH=2, KW = 6
Summary: this diff has specialized codegen for convolution case where KH=2 and KW=6 ## Performance results on local devserver with AVX2 instruction: 1, 16, 16, {2, 126}, 1, {2, 6}, {1, 2}, {0, 0, 0, 0}, {1, 1}, {0, 0}, false Fbgemm baseline: 3.8 GOPS This diff: 9.2 GOPS 1, 64, 64, {2, 257}, 1, {2, 6}, {1, 2}, {0, 0, 0, 0}, {1, 1}, {0, 0}, false Fbgemm baseline: 43.8 GOPS This diff: 61.2 GOPS ## How to invoke indirect convolution function: **At offline:** 1. Weights need to be transposed to (oc/8) - (kh) - (kw) - (ic/4) - 8 - 4 2. Create the convolution function based on problem size: ``` CodeGenBase<uint8_t, int8_t, int32_t, int32_t> codeObj; CodeGenBase<uint8_t, int8_t, int32_t, int32_t>::jit_micro_kernel_fp fn; fn = codeObj.getOrCreateDirectConv<inst_set_t::avx2>( true, conv_p.OUT_DIM[1], conv_p.IN_DIM[1] * conv_p.IC, conv_p.stride[1] * conv_p.IC); ``` 3. Compute the *col_offsets* of weight tensor 4. Make sure you have allocated the space for: output tensor (Cint32_fb, Cint8_fb), and some temporary space for input rowsum ( InSum: IN_DIM[0] x IN_DIM[1], rowSum: OUT_DIM[0] x OUT_DIM[1]) **Online:** Make sure we have: conv_p ( the problem info), Aint8 (input tensor), bBuf_tr ( the transposed weight tensor), Cint32_fb ( the 32-bit results after accumulation), Cint8_fb ( the final quantized 8-bit output). // compute direct conv row sum directConvRowSum(conv_p, Aint8.data(), inSum, rowSum, row_offsets.data()); // kernel for direct convolution for (int oc = 0; oc < conv_p.OC; oc+= 8) { fn(Aint8.data(), bBuf_tr.data() + oc * kernel_dim * conv_p.IC , bBuf_tr.data(), Cint32_fb.data() + oc, conv_p.IC * conv_p.K[1], conv_p.OC); } requantizationParams_t<> reqObj = { Aint8_zero_point, // Aq_zero_point Bint8_zero_point.data(), C_zero_point, C_multiplier.data(), rowSum, // row_offsets //row_offsets.data(), col_offsets.data(), // col_offsets nullptr, // bias static_cast<std::uint32_t>(conv_p.OC), // ncols 1, // groups nullptr}; requantizeOutputProcessingAvx2<false, false, QuantizationGranularity::TENSOR, false, false>(Cint8_fb.data(), Cint32_ref.data(), {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC}, conv_p.OC, conv_p.OC, reqObj); For more details please refer to test_asmjit2.cc Reviewed By: dskhudia Differential Revision: D31775222 fbshipit-source-id: 294450613b0978277e75d171d6a560124c14ecda
- Loading branch information