From f0f6ca70dd75c62098464fc26f97636b27bc89ff Mon Sep 17 00:00:00 2001 From: Jiyuan Zhang Date: Tue, 8 Feb 2022 15:23:18 -0800 Subject: [PATCH] add direct conv to fbgemm (#901) Summary: Pull Request resolved: https://github.com/pytorch/FBGEMM/pull/901 Integrate direct convolution code path into Fbgemm. The direct convolution entrance is integrated through FbgemmConv.cc. ConvFastPath will automatically determine if this case can use our direct convolution branch: - if spatial_dim=2, kh=2, kw<=6, stride=1 or 2, padding=0 Reviewed By: jspark1105 Differential Revision: D32273614 fbshipit-source-id: 16255395b4e14fae10c129ad98dd4445a5106989 --- defs.bzl | 2 + include/fbgemm/Fbgemm.h | 23 + include/fbgemm/FbgemmI8DirectconvAvx2.h | 60 ++ include/fbgemm/Utils.h | 3 +- src/DirectConv.h | 22 +- src/FbgemmConv.cc | 60 ++ src/GenerateKernelDirectConvU8S8S32ACC32.cc | 51 +- src/PackWeightsForConv.cc | 8 + src/PackWeightsForDirectConv.cc | 494 ++++++++++++ test/I8DirectconvTest.cc | 843 ++++++++++++++++++++ test/UniConvTest.cc | 108 +++ 11 files changed, 1639 insertions(+), 35 deletions(-) create mode 100644 include/fbgemm/FbgemmI8DirectconvAvx2.h create mode 100644 src/PackWeightsForDirectConv.cc create mode 100644 test/I8DirectconvTest.cc diff --git a/defs.bzl b/defs.bzl index 591fcc56a5..5dc87c0775 100644 --- a/defs.bzl +++ b/defs.bzl @@ -43,6 +43,7 @@ def get_fbgemm_generic_srcs(with_base = False): "src/PackMatrix.cc", "src/PackWeightMatrixForGConv.cc", "src/PackWeightsForConv.cc", + "src/PackWeightsForDirectConv.cc", "src/QuantUtils.cc", "src/RowWiseSparseAdagradFused.cc", "src/SparseAdagrad.cc", @@ -61,6 +62,7 @@ def get_fbgemm_public_headers(): "include/fbgemm/FbgemmFPCommon.h", "include/fbgemm/FbgemmI64.h", "include/fbgemm/FbgemmI8DepthwiseAvx2.h", + "include/fbgemm/FbgemmI8DirectconvAvx2.h", "include/fbgemm/FbgemmI8Spmdm.h", "include/fbgemm/FbgemmPackMatrixB.h", "include/fbgemm/FbgemmSparse.h", diff --git a/include/fbgemm/Fbgemm.h b/include/fbgemm/Fbgemm.h index ca4f090a08..2034d245b6 100644 --- a/include/fbgemm/Fbgemm.h +++ b/include/fbgemm/Fbgemm.h @@ -18,6 +18,7 @@ #include "./FbgemmBuild.h" #include "./FbgemmEmbedding.h" #include "./FbgemmI8DepthwiseAvx2.h" +#include "./FbgemmI8DirectconvAvx2.h" #include "./FbgemmI8Spmdm.h" #include "./QuantUtilsAvx2.h" #include "./Types.h" @@ -601,6 +602,10 @@ class FBGEMM_API PackWeightsForConv { return W_dw_packed_; } + std::shared_ptr getPackedWForDirectconv() { + return W_dc_packed_; + } + std::shared_ptr> getPackedWForGroupwise() { return W_gconv_packed_; @@ -649,6 +654,8 @@ class FBGEMM_API PackWeightsForConv { std::shared_ptr> W_im2col_packed_; // Packed weights if we use depthwise convolution implementation std::shared_ptr W_dw_packed_; + // Packed weights if we use direct convolution implementation + std::shared_ptr W_dc_packed_; // Packed weights if we use groupwise (small channels per group) convolution // implementation std::shared_ptr> @@ -1384,6 +1391,22 @@ FBGEMM_API void fbgemmGroupwiseConv( int thread_id, int num_threads); +template < + int SPATIAL_DIM, + QuantizationGranularity Q_GRAN, + bool FUSE_RELU, + typename BIAS_TYPE = std::int32_t> +FBGEMM_API void fbgemmDirectConv( + const conv_param_t& conv_p, + const uint8_t* Aint8, + PackedDirectConvMatrix& Bint8_tr, + uint8_t* C, + int32_t* C_buffer, + const ReQuantizeOutput& outProcess, + const BIAS_TYPE* bias, + int thread_id, + int num_threads); + /** * @return Size of row offset buffer in number of elements needed for * fbgemmGroupwiseConv diff --git a/include/fbgemm/FbgemmI8DirectconvAvx2.h b/include/fbgemm/FbgemmI8DirectconvAvx2.h new file mode 100644 index 0000000000..7b85fbf238 --- /dev/null +++ b/include/fbgemm/FbgemmI8DirectconvAvx2.h @@ -0,0 +1,60 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include +#include "fbgemm/ConvUtils.h" +#include "fbgemm/FbgemmBuild.h" +#include "fbgemm/UtilsAvx2.h" + +namespace fbgemm { + +class FBGEMM_API PackedDirectConvMatrix { + public: + /** + * @param IC the number of input channels + * @param OC the number of output channels + * @param kernel_prod the product of all kernels. For example, kernel_prod = + * 9 for 3x3 conv, and 27 for 3x3x3 conv. + * @param smat the source unpacked weight in GRS layout + */ + PackedDirectConvMatrix( + int IC_per_G, + int OC_per_G, + int filter_prod, + const std::int8_t* smat); + virtual ~PackedDirectConvMatrix(); + + const std::int8_t* PackedMat() const { + return pmat_; + } + + const bool& is_first_call() const { + return first_call; + } + + /** + compute the column offsets of the weight matrix. + output of this function is the col_offsets vector + col_offses dimension is the same as conv_p.OUT_DIM + */ + template + FBGEMM_API void col_offsets_with_zero_pt_s8acc32_DirectConvT( + const fbgemm::conv_param_t& conv_p, + std::int32_t* B_zero_point, + std::vector& col_offsets, + int ncols_per_quant_group); + + private: + std::int8_t* pmat_; /** packed weight */ + bool first_call{true}; +}; + +} // namespace fbgemm diff --git a/include/fbgemm/Utils.h b/include/fbgemm/Utils.h index 8f21168bf5..1df4390c56 100644 --- a/include/fbgemm/Utils.h +++ b/include/fbgemm/Utils.h @@ -58,7 +58,8 @@ enum class optimized_conv_t { groupwise, pointwise, fastpath1d, - im2col + im2col, + directconv }; /** diff --git a/src/DirectConv.h b/src/DirectConv.h index d06476cab2..ef97c36017 100644 --- a/src/DirectConv.h +++ b/src/DirectConv.h @@ -30,14 +30,6 @@ namespace x86 = asmjit::x86; */ void initCRegs(x86::Emitter* a, int rowRegs, int colRegs); -static asmjit::JitRuntime& runtime() { - static asmjit::JitRuntime rt; //< JIT Runtime for asmjit, - // depents on other static - // variables. Required to prevent - // initialization order fiasco - return rt; -} - template class DirectConvCodeGenBase { public: @@ -164,9 +156,8 @@ class DirectConvCodeGenBase { * store that into the code cache. */ template - jit_micro_kernel_fp_convT getOrCreateDirectConvTrans( - bool accum, - int32_t stride); + jit_micro_kernel_fp_convT + getOrCreateDirectConvTrans(bool accum, int32_t stride, int32_t numColRegs); /** * @brief Generate instructions for computing block in the rank-k update. @@ -205,6 +196,15 @@ class DirectConvCodeGenBase { x86::Gp C_offset, int rowRegs, int colRegs); + + private: + static asmjit::JitRuntime& runtime() { + static asmjit::JitRuntime rt; //< JIT Runtime for asmjit, + // depents on other static + // variables. Required to prevent + // initialization order fiasco + return rt; + } }; template diff --git a/src/FbgemmConv.cc b/src/FbgemmConv.cc index c40bd63b50..433b469eca 100644 --- a/src/FbgemmConv.cc +++ b/src/FbgemmConv.cc @@ -67,6 +67,35 @@ bool take1DFastPath(const conv_param_t& conv_p) { return false && !conv_p.transposed; } +template +bool takeDirectConvPath(const conv_param_t& conv_p) { + // Note: Direct convolutions (2D) are optimized for + // filter size: 2 x 1 to 2 x 6, transposed conv, + // in_channel % 8 == 0, out_channel % 8 == 0 + // stride = 1 or 2 + // padding = 0 ( non-zero padding will be supported soon) + bool ret = std::is_same::value && conv_p.transposed && + conv_p.G == 1 && conv_p.IC % 8 == 0 && conv_p.OC % 8 == 0 && + std::all_of( + conv_p.stride.begin(), + conv_p.stride.end(), + [](int i) { return i == 1 || i == 2; }) && + SPATIAL_DIM == 2 && conv_p.K[SPATIAL_DIM - 2] == 2 && + conv_p.K[SPATIAL_DIM - 1] <= 6 && + std::all_of(conv_p.dilation.begin(), conv_p.dilation.end(), [](int i) { + return i == 1; + }); + + // Check pads: zero padding + for (int i = 0; i < SPATIAL_DIM; ++i) { + if (conv_p.pad[i] != 0) { + ret = false; + } + } + + return ret; +} + template optimized_conv_t ConvFastPath(const conv_param_t& conv_p) { if (takeDepthWiseFastPath(conv_p)) { @@ -75,6 +104,8 @@ optimized_conv_t ConvFastPath(const conv_param_t& conv_p) { return optimized_conv_t::groupwise; } else if (takePointWiseFastPath(conv_p)) { return optimized_conv_t::pointwise; + } else if (takeDirectConvPath(conv_p)) { + return optimized_conv_t::directconv; } else if (take1DFastPath(conv_p)) { return optimized_conv_t::fastpath1d; } else { @@ -309,6 +340,26 @@ int fbgemmConv( blocking_params); break; } + case optimized_conv_t::directconv: { + // specialized direct convolution path + // std::cout << "Directconv fast path" << std::endl; + static_assert( + std::is_same:: + value, + "For directconv 2d, only requantized output is supported"); + fbgemmDirectConv( + conv_p, + // Aint8, + activations, + *(packed_weights.getPackedWForDirectconv()), + out, + outBuffer, + outProcess, + outProcess.getBias(), + thread_id, + num_threads); + break; + } case optimized_conv_t::fastpath1d: { break; } @@ -416,6 +467,15 @@ template bool takeDepthWiseFastPath<2, std::int16_t>( template bool takeDepthWiseFastPath<3, std::int16_t>( const conv_param_t<3>& conv_p); +template bool takeDirectConvPath<2, std::int32_t>( + const conv_param_t<2>& conv_p); +template bool takeDirectConvPath<3, std::int32_t>( + const conv_param_t<3>& conv_p); +template bool takeDirectConvPath<2, std::int16_t>( + const conv_param_t<2>& conv_p); +template bool takeDirectConvPath<3, std::int16_t>( + const conv_param_t<3>& conv_p); + template FBGEMM_API optimized_conv_t ConvFastPath<1, std::int32_t>(const conv_param_t<1>& conv_p); template FBGEMM_API optimized_conv_t diff --git a/src/GenerateKernelDirectConvU8S8S32ACC32.cc b/src/GenerateKernelDirectConvU8S8S32ACC32.cc index 1cb9f342f1..576286b034 100644 --- a/src/GenerateKernelDirectConvU8S8S32ACC32.cc +++ b/src/GenerateKernelDirectConvU8S8S32ACC32.cc @@ -12,7 +12,7 @@ namespace fbgemm { namespace x86 = asmjit::x86; /** - * Generate AVX512 instructions for computing block in the rank-k update of + * Generate AVX256 instructions for computing block in the rank-k update of * 32-bit Accumulation kernel. * * this compute block implements the following register blocking @@ -33,7 +33,7 @@ namespace x86 = asmjit::x86; */ /** - * Generate AVX512 instructions for storing the C registers back to the memory + * Generate AVX256 instructions for storing the C registers back to the memory * in 32-bit Accumulation kernel. */ template <> @@ -421,7 +421,7 @@ DirectConvCodeGenBase::getOrCreateDirectConv( } /** - * Generate AVX512 instructions for storing the C registers back to the memory + * Generate AVX256 instructions for storing the C registers back to the memory * in 32-bit Accumulation kernel. */ template <> @@ -454,7 +454,7 @@ void DirectConvCodeGenBase::storeCRegsTrans( } /** - * Generate AVX512 instructions for computing block in the rank-k update of + * Generate AVX256 instructions for computing block in the rank-k update of * 32-bit Accumulation kernel. The function generates the register blocking code for transposed @@ -588,22 +588,32 @@ void DirectConvCodeGenBase:: * */ + +/** + * Get or Create the AVX256 instructions for 32-bit Accumulation macro-kernel. + * + */ template <> template DirectConvCodeGenBase:: jit_micro_kernel_fp_convT DirectConvCodeGenBase:: - getOrCreateDirectConvTrans(bool accum, int32_t stride) { + getOrCreateDirectConvTrans( + bool accum, + int32_t stride, + int32_t numColRegs) { using VecRegT = typename simd_info::vec_reg_t; constexpr int numRegs = simd_info::NUM_VEC_REGS; - constexpr int vectorLen = simd_info::WIDTH_BYTES; + static constexpr int vectorLen = simd_info::WIDTH_BYTES; std::tuple kernelSig; - constexpr int mRowRegBlockSize = 2; - constexpr int mColRegBlockSize = 6; - constexpr int mRegBlockSize = mRowRegBlockSize * mColRegBlockSize; - constexpr int nRegBlockSize = 8; - constexpr int row_interleave = 4; + // int ichSize = 32; + int mRowRegBlockSize = 2; + int mColRegBlockSize = numColRegs; + int mRegBlockSize = mRowRegBlockSize * mColRegBlockSize; + int nRegBlockSize = 8; + // int nRegBlockSizeMin; + int row_interleave = 4; kernelSig = std::make_tuple(accum, stride, mRegBlockSize, nRegBlockSize); @@ -615,8 +625,7 @@ DirectConvCodeGenBase:: #if defined(FBGEMM_LOG_CODE) // generated code logging FILE* codeLogfile = fopen( - getCodeLoggingFile( - accum, stride, 0, 0, 0, mRegBlockSize, nRegBlockSize) + getCodeLoggingFile(accum, stride, mRegBlockSize, nRegBlockSize) .c_str(), "w"); asmjit::FileLogger* codeLogger = new asmjit::FileLogger(codeLogfile); @@ -689,9 +698,6 @@ DirectConvCodeGenBase:: gen16BitVectorOne(a, oneReg); a->imul(ldcReg, ldcReg, static_cast(sizeof(int32_t))); - // a->xor_(C_Offset.r32(), C_Offset.r32()); - - // a->mov(B_pf_saved, B_pf); int colRegs = maxNRegs; @@ -702,7 +708,6 @@ DirectConvCodeGenBase:: initCRegs(a, rowRegs, colRegs); // Loops over K: input channel - // corresponds to the "icb" loop in the pseudo code a->xor_(kIdx.r32(), kIdx.r32()); a->bind(LoopKLabel); @@ -753,13 +758,12 @@ DirectConvCodeGenBase:: // B for next block a->mov(buffer_B, buffer_B_saved); // increment C for next B block - // ldcReg already multiplied by 4 (sizeof(int32_t)) - a->imul(C_offset, ldcReg, static_cast(stride)); + a->imul( + C_offset, + ldcReg, + static_cast(stride)); // ldcReg already multiplied by 4 a->add(CBase, C_offset); - // a->add(CBase, static_cast(12*16*4)); - // storeCRegs(a, 12, 1, C_Offset, ldcReg, accum); - a->cmp(iIdx, i1); a->jl(LoopMBlocks); } @@ -889,6 +893,7 @@ template DirectConvCodeGenBase:: DirectConvCodeGenBase:: getOrCreateDirectConvTrans( bool accum, - int32_t stride); + int32_t stride, + int32_t numColRegs); } // namespace fbgemm diff --git a/src/PackWeightsForConv.cc b/src/PackWeightsForConv.cc index 86267833d2..f3059ab718 100644 --- a/src/PackWeightsForConv.cc +++ b/src/PackWeightsForConv.cc @@ -52,6 +52,14 @@ PackWeightsForConv::PackWeightsForConv( blocking_params); break; } + case optimized_conv_t::directconv: { + const int kernel_h = SPATIAL_DIM == 1 ? 1 : conv_p.K[SPATIAL_DIM - 2]; + const int kernel_w = conv_p.K[SPATIAL_DIM - 1]; + const int K = kernel_h * kernel_w; + W_dc_packed_ = std::make_shared( + conv_p.IC, conv_p.OC, K, sdata); + break; + } case optimized_conv_t::fastpath1d: { break; } diff --git a/src/PackWeightsForDirectConv.cc b/src/PackWeightsForDirectConv.cc new file mode 100644 index 0000000000..20de91517e --- /dev/null +++ b/src/PackWeightsForDirectConv.cc @@ -0,0 +1,494 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ +#define FBGEMM_EXPORTS +#include "fbgemm/FbgemmI8DirectconvAvx2.h" + +#include +#include + +#include "./DirectConv.h" +#include "./ExecuteKernel.h" +#include "./MaskAvx2.h" +#include "fbgemm/ConvUtils.h" +#include "fbgemm/Fbgemm.h" +#include "fbgemm/FbgemmBuild.h" +#include "fbgemm/UtilsAvx2.h" + +#include "./CodeGenHelpers.h" +#include "./OptimizedKernelsAvx2.h" +#include "./RefImplementations.h" +#include "./TransposeUtils.h" +#include "fbgemm/QuantUtilsAvx512.h" +namespace fbgemm { + +PackedDirectConvMatrix::PackedDirectConvMatrix( + int IC_per_G, + int OC_per_G, + int filter_prod, + const int8_t* smat) { + // Allocate packed arrays + int kernel_prod_aligned = (filter_prod + 1) / 2 * 2; + pmat_ = static_cast(fbgemmAlignedAlloc( + 64, + ((OC_per_G + 31) / 32 * 32) * kernel_prod_aligned * IC_per_G * + sizeof(int8_t))); + + // the transposed weight layout: W[oc/8][r][s][ic/4][8][4] + for (int g = 0; g < /* G */ 1; ++g) { + for (int k = 0; k < OC_per_G; ++k) { + for (int f = 0; f < filter_prod; ++f) { + for (int c = 0; c < IC_per_G; ++c) { + int ocB = k / 8; + int ocb = k % 8; + int icB = c / 4; + int icb = c % 4; + pmat_ + [((((g * (OC_per_G / 8) + ocB) * filter_prod + f) * + (IC_per_G / 4) + + icB) * + 8 + + ocb) * + 4 + + icb] = + smat[((g * OC_per_G + k) * filter_prod + f) * IC_per_G + c]; + } + } + } + } +} + +PackedDirectConvMatrix::~PackedDirectConvMatrix() { + fbgemmAlignedFree(pmat_); +} + +template +void PackedDirectConvMatrix::col_offsets_with_zero_pt_s8acc32_DirectConvT( + const fbgemm::conv_param_t& conv_p, + std::int32_t* B_zero_point, + std::vector& col_offsets, + int ncols_per_quant_group) { + // if use direct convolution implementation, compute the col_offsets + // of the weight matrix at the first time of inference. + // We need to know the shape of output matrix + // to compute col_offsets for direct convolution. + // Hence it cannot be called from inside weight packing function + // at initialization stage like other quantized conv implementation. + // Thus the col_offsets computation will be invoked at forward pass, + // and only the first pass will prepare the col_offsets. + if (first_call == false) { + return; + } + int IC = conv_p.IC; + int OC = conv_p.OC; + + int IN_DIM0 = conv_p.IN_DIM[0]; + int IN_DIM1 = conv_p.IN_DIM[1]; + int OUT_DIM0 = conv_p.OUT_DIM[0]; + int OUT_DIM1 = conv_p.OUT_DIM[1]; + int K0 = conv_p.K[0]; + int K1 = conv_p.K[1]; + int stride0 = conv_p.stride[0]; + int stride1 = conv_p.stride[1]; + + int MDim = conv_p.MB * OUT_DIM0 * OUT_DIM1; + int NDim = conv_p.OC / conv_p.G; + // int KDim = K[0] * K[1] * conv_p.IC; + + col_offsets.resize(MDim * NDim, 0); + std::fill(col_offsets.begin(), col_offsets.end(), 0); + std::vector count(MDim * NDim, 0); + + for (int oc = 0; oc < OC; oc++) { + for (int ih = 0; ih < IN_DIM0; ih++) { + for (int iw = 0; iw < IN_DIM1; iw++) { + for (int kh = 0; kh < K0; kh++) { + for (int kw = 0; kw < K1; kw++) { + for (int ic = 0; ic < IC; ic++) { + int oh = ih * stride0 + kh; + int ow = iw * stride1 + kw; + col_offsets[(oh * OUT_DIM1 + ow) * OC + oc] += pmat_ + [(((((oc / 8) * K0 + kh) * K1 + kw) * (IC / 4) + ic / 4) * 8 + + (oc % 8)) * + 4 + + (ic % 4)]; + count[(oh * OUT_DIM1 + ow) * OC + oc]++; + } + } + } + } + } + } + + for (int oc = 0; oc < OC; oc++) { + for (int oh = 0; oh < OUT_DIM0; oh++) { + for (int ow = 0; ow < OUT_DIM1; ow++) { + col_offsets[(oh * OUT_DIM1 + ow) * OC + oc] -= + B_zero_point[oc / ncols_per_quant_group] * + count[(oh * OUT_DIM1 + ow) * OC + oc]; + } + } + } + + first_call = false; +} + +template FBGEMM_API void +PackedDirectConvMatrix::col_offsets_with_zero_pt_s8acc32_DirectConvT<1>( + const fbgemm::conv_param_t<1>& conv_p, + std::int32_t* B_zero_point, + std::vector& col_offsets, + int ncols_per_quant_group); + +template FBGEMM_API void +PackedDirectConvMatrix::col_offsets_with_zero_pt_s8acc32_DirectConvT<2>( + const fbgemm::conv_param_t<2>& conv_p, + std::int32_t* B_zero_point, + std::vector& col_offsets, + int ncols_per_quant_group); + +template FBGEMM_API void +PackedDirectConvMatrix::col_offsets_with_zero_pt_s8acc32_DirectConvT<3>( + const fbgemm::conv_param_t<3>& conv_p, + std::int32_t* B_zero_point, + std::vector& col_offsets, + int ncols_per_quant_group); + +template +void directConvRowSum( + const conv_param_t& conv_p, + const uint8_t* A, + int32_t* inSum, + int32_t* rowSum) { + int IN0 = conv_p.IN_DIM[0]; + int IN1 = conv_p.IN_DIM[1]; + int IC = conv_p.IC; + int K0 = conv_p.K[0]; + int K1 = conv_p.K[1]; + int OUT0 = conv_p.OUT_DIM[0]; + int OUT1 = conv_p.OUT_DIM[1]; + int stride = conv_p.stride[1]; + + memset(rowSum, 0, sizeof(int32_t) * OUT0 * OUT1); + + for (int ih = 0; ih < IN0; ++ih) { + for (int iw = 0; iw < IN1; ++iw) { + inSum[ih * IN1 + iw] = reduceAvx2(A + ih * IN1 * IC + iw * IC, IC); + } + } + + for (int ih = 0; ih < IN0; ++ih) { + for (int iw = 0; iw < IN1; iw++) { + for (int r = 0; r < K0; ++r) { + for (int s = 0; s < K1; ++s) { + rowSum[(ih + r) * OUT1 + iw * stride + s] += inSum[ih * IN1 + iw]; + } + } + } + } + /* + compare_buffers( + rowSum, + rowoffsets, + OUT0, + OUT1, + OUT1, + 5); + */ +} + +template void directConvRowSum<1>( + const conv_param_t<1>& conv_p, + const uint8_t* A, + int32_t* inSum, + int32_t* rowSum); + +template void directConvRowSum<2>( + const conv_param_t<2>& conv_p, + const uint8_t* A, + int32_t* inSum, + int32_t* rowSum); + +template void directConvRowSum<3>( + const conv_param_t<3>& conv_p, + const uint8_t* A, + int32_t* inSum, + int32_t* rowSum); + +template < + int SPATIAL_DIM, + QuantizationGranularity Q_GRAN, + bool FUSE_RELU, + typename BIAS_TYPE> +void fbgemmDirectConv( + const conv_param_t& conv_p, + const uint8_t* Aint8, + PackedDirectConvMatrix& Bint8_tr, + uint8_t* C, + int32_t* C_buffer, + const ReQuantizeOutput& outProcess, + const BIAS_TYPE* bias, + // const int32_t* bias, + int thread_id, + int num_threads) { + // support for single thread now, + // will enable multithread later + if (thread_id > 0 || thread_id >= num_threads) { + return; + } + + if (SPATIAL_DIM != 2) { + assert(false && "1d/3d direct conv not supported"); + } else { + if (conv_p.transposed) { + DirectConvCodeGenBase:: + jit_micro_kernel_fp_convT fn; + DirectConvCodeGenBase codeObj; + /* + fn = codeObj.getOrCreateDirectConvTrans( + true, conv_p.stride[1]); + */ + fn = codeObj.getOrCreateDirectConvTrans( + true, conv_p.stride[1], conv_p.K[1]); + + int32_t* inSum = static_cast(fbgemmAlignedAlloc( + 64, conv_p.IN_DIM[0] * conv_p.IN_DIM[1] * sizeof(int32_t))); + int32_t* rowSum = static_cast(fbgemmAlignedAlloc( + 64, conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * sizeof(int32_t))); + + directConvRowSum(conv_p, Aint8, inSum, rowSum); + int kernel_dim = conv_p.K[0] * conv_p.K[1]; + + std::memset( + C_buffer, + 0, + sizeof(int32_t) * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC); + std::memset( + C, + 0, + sizeof(int8_t) * conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1] * conv_p.OC); + // no-op output process objects + for (int i = 0; i < conv_p.OC; i += 8) { + for (int j = 0; j < conv_p.IN_DIM[0]; j++) { + fn(Aint8 + j * conv_p.IC * conv_p.IN_DIM[1], + Bint8_tr.PackedMat() + i * kernel_dim * conv_p.IC, + C_buffer + j * conv_p.OUT_DIM[1] * conv_p.OC + i, + conv_p.IC, + conv_p.OC, + (conv_p.OC * conv_p.OUT_DIM[1] - conv_p.OC * conv_p.K[1]) * 4, + conv_p.IN_DIM[1]); + } + } + + int32_t A_zero_point = outProcess.getAZeroPoint(); + const int32_t* B_zero_point = outProcess.getBZeroPoint(); + // const float* C_multiplier = outProcess.getCMultiplier(); + const int32_t* col_offsets = outProcess.getColOffsets(); + + /* + int groups = 1; + if (Q_GRAN == QuantizationGranularity::OUT_CHANNEL) { + groups = conv_p.OC; + } + */ + requantizationParams_t reqObj = { + outProcess.getAZeroPoint(), + outProcess.getBZeroPoint(), + outProcess.getCZeroPoint(), + outProcess.getCMultiplier(), + rowSum, // rowOffsetBuf, + outProcess.getColOffsets(), + (outProcess.getBias()), + static_cast(conv_p.OC), // outProcess.getNCols(), + 1, // groups + outProcess.getActWScale()}; + + // Dispatch HAS_BIAS + if (bias == nullptr) { + // Dispatch A_SYMMETRIC and B_SYMMETRIC + if (A_zero_point == 0 || col_offsets == nullptr) { + if (Q_GRAN == QuantizationGranularity::TENSOR && + B_zero_point[0] == 0) { + requantizeOutputProcessingAvx2< + true, + true, + QuantizationGranularity::TENSOR, + false, // HAS_BIAS, + FUSE_RELU, + BIAS_TYPE, + true>( + C, + C_buffer, + {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC}, + conv_p.OC, + conv_p.OC, + reqObj); + } else { + requantizeOutputProcessingAvx2< + true, + false, + Q_GRAN, + false, // HAS_BIAS, + FUSE_RELU, + BIAS_TYPE, + true>( + C, + C_buffer, + {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC}, + conv_p.OC, + conv_p.OC, + reqObj); + } + } else { + if (Q_GRAN == QuantizationGranularity::TENSOR && + B_zero_point[0] == 0) { + requantizeOutputProcessingAvx2< + false, + true, + QuantizationGranularity::TENSOR, + false, // HAS_BIAS, + FUSE_RELU, + BIAS_TYPE, + true>( + C, + C_buffer, + {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC}, + conv_p.OC, + conv_p.OC, + reqObj); + } else { + requantizeOutputProcessingAvx2< + false, + false, + Q_GRAN, + false, // HAS_BIAS, + FUSE_RELU, + BIAS_TYPE, + true>( + C, + C_buffer, + {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC}, + conv_p.OC, + conv_p.OC, + reqObj); + } + } + } else { // has_bias == true + + // dispatch A_SYMMETRIC and B_SYMMETRIC + if (A_zero_point == 0 || col_offsets == nullptr) { + if (Q_GRAN == QuantizationGranularity::TENSOR && + B_zero_point[0] == 0) { + requantizeOutputProcessingAvx2< + true, + true, + QuantizationGranularity::TENSOR, + true, // HAS_BIAS, + FUSE_RELU, + BIAS_TYPE, + true>( + C, + C_buffer, + {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC}, + conv_p.OC, + conv_p.OC, + reqObj); + } else { + requantizeOutputProcessingAvx2< + true, + false, + Q_GRAN, + true, // HAS_BIAS, + FUSE_RELU, + BIAS_TYPE, + true>( + C, + C_buffer, + {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC}, + conv_p.OC, + conv_p.OC, + reqObj); + } + } else { + if (Q_GRAN == QuantizationGranularity::TENSOR && + B_zero_point[0] == 0) { + requantizeOutputProcessingAvx2< + false, + true, + QuantizationGranularity::TENSOR, + true, // HAS_BIAS, + FUSE_RELU, + BIAS_TYPE, + true>( + C, + C_buffer, + {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC}, + conv_p.OC, + conv_p.OC, + reqObj); + } else { + requantizeOutputProcessingAvx2< + false, + false, + Q_GRAN, + true, // HAS_BIAS, + FUSE_RELU, + BIAS_TYPE, + true>( + C, + C_buffer, + {0, conv_p.OUT_DIM[1] * conv_p.OUT_DIM[0], 0, conv_p.OC}, + conv_p.OC, + conv_p.OC, + reqObj); + } + } + } + fbgemmAlignedFree(inSum); + fbgemmAlignedFree(rowSum); + } // transposed conv + else { // non-transposed conv + assert(false && "non-transposed direct conv not integrated yet."); + } + } // else SPATIAL_DIM +} + +#define INSTANTIATE_REQUANTIZE_SPATIAL_DIM( \ + SPATIAL_DIM, Q_GRAN, RELU, BIAS_TYPE) \ + template void FBGEMM_API \ + fbgemmDirectConv( \ + const conv_param_t& conv_p, \ + const uint8_t* Aint8, \ + PackedDirectConvMatrix& Bint8_tr, \ + uint8_t* C, \ + int32_t* C_buffer, \ + const ReQuantizeOutput& outProcess, \ + const BIAS_TYPE* bias, \ + int thread_id, \ + int num_threads); + +#define INSTANTIATE_REQUANTIZE_BIAS_TYPE(Q_GRAN, RELU, BIAS_TYPE) \ + INSTANTIATE_REQUANTIZE_SPATIAL_DIM(1, Q_GRAN, RELU, BIAS_TYPE) \ + INSTANTIATE_REQUANTIZE_SPATIAL_DIM(2, Q_GRAN, RELU, BIAS_TYPE) \ + INSTANTIATE_REQUANTIZE_SPATIAL_DIM(3, Q_GRAN, RELU, BIAS_TYPE) + +#define INSTANTIATE_REQUANTIZE(Q_GRAN, RELU) \ + INSTANTIATE_REQUANTIZE_BIAS_TYPE(Q_GRAN, RELU, float) \ + INSTANTIATE_REQUANTIZE_BIAS_TYPE(Q_GRAN, RELU, int32_t) + +#define INSTANTIATE_Q_GRANS(RELU) \ + INSTANTIATE_REQUANTIZE(QuantizationGranularity::TENSOR, RELU) \ + INSTANTIATE_REQUANTIZE(QuantizationGranularity::GROUP, RELU) \ + INSTANTIATE_REQUANTIZE(QuantizationGranularity::OUT_CHANNEL, RELU) + +INSTANTIATE_Q_GRANS(true) +INSTANTIATE_Q_GRANS(false) + +#undef INSTANTIATE_REQUANTIZE_SPATIAL_DIM +#undef INSTANTIATE_REQUANTIZE_BIAS_TYPE +#undef INSTANTIATE_REQUANTIZE +#undef INSTANTIATE_Q_GRANS +} // namespace fbgemm diff --git a/test/I8DirectconvTest.cc b/test/I8DirectconvTest.cc new file mode 100644 index 0000000000..f76f483da7 --- /dev/null +++ b/test/I8DirectconvTest.cc @@ -0,0 +1,843 @@ +/* + * Copyright (c) Facebook, Inc. and its affiliates. + * All rights reserved. + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "bench/AlignedVec.h" +#include "bench/BenchUtils.h" +#include "fbgemm/Fbgemm.h" +#include "fbgemm/FbgemmI8DepthwiseAvx2.h" +#include "src/DirectConv.h" +#include "src/OptimizedKernelsAvx2.h" +#include "src/RefImplementations.h" + +using namespace std; + +namespace fbgemm { + +// From Xray OCR +// clang-format off +// conv_param_t<>(N, IC, OC, H, W, G, +// /* kern */ {kernel1, kernel2}, /* stride */ {stride1, stride2}, /* +//padding */ {pad, pad, pad, pad}, +// /* dialation */ {1, 1}, /* otpt_pad */ {0,0}, /* trans */ transpose), +// 2D conv shapes + vector> shapes = { + // MB, IC, OC, IH, IW, G, KH, KW, stride_h, stride_w, + // pad_h_top, pad_w_left, pad_h_bottom, pad_w_right, + // (dilation_h, dilation_w, output_padding_h, output_padding_w, tranpose) + // 2D convolutions + // regular + + // Ferraris Model + // Data from - + // https://docs.google.com/spreadsheets/d/1VM-nglZl-pSwBdgYm3VbeLRcORc5y_vTRl9anRCUSDQ/edit#gid=1776750723 + // conv_param_t<>(N, IC, OC, H, W, G, + // /* kern */ {kernel1, kernel2}, /* stride */ {stride1, stride2}, /* + //padding */ {pad, pad, pad, pad}, + // /* dialation */ {1, 1}, /* otpt_pad */ {0,0}, /* trans */ transpose), + + conv_param_t<>(1, 128, 128, {2, 257}, 1, {2, 6}, {1, 2}, {0, 0, 0, 0}, {1, 1}, {0, 0}, false), + conv_param_t<>(1, 16, 16, {2, 126}, 1, {2, 6}, {1, 2}, {0, 0, 0, 0}, {1, 1}, {0, 0}, false), + conv_param_t<>(1, 64, 64, {2, 257}, 1, {2, 6}, {1, 2}, {0, 0, 0, 0}, {1, 1}, {0, 0}, false), + }; + +vector> shapes_trans = { + conv_param_t<>(1, 256, 176, {2, 4}, 1, {2, 6}, {1, 2}, {0, 0, 0, 0}, + {1, 1}, {0, 0}, true), + conv_param_t<>(1, 128, 128, {4, 12}, 1, {2, 6}, {1, 1}, {0, 0, 0, 0}, + {1, 1}, {0, 0}, true), + conv_param_t<>(1, 512, 64, {4, 50}, 1, {2, 6}, {1, 1}, {0, 0, 0, 0}, + {1, 1}, {0, 0}, true), + +}; + +namespace { +/* +class FBGemmDirectConvTest + : public testing::TestWithParam> {}; +*/ +class FBGemmDirectConvTransTest + : public testing::TestWithParam> {}; + +class FBGemmDirectConvTransFbgemmTest + : public testing::TestWithParam> {}; + +} // namespace + +template +void transposeConvWeights_KwIchO8I4( + const conv_param_t& conv_p, + const std::int8_t* src, + std::int8_t* dest) { + int G = conv_p.G; + int IC_per_G = conv_p.IC / conv_p.G; + int OC_per_G = conv_p.OC / conv_p.G; + + int filter_prod = std::accumulate( + conv_p.K.begin(), + conv_p.K.begin() + SPATIAL_DIM, + 1, + std::multiplies()); + // Transforms weights from G K/G (T R S C/G) to G (T R S C/G) K/G format. + // the transposed weight layout: W[oc/8][h][w][ic/4][8][4] + for (int g = 0; g < G; ++g) { + for (int k = 0; k < OC_per_G; ++k) { + for (int f = 0; f < filter_prod; ++f) { + for (int c = 0; c < IC_per_G; ++c) { + int ocB = k / 8; + int ocb = k % 8; + int icB = c / 4; + int icb = c % 4; + dest + [(((ocB * filter_prod + f) * (IC_per_G / 4) + icB) * 8 + ocb) * + 4 + + icb] = + src[((g * OC_per_G + k) * filter_prod + f) * IC_per_G + c]; + } + } + } + } +} + +void directConvRowSum( + const conv_param_t<2>& conv_p, + uint8_t* A, + int32_t* inSum, + int32_t* rowSum) { + int IN0 = conv_p.IN_DIM[0]; + int IN1 = conv_p.IN_DIM[1]; + int IC = conv_p.IC; + int K0 = conv_p.K[0]; + int K1 = conv_p.K[1]; + int OUT0 = conv_p.OUT_DIM[0]; + int OUT1 = conv_p.OUT_DIM[1]; + int stride = conv_p.stride[1]; + + memset(rowSum, 0, sizeof(int32_t) * OUT0 * OUT1); + for (int ih = 0; ih < IN0; ++ih) + for (int iw = 0; iw < IN1; ++iw) { + inSum[ih * IN1 + iw] = reduceAvx2(A + ih * IN1 * IC + iw * IC, IC); + } + + + for (int ih = 0; ih < IN0; ++ih) + for (int iw = 0; iw < IN1; iw++) { + for (int r = 0; r < K0; ++r) { + for (int s = 0; s < K1; ++s) { + rowSum[(ih + r) * OUT1 + iw * stride + s] += inSum[ih * IN1 + iw]; + } + } + } + /* + compare_buffers( + rowSum, + rowoffsets, + OUT0, + OUT1, + OUT1, + 5); + */ +} + + +void col_offsets_with_zero_pt_s8acc32_DirectConvT_ref( + const conv_param_t<2>& conv_p, + const int8_t* Bint8, + const int32_t* B_zero_point, + int32_t* col_offsets, + int ncols_per_quant_group) { + int IC = conv_p.IC; + int OC = conv_p.OC; + array IN_DIM = conv_p.IN_DIM; + array OUT_DIM = conv_p.OUT_DIM; + array K = conv_p.K; + array stride = conv_p.stride; + + int MDim = conv_p.MB * OUT_DIM[0] * OUT_DIM[1]; + int NDim = conv_p.OC / conv_p.G; + // int KDim = K[0] * K[1] * conv_p.IC; + + std::memset(col_offsets, 0, MDim * NDim); + vector count(MDim * NDim, 0); + for (int oc = 0; oc < OC; oc++) { + for (int ih = 0; ih < IN_DIM[0]; ih++) { + for (int iw = 0; iw < IN_DIM[1]; iw++) { + for (int kh = 0; kh < K[0]; kh++) { + for (int kw = 0; kw < K[1]; kw++) { + for (int ic = 0; ic < IC; ic++) { + int oh = ih * stride[0] + kh; + int ow = iw * stride[1] + kw; + col_offsets[(oh * OUT_DIM[1] + ow) * OC + oc] += Bint8 + [(((((oc / 8) * K[0] + kh) * K[1] + kw) * (IC / 4) + ic / 4) * + 8 + + (oc % 8)) * + 4 + + (ic % 4)]; + count[(oh * OUT_DIM[1] + ow) * OC + oc]++; + } + } + } + } + } + } + + for (int oc = 0; oc < OC; oc++) { + for (int oh = 0; oh < OUT_DIM[0]; oh++) { + for (int ow = 0; ow < OUT_DIM[1]; ow++) { + col_offsets[(oh * OUT_DIM[1] + ow) * OC + oc] -= + B_zero_point[oc / ncols_per_quant_group] * + count[(oh * OUT_DIM[1] + ow) * OC + oc]; + } + } + } +} + + +void QuantizeDirectConv_ref( + const conv_param_t<2>& conv_p, + aligned_vector Aint8, + aligned_vector Bint8, + aligned_vector& Cint32_ref, + aligned_vector& Cint8_ref, + int32_t Aint8_zero_point, + aligned_vector C_multiplier, + int32_t C_zero_point, + aligned_vector Bint8_zero_point) { + int im_out_dim = accumulate( + conv_p.OUT_DIM.begin(), conv_p.OUT_DIM.end(), 1, multiplies()); + int kernel_dim = + accumulate(conv_p.K.begin(), conv_p.K.end(), 1, multiplies()); + + aligned_vector Bint8_tr( + kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); + + transposeConvWeights<2>(conv_p, Bint8.data(), Bint8_tr.data()); + conv_ref( + conv_p, + Aint8.data(), + Aint8_zero_point, + Bint8_tr.data(), + Cint32_ref.data()); + + // matrix dimensions after im2col + int MDim = conv_p.MB * im_out_dim; + int NDim = conv_p.OC / conv_p.G; + int KDim = kernel_dim * conv_p.IC; + int KDimPerGroup = KDim / conv_p.G; + + int OC_per_G = conv_p.OC / conv_p.G; + + // computing row offset + vector row_offsets(MDim); + vector Aint8_im2col(MDim * KDim); + im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_im2col.data()); + + vector row_offsets_sum(MDim, 0); + vector in_row_offsets_sum(conv_p.IN_DIM[0] * conv_p.IN_DIM[1], 0); + + // computing column offset + vector col_offsets(conv_p.OC); + for (int g = 0; g < conv_p.G; ++g) { + col_offsets_with_zero_pt_s8acc32_ref( + KDimPerGroup, + OC_per_G, + OC_per_G, + Bint8_tr.data() + g * KDimPerGroup * OC_per_G, + Bint8_zero_point.data(), + col_offsets.data() + g * OC_per_G, + conv_p.OC); + } + + for (int g = 0; g < conv_p.G; ++g) { + row_offsets_u8acc32_ref( + MDim, + KDimPerGroup, + KDim, + Aint8_im2col.data() + g * KDimPerGroup, + row_offsets.data()); + + requantize_u8acc32_ref( + MDim, + NDim, + conv_p.G * NDim, + Cint32_ref.data() + g * NDim, + Cint8_ref.data() + g * NDim, + C_multiplier.data() + g * NDim / conv_p.OC, + C_zero_point, + Aint8_zero_point, + Bint8_zero_point.data() + g * NDim / conv_p.OC, + row_offsets.data(), + col_offsets.data() + g * NDim, + nullptr, + conv_p.OC); + } +} + +/* +INSTANTIATE_TEST_CASE_P( + InstantiationName, + FBGemmDirectConvTest, + ::testing::Combine( + ::testing::Bool(), // a_symmetric + ::testing::Bool(), // b_symmetric + ::testing::Values(1, 2))); // oc_per_g + +TEST_P(FBGemmDirectConvTest, Test2D) { + bool a_symmetric, b_symmetric; + int oc_per_g; + tie(a_symmetric, b_symmetric, oc_per_g) = GetParam(); + + for (auto conv_p : shapes) { + int im_in_dim = accumulate( + conv_p.IN_DIM.begin(), conv_p.IN_DIM.end(), 1, multiplies()); + aligned_vector aBuf(conv_p.MB * im_in_dim * conv_p.IC); + + int kernel_dim = + accumulate(conv_p.K.begin(), conv_p.K.end(), 1, multiplies()); + + aligned_vector bBuf( + kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); + + + aligned_vector bBuf_pf( + kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); + + aligned_vector Bint8_tr( + kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); + + aligned_vector Bint8_tr_vec( + kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); + + aligned_vector C_multiplier(1); + randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2); + int32_t C_zero_point = 5; + + int im_out_dim = accumulate( + conv_p.OUT_DIM.begin(), conv_p.OUT_DIM.end(), 1, multiplies()); + // matrix dimensions after im2col + int MDim = conv_p.MB * im_out_dim; + int NDim = conv_p.OC / conv_p.G; + int KDim = kernel_dim * conv_p.IC; + int KDimPerGroup = KDim / conv_p.G; + + int OC_per_G = conv_p.OC / conv_p.G; + aligned_vector Cint32_ref(conv_p.MB * im_out_dim * conv_p.OC); + aligned_vector Cint8_ref(Cint32_ref.size(), 0); + aligned_vector Cint32_fb(Cint32_ref.size()); + aligned_vector Cint8_fb(Cint32_ref.size(), 0); + aligned_vector Cint8_fb2(Cint32_ref.size(), 0); + aligned_vector Cint32_fb2(Cint32_ref.size()); + + DirectConvCodeGenBase::jit_micro_kernel_fp fn; + // fn = GemmGetOrCreate( + // true, _MB, _NB, _KB); + DirectConvCodeGenBase codeObj; + + fn = codeObj.getOrCreateDirectConv( + true, + conv_p.OUT_DIM[1], + conv_p.IN_DIM[1] * conv_p.IC, + conv_p.stride[1] * conv_p.IC); + + randFill(aBuf, 0, 5); + randFill(bBuf, -4, 4); + randFill(bBuf_pf, -4, 4); + + int32_t Aint8_zero_point = 4; + aligned_vector Bint8_zero_point(1); + randFill(Bint8_zero_point, -3, -1); + + + aligned_vector bBuf_tr(bBuf.size()); + transposeConvWeights_KwIchO8I4<2>(conv_p, bBuf.data(), bBuf_tr.data()); + + for (int i = 0; i < conv_p.OC; i += 8) { + fn(aBuf.data(), + bBuf_tr.data() + i * kernel_dim * conv_p.IC, + bBuf_pf.data(), + Cint32_fb.data() + i, + conv_p.IC * conv_p.K[1], + conv_p.OC); + } + + // reference quantized int8 convolution implementation + QuantizeDirectConv_ref( + conv_p, + aBuf, + bBuf, + Cint32_ref, + Cint8_ref, + Aint8_zero_point, + C_multiplier, + C_zero_point, + Bint8_zero_point); + + compare_buffers( + Cint32_fb.data(), + Cint32_ref.data(), + conv_p.OUT_DIM[0] * conv_p.OUT_DIM[1], + conv_p.OC, + conv_p.OC, + 5); + + + // computing column offset + vector col_offsets(conv_p.OC); + transposeConvWeights<2>(conv_p, bBuf.data(), Bint8_tr.data()); + for (int g = 0; g < conv_p.G; ++g) { + col_offsets_with_zero_pt_s8acc32_ref( + KDimPerGroup, + OC_per_G, + OC_per_G, + Bint8_tr.data() + g * KDimPerGroup * OC_per_G, + Bint8_zero_point.data(), + col_offsets.data() + g * OC_per_G, + conv_p.OC); + } + + vector row_offsets(MDim); + vector Aint8_im2col(MDim * KDim); + im2col_ref(conv_p, aBuf.data(), Aint8_zero_point, Aint8_im2col.data()); + for (int g = 0; g < conv_p.G; ++g) { + row_offsets_u8acc32_ref( + MDim, + KDimPerGroup, + KDim, + Aint8_im2col.data() + g * KDimPerGroup, + row_offsets.data()); + + requantize_u8acc32_ref( + MDim, + NDim, + conv_p.G * NDim, + Cint32_fb.data() + g * NDim, + Cint8_fb.data() + g * NDim, + C_multiplier.data() + g * NDim / conv_p.OC, + C_zero_point, + Aint8_zero_point, + Bint8_zero_point.data() + g * NDim / conv_p.OC, + row_offsets.data(), + col_offsets.data() + g * NDim, + nullptr, + conv_p.OC); + } + + // correctness check + for (int n = 0; n < conv_p.MB; ++n) { + for (int h = 0; h < conv_p.OUT_DIM[0]; ++h) { + for (int w = 0; w < conv_p.OUT_DIM[1]; ++w) { + for (int k = 0; k < conv_p.OC; ++k) { + int H_OUT = conv_p.OUT_DIM[0]; + int W_OUT = conv_p.OUT_DIM[1]; + int OC = conv_p.OC; + int32_t expected = + Cint8_ref[((n * H_OUT + h) * W_OUT + w) * OC + k]; + int32_t actual = Cint8_fb[((n * H_OUT + h) * W_OUT + w) * OC + k]; + EXPECT_EQ(actual, expected) + << "Directconv " << conv_p.K[0] << "x" << conv_p.K[1] << " results differ at (" << n + << ", " << h << ", " << w << ", " << k << ")."; + } + } + } + } + + } // for each shape +} +*/ + + +INSTANTIATE_TEST_CASE_P( + InstantiationName, + FBGemmDirectConvTransTest, + ::testing::Combine( + ::testing::Bool(), // a_symmetric + ::testing::Bool(), // b_symmetric + ::testing::Values(1, 2))); // oc_per_g + +TEST_P(FBGemmDirectConvTransTest, Test2D) { + bool a_symmetric, b_symmetric; + int oc_per_g; + tie(a_symmetric, b_symmetric, oc_per_g) = GetParam(); + + for (auto conv_p : shapes_trans) { + int im_in_dim = accumulate( + conv_p.IN_DIM.begin(), conv_p.IN_DIM.end(), 1, multiplies()); + aligned_vector aBuf(conv_p.MB * im_in_dim * conv_p.IC); + + int kernel_dim = + accumulate(conv_p.K.begin(), conv_p.K.end(), 1, multiplies()); + + aligned_vector bBuf( + kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); + + + aligned_vector bBuf_pf( + kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); + + aligned_vector Bint8_tr( + kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); + + aligned_vector Bint8_tr_vec( + kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); + + aligned_vector C_multiplier(1); + randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2); + int32_t C_zero_point = 5; + + int im_out_dim = accumulate( + conv_p.OUT_DIM.begin(), conv_p.OUT_DIM.end(), 1, multiplies()); + // matrix dimensions after im2col + int MDim = conv_p.MB * im_out_dim; + int NDim = conv_p.OC / conv_p.G; + int KDim = kernel_dim * conv_p.IC; + int KDimPerGroup = KDim / conv_p.G; + + int OC_per_G = conv_p.OC / conv_p.G; + aligned_vector Cint32_ref(conv_p.MB * im_out_dim * conv_p.OC); + aligned_vector Cint8_ref(Cint32_ref.size(), 0); + aligned_vector Cint32_fb(Cint32_ref.size(), 0); + aligned_vector Cint8_fb(Cint32_ref.size(), 0); + aligned_vector Cint8_fb2(Cint32_ref.size(), 0); + aligned_vector Cint32_fb2(Cint32_ref.size()); + + randFill(aBuf, 0, 5); + randFill(bBuf, -4, 4); + randFill(bBuf_pf, -4, 4); + + int32_t Aint8_zero_point = 4; + aligned_vector Bint8_zero_point(1); + randFill(Bint8_zero_point, -3, -1); + + aligned_vector &Bint8 = bBuf; + aligned_vector &Aint8 = aBuf; + + // reference implementation + // conv_ref expects weights to be in G (R S C/G) K/G + transposeConvWeights<2>(conv_p, Bint8.data(), Bint8_tr.data()); + transposeConvWeights_KwIchO8I4<2>( + conv_p, Bint8.data(), Bint8_tr_vec.data()); + + conv_ref( + // DirectConvTrans_ref( + conv_p, + Aint8.data(), + Aint8_zero_point, + Bint8_tr.data(), + Cint32_ref.data()); + + + // computing row offset + vector row_offsets(MDim); + vector Aint8_im2col(MDim * KDim); + im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_im2col.data()); + + // computing column offset + vector col_offsets(conv_p.OC); + for (int g = 0; g < conv_p.G; ++g) { + col_offsets_with_zero_pt_s8acc32_ref( + KDimPerGroup, + OC_per_G, + OC_per_G, + Bint8_tr.data() + g * KDimPerGroup * OC_per_G, + Bint8_zero_point.data(), + col_offsets.data() + g * OC_per_G, + conv_p.OC); + } + + for (int g = 0; g < conv_p.G; ++g) { + row_offsets_u8acc32_ref( + MDim, + KDimPerGroup, + KDim, + Aint8_im2col.data() + g * KDimPerGroup, + row_offsets.data()); + + requantize_u8acc32_ref( + MDim, + NDim, + conv_p.G * NDim, + Cint32_ref.data() + g * NDim, + Cint8_ref.data() + g * NDim, + C_multiplier.data() + g * NDim / conv_p.OC, + C_zero_point, + Aint8_zero_point, + Bint8_zero_point.data() + g * NDim / conv_p.OC, + row_offsets.data(), + col_offsets.data() + g * NDim, + nullptr, + conv_p.OC); + } + + // computing column offset + vector col_offsetsT(conv_p.OC * MDim); + for (int g = 0; g < conv_p.G; ++g) { + col_offsets_with_zero_pt_s8acc32_DirectConvT_ref( + conv_p, + Bint8_tr_vec.data() + g * KDimPerGroup * OC_per_G, + Bint8_zero_point.data(), + col_offsetsT.data() + g * OC_per_G, + conv_p.OC); + } + + string runType; + + PackedDirectConvMatrix packedB(conv_p.IC, conv_p.OC, kernel_dim, Bint8.data()); + + DoNothing<> doNothingObj{}; + ReQuantizeOutput outputProcObj( + doNothingObj, + C_multiplier.data(), + C_zero_point, + Aint8_zero_point, + Bint8_zero_point.data(), + nullptr, // row offsets + col_offsetsT.data(), + nullptr, // bias + conv_p.OC, + conv_p.G); + + int32_t* bias_p = nullptr; + fbgemmDirectConv(conv_p, + Aint8.data(), + packedB, + Cint8_fb.data(), + Cint32_fb.data(), + outputProcObj, + bias_p, //bias + 0, + 1); + + /* + compare_buffers( + Cint8_ref.data(), + Cint8_fb.data(), + MDim, + NDim * conv_p.G, + NDim * conv_p.G, + 5); + */ + + // correctness check + for (int n = 0; n < conv_p.MB; ++n) { + for (int h = 0; h < conv_p.OUT_DIM[0]; ++h) { + for (int w = 0; w < conv_p.OUT_DIM[1]; ++w) { + for (int k = 0; k < conv_p.OC; ++k) { + int H_OUT = conv_p.OUT_DIM[0]; + int W_OUT = conv_p.OUT_DIM[1]; + int OC = conv_p.OC; + int32_t expected = + Cint8_ref[((n * H_OUT + h) * W_OUT + w) * OC + k]; + int32_t actual = Cint8_fb[((n * H_OUT + h) * W_OUT + w) * OC + k]; + EXPECT_EQ(actual, expected) + << "DirectconvTrans " << conv_p.K[0] << "x" << conv_p.K[1] << " results differ at (" << n + << ", " << h << ", " << w << ", " << k << ")."; + } + } + } + } + + } // for each shape +} + + +INSTANTIATE_TEST_CASE_P( + InstantiationName, + FBGemmDirectConvTransFbgemmTest, + ::testing::Combine( + ::testing::Bool(), // a_symmetric + ::testing::Bool(), // b_symmetric + ::testing::Values(1, 2))); // oc_per_g + + +TEST_P(FBGemmDirectConvTransFbgemmTest, Test2D) { + bool a_symmetric, b_symmetric; + int oc_per_g; + tie(a_symmetric, b_symmetric, oc_per_g) = GetParam(); + + for (auto conv_p : shapes_trans) { + int im_in_dim = accumulate( + conv_p.IN_DIM.begin(), conv_p.IN_DIM.end(), 1, multiplies()); + aligned_vector aBuf(conv_p.MB * im_in_dim * conv_p.IC); + + int kernel_dim = + accumulate(conv_p.K.begin(), conv_p.K.end(), 1, multiplies()); + + aligned_vector bBuf( + kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); + + + aligned_vector bBuf_pf( + kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); + + aligned_vector Bint8_tr( + kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); + + aligned_vector Bint8_tr_vec( + kernel_dim * conv_p.IC * (conv_p.OC / conv_p.G)); + + aligned_vector C_multiplier(1); + randFill(C_multiplier, 0.001234f / 2, 0.001234f * 3 / 2); + int32_t C_zero_point = 5; + + int im_out_dim = accumulate( + conv_p.OUT_DIM.begin(), conv_p.OUT_DIM.end(), 1, multiplies()); + // matrix dimensions after im2col + int MDim = conv_p.MB * im_out_dim; + int NDim = conv_p.OC / conv_p.G; + int KDim = kernel_dim * conv_p.IC; + int KDimPerGroup = KDim / conv_p.G; + + int OC_per_G = conv_p.OC / conv_p.G; + aligned_vector Cint32_ref(conv_p.MB * im_out_dim * conv_p.OC); + aligned_vector Cint8_ref(Cint32_ref.size(), 0); + aligned_vector Cint32_fb(Cint32_ref.size(), 0); + aligned_vector Cint8_fb(Cint32_ref.size(), 0); + aligned_vector Cint8_fb2(Cint32_ref.size(), 0); + aligned_vector Cint32_fb2(Cint32_ref.size()); + + randFill(aBuf, 0, 5); + randFill(bBuf, -4, 4); + randFill(bBuf_pf, -4, 4); + + int32_t Aint8_zero_point = 4; + aligned_vector Bint8_zero_point(1); + randFill(Bint8_zero_point, -3, -1); + + aligned_vector &Bint8 = bBuf; + aligned_vector &Aint8 = aBuf; + + // reference implementation + // conv_ref expects weights to be in G (R S C/G) K/G + transposeConvWeights<2>(conv_p, Bint8.data(), Bint8_tr.data()); + + conv_ref( + // DirectConvTrans_ref( + conv_p, + Aint8.data(), + Aint8_zero_point, + Bint8_tr.data(), + Cint32_ref.data()); + + + // computing row offset + vector row_offsets(MDim); + vector Aint8_im2col(MDim * KDim); + im2col_ref(conv_p, Aint8.data(), Aint8_zero_point, Aint8_im2col.data()); + + // computing column offset + vector col_offsets(conv_p.OC); + for (int g = 0; g < conv_p.G; ++g) { + col_offsets_with_zero_pt_s8acc32_ref( + KDimPerGroup, + OC_per_G, + OC_per_G, + Bint8_tr.data() + g * KDimPerGroup * OC_per_G, + Bint8_zero_point.data(), + col_offsets.data() + g * OC_per_G, + conv_p.OC); + } + + for (int g = 0; g < conv_p.G; ++g) { + row_offsets_u8acc32_ref( + MDim, + KDimPerGroup, + KDim, + Aint8_im2col.data() + g * KDimPerGroup, + row_offsets.data()); + + requantize_u8acc32_ref( + MDim, + NDim, + conv_p.G * NDim, + Cint32_ref.data() + g * NDim, + Cint8_ref.data() + g * NDim, + C_multiplier.data() + g * NDim / conv_p.OC, + C_zero_point, + Aint8_zero_point, + Bint8_zero_point.data() + g * NDim / conv_p.OC, + row_offsets.data(), + col_offsets.data() + g * NDim, + nullptr, + conv_p.OC); + } + + // fbgemm top-level function for direct conv path + PackWeightsForConv<2> packedB_2D(conv_p, Bint8.data()); + + vector col_offsetsT(conv_p.OC * MDim); + packedB_2D.getPackedWForDirectconv().get()->col_offsets_with_zero_pt_s8acc32_DirectConvT( + conv_p, + Bint8_zero_point.data(), + col_offsetsT, + conv_p.OC); + + DoNothing<> doNothingObj{}; + ReQuantizeOutput outputProcObj( + doNothingObj, + C_multiplier.data(), + C_zero_point, + Aint8_zero_point, + Bint8_zero_point.data(), + nullptr, // row offsets + col_offsetsT.data(), + nullptr, // bias + conv_p.OC, + conv_p.G); + + fbgemmConv( + conv_p, + Aint8.data(), + packedB_2D, + Cint8_fb.data(), + Cint32_fb.data(), + outputProcObj, + 0, + 1); + + /* + compare_buffers( + Cint8_ref.data(), + Cint8_fb.data(), + MDim, + NDim * conv_p.G, + NDim * conv_p.G, + 5); + */ + + // correctness check + for (int n = 0; n < conv_p.MB; ++n) { + for (int h = 0; h < conv_p.OUT_DIM[0]; ++h) { + for (int w = 0; w < conv_p.OUT_DIM[1]; ++w) { + for (int k = 0; k < conv_p.OC; ++k) { + int H_OUT = conv_p.OUT_DIM[0]; + int W_OUT = conv_p.OUT_DIM[1]; + int OC = conv_p.OC; + int32_t expected = + Cint8_ref[((n * H_OUT + h) * W_OUT + w) * OC + k]; + int32_t actual = Cint8_fb[((n * H_OUT + h) * W_OUT + w) * OC + k]; + EXPECT_EQ(actual, expected) + << "DirectconvTrans " << conv_p.K[0] << "x" << conv_p.K[1] << " results differ at (" << n + << ", " << h << ", " << w << ", " << k << ")."; + } + } + } + } + + } // for each shape +} + +} // fbgemm namespace diff --git a/test/UniConvTest.cc b/test/UniConvTest.cc index 809346952a..52e3f91ece 100644 --- a/test/UniConvTest.cc +++ b/test/UniConvTest.cc @@ -135,6 +135,13 @@ GetShapes_() { conv_param_t<>(1, 32, 32, {10, 30}, 8, {3, 5}, {1, 1}, {1, 1, 1, 1}, {1, 1}, {0, 0}, true), conv_param_t<>(1, 32, 32, {10, 30}, 8, {5, 3}, {1, 1}, {1, 1, 1, 1}, {1, 1}, {0, 0}, true), conv_param_t<>(1, 32, 32, {10, 30}, 8, {5, 3}, {1, 1}, {1, 1, 1, 1}, {2, 2}, {0, 0}, true), + // directconv + conv_param_t<>(1, 256, 176, {2, 4}, 1, {2, 6}, {1, 2}, {0, 0, 0, 0}, + {1, 1}, {0, 0}, true), + conv_param_t<>(1, 128, 128, {4, 12}, 1, {2, 6}, {1, 1}, {0, 0, 0, 0}, + {1, 1}, {0, 0}, true), + conv_param_t<>(1, 512, 64, {4, 50}, 1, {2, 6}, {1, 1}, {0, 0, 0, 0}, + {1, 1}, {0, 0}, true), }; return shapes; } @@ -205,6 +212,8 @@ TEST_P(uniConvTest, packingTest) { << "pointwise packed matrix should be null"; ASSERT_NE(packedB_1D.getPackedWForDepthwise(), nullptr) << "depthwise packed matrix is null"; + ASSERT_EQ(packedB_1D.getPackedWForDirectconv(), nullptr) + << "directconv packed matrix should be null"; break; } case optimized_conv_t::groupwise: { @@ -216,6 +225,8 @@ TEST_P(uniConvTest, packingTest) { << "pointwise packed matrix should be null"; ASSERT_NE(packedB_1D.getPackedWForGroupwise(), nullptr) << "Groupwise packed matrix is null"; + ASSERT_EQ(packedB_1D.getPackedWForDirectconv(), nullptr) + << "directconv packed matrix should be null"; break; } case optimized_conv_t::pointwise: { @@ -227,6 +238,21 @@ TEST_P(uniConvTest, packingTest) { << "Groupwise packed matrix should be null"; ASSERT_NE(packedB_1D.getPackedWForPointwise(), nullptr) << "pointwise packed matrix is null"; + ASSERT_EQ(packedB_1D.getPackedWForDirectconv(), nullptr) + << "directconv packed matrix should be null"; + break; + } + case optimized_conv_t::directconv: { + ASSERT_EQ(packedB_1D.getPackedWForDepthwise(), nullptr) + << "depthwise packed matrix should be null"; + ASSERT_EQ(packedB_1D.getPackedWForGroupwise(), nullptr) + << "groupwise packed matrix should be null"; + ASSERT_EQ(packedB_1D.getPackedWForPointwise(), nullptr) + << "pointwise packed matrix should be null"; + ASSERT_NE(packedB_1D.getPackedWForDirectconv(), nullptr) + << "directconv packed matrix is null"; + ASSERT_EQ(packedB_1D.getPackedWForIm2col(), nullptr) + << "im2col packed matrix should be null"; break; } case optimized_conv_t::fastpath1d: { @@ -239,6 +265,8 @@ TEST_P(uniConvTest, packingTest) { << "groupwise packed matrix should be null"; ASSERT_EQ(packedB_1D.getPackedWForPointwise(), nullptr) << "pointwise packed matrix should be null"; + ASSERT_EQ(packedB_1D.getPackedWForDirectconv(), nullptr) + << "directconv packed matrix should be null"; ASSERT_NE(packedB_1D.getPackedWForIm2col(), nullptr) << "im2col packed matrix is null"; break; @@ -270,6 +298,8 @@ TEST_P(uniConvTest, packingTest) { << "pointwise packed matrix should be null"; ASSERT_NE(packedB_2D.getPackedWForDepthwise(), nullptr) << "depthwise packed matrix is null"; + ASSERT_EQ(packedB_2D.getPackedWForDirectconv(), nullptr) + << "directconv packed matrix should be null"; break; } case optimized_conv_t::groupwise: { @@ -281,6 +311,8 @@ TEST_P(uniConvTest, packingTest) { << "pointwise packed matrix should be null"; ASSERT_NE(packedB_2D.getPackedWForGroupwise(), nullptr) << "Groupwise packed matrix is null"; + ASSERT_EQ(packedB_2D.getPackedWForDirectconv(), nullptr) + << "directconv packed matrix should be null"; break; } case optimized_conv_t::pointwise: { @@ -292,6 +324,21 @@ TEST_P(uniConvTest, packingTest) { << "Groupwise packed matrix should be null"; ASSERT_NE(packedB_2D.getPackedWForPointwise(), nullptr) << "pointwise packed matrix is null"; + ASSERT_EQ(packedB_2D.getPackedWForDirectconv(), nullptr) + << "directconv packed matrix should be null"; + break; + } + case optimized_conv_t::directconv: { + ASSERT_EQ(packedB_2D.getPackedWForDepthwise(), nullptr) + << "depthwise packed matrix should be null"; + ASSERT_EQ(packedB_2D.getPackedWForGroupwise(), nullptr) + << "groupwise packed matrix should be null"; + ASSERT_EQ(packedB_2D.getPackedWForPointwise(), nullptr) + << "pointwise packed matrix should be null"; + ASSERT_NE(packedB_2D.getPackedWForDirectconv(), nullptr) + << "directconv packed matrix is null"; + ASSERT_EQ(packedB_2D.getPackedWForIm2col(), nullptr) + << "im2col packed matrix should be null"; break; } case optimized_conv_t::fastpath1d: { @@ -306,6 +353,8 @@ TEST_P(uniConvTest, packingTest) { << "pointwise packed matrix should be null"; ASSERT_NE(packedB_2D.getPackedWForIm2col(), nullptr) << "im2col packed matrix is null"; + ASSERT_EQ(packedB_2D.getPackedWForDirectconv(), nullptr) + << "directconv packed matrix should be null"; break; } } @@ -335,6 +384,8 @@ TEST_P(uniConvTest, packingTest) { << "pointwise packed matrix should be null"; ASSERT_NE(packedB_3D.getPackedWForDepthwise(), nullptr) << "depthwise packed matrix is null"; + ASSERT_EQ(packedB_3D.getPackedWForDirectconv(), nullptr) + << "directconv packed matrix should be null"; break; } case optimized_conv_t::groupwise: { @@ -346,6 +397,8 @@ TEST_P(uniConvTest, packingTest) { << "im2col packed matrix should be null"; ASSERT_NE(packedB_3D.getPackedWForGroupwise(), nullptr) << "Groupwise packed matrix is null"; + ASSERT_EQ(packedB_3D.getPackedWForDirectconv(), nullptr) + << "directconv packed matrix should be null"; break; } case optimized_conv_t::pointwise: { @@ -357,6 +410,21 @@ TEST_P(uniConvTest, packingTest) { << "im2col packed matrix should be null"; ASSERT_NE(packedB_3D.getPackedWForPointwise(), nullptr) << "pointwise packed matrix is null"; + ASSERT_EQ(packedB_3D.getPackedWForDirectconv(), nullptr) + << "directconv packed matrix should be null"; + break; + } + case optimized_conv_t::directconv: { + ASSERT_EQ(packedB_3D.getPackedWForDepthwise(), nullptr) + << "depthwise packed matrix should be null"; + ASSERT_EQ(packedB_3D.getPackedWForGroupwise(), nullptr) + << "groupwise packed matrix should be null"; + ASSERT_EQ(packedB_3D.getPackedWForPointwise(), nullptr) + << "pointwise packed matrix should be null"; + ASSERT_NE(packedB_3D.getPackedWForDirectconv(), nullptr) + << "directconv packed matrix is null"; + ASSERT_EQ(packedB_3D.getPackedWForIm2col(), nullptr) + << "im2col packed matrix should be null"; break; } case optimized_conv_t::fastpath1d: { @@ -371,6 +439,8 @@ TEST_P(uniConvTest, packingTest) { << "pointwise packed matrix should be null"; ASSERT_NE(packedB_3D.getPackedWForIm2col(), nullptr) << "im2col packed matrix is null"; + ASSERT_EQ(packedB_3D.getPackedWForDirectconv(), nullptr) + << "directconv packed matrix should be null"; break; } } @@ -533,6 +603,33 @@ TEST(uniConvTest, cornerCases) { 1); } +template +bool takeDirectConvPath(const conv_param_t& conv_p) { + // Note: Direct convolutions (2D) are optimized for + // filter size: 2 x 1 to 2 x 6, transposed conv, + // in_channel % 8 == 0, out_channel % 8 == 0 + // stride = 1 or 2 + // padding = 0 ( non-zero padding will be supported soon) + bool ret = std::is_same::value && conv_p.transposed && + conv_p.G == 1 && conv_p.IC % 8 == 0 && conv_p.OC % 8 == 0 && + std::all_of( + conv_p.stride.begin(), + conv_p.stride.end(), + [](int i) { return i == 1 || i == 2; }) && + SPATIAL_DIM == 2 && conv_p.K[SPATIAL_DIM - 2] == 2 && + conv_p.K[SPATIAL_DIM - 1] <= 6 && + std::all_of(conv_p.dilation.begin(), conv_p.dilation.end(), [](int i) { + return i == 1; + }); + // Check pads: zero padding + for (int i = 0; i < SPATIAL_DIM; ++i) { + if (conv_p.pad[i] != 0) { + ret = false; + } + } + return ret; +} + /** * @brief Unit test for uint8 activations, int8 weights, and 32-bit * accumulation. Output processing: requantization -> nothing @@ -693,6 +790,17 @@ void runRequantizeTest( PackWeightsForConv packedWeights(conv_p, Bint8.data()); + // DirectConv col_offsets is handled differently + if (takeDirectConvPath(conv_p)) { + packedWeights.getPackedWForDirectconv() + .get() + ->col_offsets_with_zero_pt_s8acc32_DirectConvT( + conv_p, + Bint8_zero_point.data(), + col_offsets, + ncols_per_quant_group); + } + // TODO: Uncomment once we support multiple threads in fbgemmGroupwiseConv // #ifdef _OPENMP // #pragma omp parallel