Skip to content

Commit

Permalink
Better groupwise conv for small number of channels per group (#145)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #145

Groupwise convolution now supports C_per_G = 2, 4, 8, 16 and stride = 1 and 2

We can also parallelize groupwise now.

~~TODO: Add performance numbers.  ~~

{F218426946}

The following results are on T1.

**Existing Implementation on Resnext shapes**
```
    MB, IC, OC, IH, IW, KH, KW, stride_h, stride_w, pad_h, pad_w, Type, M, N, K,  GOPS
   1, 128, 128, 56, 56, 32, 3, 3, 1, 1, 1, 1,   FusedIm2Col,   3136,      4,   1152, 0.75
   1, 128, 128, 56, 56, 32, 3, 3, 1, 1, 1, 1,        direct,   3136,      4,   1152, 24.01
   1, 256, 256, 28, 28, 32, 3, 3, 1, 1, 1, 1,   FusedIm2Col,    784,      8,   2304, 2.93
   1, 256, 256, 28, 28, 32, 3, 3, 1, 1, 1, 1,        direct,    784,      8,   2304, 34.34
   1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1,   FusedIm2Col,    196,     16,   4608, 10.34
   1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1,        direct,    196,     16,   4608, 41.46
   1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1,   FusedIm2Col,    196,     16,   4608, 10.35
   1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1,        direct,    196,     16,   4608, 41.82
```

**This diff**
```
    MB, IC, OC, IH, IW, KH, KW, stride_h, stride_w, pad_h, pad_w, Type, M, N, K,  GOPS
   1, 128, 128, 56, 56, 32, 3, 3, 1, 1, 1, 1,   FusedIm2Col,   3136,      4,   1152, 0.75
   1, 128, 128, 56, 56, 32, 3, 3, 1, 1, 1, 1,        direct,   3136,      4,   1152, 14.82
   1, 256, 256, 28, 28, 32, 3, 3, 1, 1, 1, 1,   FusedIm2Col,    784,      8,   2304, 2.92
   1, 256, 256, 28, 28, 32, 3, 3, 1, 1, 1, 1,        direct,    784,      8,   2304, 28.93
   1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1,   FusedIm2Col,    196,     16,   4608, 10.32
   1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1,        direct,    196,     16,   4608, 41.80
   1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1,   FusedIm2Col,    196,     16,   4608, 10.35
   1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1,        direct,    196,     16,   4608, 41.89
```

Note: The first shape is slower probably due to the more frequent calling of output processing. Needs further analysis.

**Extra shapes supported in this diff**
```
   1, 256, 256, 56, 56, 32, 3, 3, 2, 2, 1, 1,   FusedIm2Col,    784,      8,   2304, 2.72
   1, 256, 256, 56, 56, 32, 3, 3, 2, 2, 1, 1,        direct,    784,      8,   2304, 23.50
   1, 512, 512, 28, 28, 32, 3, 3, 2, 2, 1, 1,   FusedIm2Col,    196,     16,   4608, 9.84
   1, 512, 512, 28, 28, 32, 3, 3, 2, 2, 1, 1,        direct,    196,     16,   4608, 35.08
   1, 64, 64, 28, 28, 32, 3, 3, 1, 1, 1, 1,   FusedIm2Col,    784,      2,    576, 0.18
   1, 64, 64, 28, 28, 32, 3, 3, 1, 1, 1, 1,        direct,    784,      2,    576, 7.50
```

**This diff with with OMP_NUM_THREADS=2**
```
    MB, IC, OC, IH, IW, KH, KW, stride_h, stride_w, pad_h, pad_w, Type, M, N, K,  GOPS
   1, 128, 128, 56, 56, 32, 3, 3, 1, 1, 1, 1,   FusedIm2Col,   3136,      4,   1152, 0.75
   1, 128, 128, 56, 56, 32, 3, 3, 1, 1, 1, 1,        direct,   3136,      4,   1152, 30.57
   1, 256, 256, 28, 28, 32, 3, 3, 1, 1, 1, 1,   FusedIm2Col,    784,      8,   2304, 2.92
   1, 256, 256, 28, 28, 32, 3, 3, 1, 1, 1, 1,        direct,    784,      8,   2304, 52.47
   1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1,   FusedIm2Col,    196,     16,   4608, 10.27
   1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1,        direct,    196,     16,   4608, 68.01
   1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1,   FusedIm2Col,    196,     16,   4608, 10.28
   1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1,        direct,    196,     16,   4608, 66.12
   1, 256, 256, 56, 56, 32, 3, 3, 2, 2, 1, 1,   FusedIm2Col,    784,      8,   2304, 2.77
   1, 256, 256, 56, 56, 32, 3, 3, 2, 2, 1, 1,        direct,    784,      8,   2304, 47.90
   1, 512, 512, 28, 28, 32, 3, 3, 2, 2, 1, 1,   FusedIm2Col,    196,     16,   4608, 9.59
   1, 512, 512, 28, 28, 32, 3, 3, 2, 2, 1, 1,        direct,    196,     16,   4608, 60.01
   1, 64, 64, 28, 28, 32, 3, 3, 1, 1, 1, 1,   FusedIm2Col,    784,      2,    576, 0.18
   1, 64, 64, 28, 28, 32, 3, 3, 1, 1, 1, 1,        direct,    784,      2,    576, 9.03
```

**This diff with with OMP_NUM_THREADS=4**
```
    MB, IC, OC, IH, IW, KH, KW, stride_h, stride_w, pad_h, pad_w, Type, M, N, K,  GOPS
   1, 128, 128, 56, 56, 32, 3, 3, 1, 1, 1, 1,   FusedIm2Col,   3136,      4,   1152, 0.75
   1, 128, 128, 56, 56, 32, 3, 3, 1, 1, 1, 1,        direct,   3136,      4,   1152, 48.39
   1, 256, 256, 28, 28, 32, 3, 3, 1, 1, 1, 1,   FusedIm2Col,    784,      8,   2304, 2.92
   1, 256, 256, 28, 28, 32, 3, 3, 1, 1, 1, 1,        direct,    784,      8,   2304, 80.02
   1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1,   FusedIm2Col,    196,     16,   4608, 10.26
   1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1,        direct,    196,     16,   4608, 91.99
   1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1,   FusedIm2Col,    196,     16,   4608, 10.24
   1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1,        direct,    196,     16,   4608, 93.38
   1, 256, 256, 56, 56, 32, 3, 3, 2, 2, 1, 1,   FusedIm2Col,    784,      8,   2304, 2.85
   1, 256, 256, 56, 56, 32, 3, 3, 2, 2, 1, 1,        direct,    784,      8,   2304, 75.54
   1, 512, 512, 28, 28, 32, 3, 3, 2, 2, 1, 1,   FusedIm2Col,    196,     16,   4608, 9.84
   1, 512, 512, 28, 28, 32, 3, 3, 2, 2, 1, 1,        direct,    196,     16,   4608, 86.84
   1, 64, 64, 28, 28, 32, 3, 3, 1, 1, 1, 1,   FusedIm2Col,    784,      2,    576, 0.18
   1, 64, 64, 28, 28, 32, 3, 3, 1, 1, 1, 1,        direct,    784,      2,    576, 9.88
```

**This diff with with OMP_NUM_THREADS=8**
```
    MB, IC, OC, IH, IW, KH, KW, stride_h, stride_w, pad_h, pad_w, Type, M, N, K,  GOPS
   1, 128, 128, 56, 56, 32, 3, 3, 1, 1, 1, 1,   FusedIm2Col,   3136,      4,   1152, 0.75
   1, 128, 128, 56, 56, 32, 3, 3, 1, 1, 1, 1,        direct,   3136,      4,   1152, 81.98
   1, 256, 256, 28, 28, 32, 3, 3, 1, 1, 1, 1,   FusedIm2Col,    784,      8,   2304, 2.92
   1, 256, 256, 28, 28, 32, 3, 3, 1, 1, 1, 1,        direct,    784,      8,   2304, 104.51
   1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1,   FusedIm2Col,    196,     16,   4608, 10.25
   1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1,        direct,    196,     16,   4608, 115.44
   1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1,   FusedIm2Col,    196,     16,   4608, 10.24
   1, 512, 512, 14, 14, 32, 3, 3, 1, 1, 1, 1,        direct,    196,     16,   4608, 113.22
   1, 256, 256, 56, 56, 32, 3, 3, 2, 2, 1, 1,   FusedIm2Col,    784,      8,   2304, 2.85
   1, 256, 256, 56, 56, 32, 3, 3, 2, 2, 1, 1,        direct,    784,      8,   2304, 97.82
   1, 512, 512, 28, 28, 32, 3, 3, 2, 2, 1, 1,   FusedIm2Col,    196,     16,   4608, 9.81
   1, 512, 512, 28, 28, 32, 3, 3, 2, 2, 1, 1,        direct,    196,     16,   4608, 113.20
   1, 64, 64, 28, 28, 32, 3, 3, 1, 1, 1, 1,   FusedIm2Col,    784,      2,    576, 0.18
   1, 64, 64, 28, 28, 32, 3, 3, 1, 1, 1, 1,        direct,    784,      2,    576, 9.99
```

Reviewed By: jspark1105

Differential Revision: D17934344

fbshipit-source-id: cc4f8100f6f2c4e7e66c40f332121bfb7355c551
  • Loading branch information
dskhudia authored and facebook-github-bot committed Oct 23, 2019
1 parent 00dfc2c commit cf7a2fb
Show file tree
Hide file tree
Showing 10 changed files with 1,433 additions and 2,194 deletions.
129 changes: 71 additions & 58 deletions bench/GroupwiseConvRequantizeBenchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,51 +24,43 @@ using namespace std;
using namespace fbgemm;

void performance_test() {
// clang-format off
vector<conv_param_t<>> shapes = {
// MB, IC, OC, {IH, IW}, G, {KH, KW}, {stride_h, stride_w}, pad_t, pad_l,
// pad_b, pad_r
// conv_param_t<>(1, 16, 16, {16, 14}, 4, {3, 3}, {1, 1}, {1, 1, 1, 1}),
conv_param_t<>(1, 128, 128, {56, 48}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
conv_param_t<>(1, 128, 128, {48, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
conv_param_t<>(1, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
conv_param_t<>(2, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
// conv_param_t<>(1, 256, 256, {56, 56}, 64, {3, 3}, {1, 1}, {1, 1, 1,
// 1}),
// conv_param_t<>(1, 3, 64, {224, 224}, 1, {7, 7}, {2, 2}, {3, 3, 3, 3}),
// conv_param_t<>(1, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1,
// 1}),
// conv_param_t<>(1, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1,
// 1}),
// conv_param_t<>(1, 256, 256, {56, 56}, 32, {3, 3}, {2, 2}, {1, 1, 1,
// 1}),
// conv_param_t<>(1, 256, 256, {28, 28}, 32, {3, 3}, {1, 1}, {1, 1, 1,
// 1}),
// conv_param_t<>(1, 512, 512, {28, 28}, 32, {3, 3}, {2, 2}, {1, 1, 1,
// 1}),
// conv_param_t<>(1, 512, 512, {14, 14}, 32, {3, 3}, {1, 1}, {1, 1, 1,
// 1}),
// conv_param_t<>(1, 512, 512, {14, 14}, 32, {3, 3}, {1, 1}, {1, 1, 1,
// 1}),
// conv_param_t<>(1, 1024, 1024, {14, 14}, 32, {3, 3}, {2, 2}, {1, 1, 1,
// 1}),
// conv_param_t<>(1, 1024, 1024, {7, 7}, 32, {3, 3}, {1, 1}, {1, 1, 1,
// 1}),
// conv_param_t<>(1, 1024, 1024, {7, 7}, 32, {3, 3}, {1, 1}, {1, 1, 1,
// 1}),
// BatchSize > 1
// conv_param_t<>(2, 128, 128, {56, 48}, 32, {3, 3}, {1, 1}, {1, 1, 1,
// 1}),

conv_param_t<>(1, 256, 256, {28, 24}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
conv_param_t<>(1, 256, 256, {24, 28}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
conv_param_t<>(1, 256, 256, {28, 28}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
conv_param_t<>(2, 256, 256, {28, 28}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),

conv_param_t<>(1, 512, 512, {14, 12}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
conv_param_t<>(1, 512, 512, {12, 14}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
conv_param_t<>(1, 512, 512, {14, 14}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
conv_param_t<>(2, 512, 512, {14, 14}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
// MB, IC, OC, {IH, IW}, G, {KH, KW}, {stride_h, stride_w}, pad_t, pad_l,
// pad_b, pad_r
// conv_param_t<>(1, 16, 16, {16, 14}, 4, {3, 3}, {1, 1}, {1, 1, 1, 1}),
conv_param_t<>(1, 128, 128, {56, 48}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
conv_param_t<>(1, 128, 128, {48, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
conv_param_t<>(1, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
conv_param_t<>(2, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
// conv_param_t<>(1, 256, 256, {56, 56}, 64, {3, 3}, {1, 1}, {1, 1, 1, 1}),
// conv_param_t<>(1, 3, 64, {224, 224}, 1, {7, 7}, {2, 2}, {3, 3, 3, 3}),
// conv_param_t<>(1, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
// conv_param_t<>(1, 128, 128, {56, 56}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
// conv_param_t<>(1, 256, 256, {56, 56}, 32, {3, 3}, {2, 2}, {1, 1, 1, 1}),
// conv_param_t<>(1, 256, 256, {28, 28}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
// conv_param_t<>(1, 512, 512, {28, 28}, 32, {3, 3}, {2, 2}, {1, 1, 1, 1}),
// conv_param_t<>(1, 512, 512, {14, 14}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
// conv_param_t<>(1, 512, 512, {14, 14}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
// conv_param_t<>(1, 1024, 1024, {14, 14}, 32, {3, 3}, {2, 2},
// {1, 1, 1, 1}),
// conv_param_t<>(1, 1024, 1024, {7, 7}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
// conv_param_t<>(1, 1024, 1024, {7, 7}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),

// BatchSize > 1
// conv_param_t<>(2, 128, 128, {56, 48}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),

conv_param_t<>(1, 256, 256, {28, 24}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
conv_param_t<>(1, 256, 256, {24, 28}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
conv_param_t<>(1, 256, 256, {28, 28}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
conv_param_t<>(2, 256, 256, {28, 28}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),

conv_param_t<>(1, 512, 512, {14, 12}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
conv_param_t<>(1, 512, 512, {12, 14}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
conv_param_t<>(1, 512, 512, {14, 14}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
conv_param_t<>(2, 512, 512, {14, 14}, 32, {3, 3}, {1, 1}, {1, 1, 1, 1}),
};
// clang-format on

bool flush = true;
std::vector<char> llc;
Expand Down Expand Up @@ -424,17 +416,39 @@ void performance_test() {
// printMatrix(matrix_op_t::NoTranspose, Aint8_im2col.data(), MDim, KDim,
// KDim, "A_out after im2col unpacked");

fbgemmGroupwiseConv(
conv_p,
Aint8.data(),
Aint8_zero_point,
row_offset_buf_direct.data(),
packedWeights,
Cint8_fb_direct.data(),
Cint32_fb_direct.data(),
reqObj,
0,
1);
#ifdef _OPENMP
#pragma omp parallel
#endif
{
int num_threads = fbgemm_get_num_threads();
int tid = fbgemm_get_thread_num();
fbgemmGroupwiseConv(
conv_p,
Aint8.data(),
Aint8_zero_point,
row_offset_buf_direct.data(),
packedWeights,
Cint8_fb_direct.data(),
Cint32_fb_direct.data(),
reqObj,
tid,
num_threads);
}

//printMatrix(
// matrix_op_t::NoTranspose,
// Cint8_ref.data(),
// MDim,
// NDim * conv_p.G,
// NDim * conv_p.G,
// "reference:");
//printMatrix(
// matrix_op_t::NoTranspose,
// Cint8_fb_direct.data(),
// MDim,
// NDim * conv_p.G,
// NDim * conv_p.G,
// "Opt:");

end = chrono::high_resolution_clock::now();

Expand Down Expand Up @@ -510,12 +524,11 @@ void performance_test() {

int main() {
#ifdef _OPENMP
// TODO: enable once fbgemmGroupwiseConv support multi-threading
/*// Use 1 thread unless OMP_NUM_THREADS is explicit set.
// Use 1 thread unless OMP_NUM_THREADS is explicit set.
const char* val = getenv("OMP_NUM_THREADS");
if (val == nullptr || !*val) {*/
if (val == nullptr || !*val) {
omp_set_num_threads(1);
/*}*/
}
#endif
performance_test();
return 0;
Expand Down
9 changes: 9 additions & 0 deletions include/fbgemm/Fbgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -511,6 +511,13 @@ class FBGEMM_API PackWeightMatrixForGConv {
const inpType* sdata,
inpType* pdata = nullptr);

/**
* Number of groups we work at a time to fill the full simd width
* e.g., IC_PER_G = 4 and OC_PER_G = 4, we work on two groups at a time
* to fill the avx2 width of 256 bits.
*/
static int numOfGroupsTogether(const conv_param_t<SPATIAL_DIM>& conv_param);

/**
* @brief Packs a block of source matrix into pmat buffer.
*/
Expand Down Expand Up @@ -540,6 +547,8 @@ class FBGEMM_API PackWeightMatrixForGConv {
const T* sdata_;
T* pdata_;
bool bufAllocatedHere_;
// Number of groups we work at a time to fill the full simd width
int GTogether_;

/**
* @brief Internal function performing both pack & unpack
Expand Down
29 changes: 29 additions & 0 deletions include/fbgemm/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <algorithm>
#include <array>
#include <cmath>
#include <string>
#include <type_traits>
#include "FbgemmBuild.h"
Expand Down Expand Up @@ -51,6 +53,33 @@ enum class impl_type_t { ref, opt };
*/
enum class layout_t { KCX, KXC };

/**
* @brief Some commonly used variables for different instruction sets
*/
template <inst_set_t inst_set>
struct simd_info;

template <>
struct simd_info<inst_set_t::avx2> {
static constexpr int WIDTH_BITS = 256;
static constexpr int WIDTH_BYTES = 32;
static constexpr int WIDTH_32BIT_ELEMS = 8;
};

template <>
struct simd_info<inst_set_t::avx512> {
static constexpr int WIDTH_BITS = 512;
static constexpr int WIDTH_BYTES = 64;
static constexpr int WIDTH_32BIT_ELEMS = 16;
};

template <>
struct simd_info<inst_set_t::avx512_vnni> {
static constexpr int WIDTH_BITS = 512;
static constexpr int WIDTH_BYTES = 64;
static constexpr int WIDTH_32BIT_ELEMS = 16;
};

/**
* @brief A function to compare data in two buffers for closeness/equality.
*/
Expand Down
128 changes: 128 additions & 0 deletions src/CodeGenHelpers.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
/*
* Copyright (c) Facebook, Inc. and its affiliates.
* All rights reserved.
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <asmjit/asmjit.h>

namespace fbgemm {

namespace x86 = asmjit::x86;

/**
* @brief Create instruction sequence to generate 16-bit 1s
* @tparam T Register type of destination, e.g., x86::Ymm or x86::Zmm
*
* @param dest Once the instruction sequence is executed,
* dest[0:15] will have 0x0001, dest[16:31]
* will have 0x0001 and so on
*/
template <typename T>
void gen16BitVectorOne(x86::Emitter* a, T dest) {
a->vpcmpeqw(dest, dest, dest);
a->vpsrlw(dest, dest, 15);
}

/**
* @brief Create instruction sequence to generate 8-bit 1s
* @tparam T Register type of destination, e.g., x86::Ymm or x86::Zmm
*
* @param dest Once the instruction sequence is executed,
* dest[0:7] will have 0x01, dest[8:15]
* will have 0x01 and so on
*/
template <typename T>
void gen8BitVectorOne(x86::Emitter* a, T dest) {
a->vpcmpeqw(dest, dest, dest);
a->vpabsb(dest, dest);
}

/**
* @brief Generates instruction sequence to compute s32 += U8 * I8
* @tparam T Register type of destination, e.g., x86::Ymm or x86::Zmm
*
* @param cReg contains result
*
*/
template <typename T>
void genU8I8S32FMA(
x86::Emitter* a,
T aReg,
T bReg,
T cReg,
T oneReg16Bit,
T tmpReg) {
a->vpmaddubsw(tmpReg, aReg, bReg);
a->vpmaddwd(tmpReg, oneReg16Bit, tmpReg);
a->vpaddd(cReg, tmpReg, cReg);
}

/**
* @brief Add 4 consecutive numbers of type uint8
* and emit their sum as 32-bit numbers.
* i.e., dest[0:31] contains
* src[0:7] + src[8:15] + src[16:23] + src[24:31]
* @tparam T Register type of destination, e.g., x86::Ymm or x86::Zmm
*
* @param dest contains result
*
*/
template <typename T>
void genU8Sum4(
x86::Emitter* a,
T src,
T dest,
T oneReg16Bit,
T tmpReg) {
gen8BitVectorOne(a, tmpReg);
a->vpmaddubsw(tmpReg, src, tmpReg);
a->vpmaddwd(tmpReg, tmpReg, oneReg16Bit);
a->vpaddd(dest, tmpReg, dest);
/*a->vxorps(tmpReg, tmpReg, tmpReg);*/
/*a->vmpsadbw(tmpReg, src, tmpReg, static_cast<asmjit::Imm>(0));*/
/*a->vpermilps(tmpReg, tmpReg, static_cast<asmjit::Imm>(4));*/
/*a->vpmovzxwd(tmpReg, tmpReg.half());*/
/*a->vpaddd(dest, tmpReg, dest);*/
}

/**
* @brief Add 8 consecutive numbers of type uint8
* and emit their sum as 16-bit numbers.
* i.e., dest[0:15] contains
* src[0:7] + src[8:15] + src[16:23] + src[24:31]
* src[32:39] + src[40:47] + src[48:55] + src[56:63]
*
* and
*
* dest[64:79] contains
* src[64:71] + src[71:79] + src[80:87] + src[88:95]
* src[96:103] + src[104:111] + src[112:119] + src[120:127]
*
* so on
*
* @tparam T Register type of destination, e.g., x86::Ymm or x86::Zmm
*
* @param dest contains result
*
*/
template <typename T>
void genU8Sum8(x86::Emitter* a, T src, T dest, T tmpReg) {
a->vxorps(tmpReg, tmpReg, tmpReg);
a->vpsadbw(tmpReg, src, tmpReg);
a->vpaddd(dest, tmpReg, dest);
}

/**
* @brief Broadcast lower 8-bits of src to destination vector
* register.
*/
template <typename T>
void broadcast8Bit(x86::Emitter* a, x86::Gp src, T dest) {
// move src to dest
a->movq(dest.half(), src);
a->vpbroadcastb(dest, dest.half());
}

} // namespace fbgemm
15 changes: 11 additions & 4 deletions src/Fbgemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -219,13 +219,20 @@ FBGEMM_API bool fbgemmOptimizedGConv(const conv_param_t<SPATIAL_DIM>& conv_p) {
int C_per_G = conv_p.IC / conv_p.G;
int K_per_G = conv_p.OC / conv_p.G;

int G_together = PackWeightMatrixForGConv<int8_t, int32_t, SPATIAL_DIM>::
numOfGroupsTogether(conv_p);

return (SPATIAL_DIM == 2) && (C_per_G == K_per_G) &&
(C_per_G == 4 || C_per_G == 8 || C_per_G == 16) && (conv_p.G % 8 == 0) &&
(conv_p.K[0] == conv_p.K[1]) && (conv_p.K[0] == 3) &&
(conv_p.pad[0] == 1) && (conv_p.pad[1] == 1) &&
(C_per_G == 2 || C_per_G == 4 || C_per_G == 8 || C_per_G == 16) &&
(conv_p.G >= G_together) && (conv_p.K[0] == conv_p.K[1]) &&
(conv_p.K[0] == 3) && (conv_p.pad[0] == 1) && (conv_p.pad[1] == 1) &&
((conv_p.IN_DIM[0] % 2 == 0) == (conv_p.IN_DIM[1] % 2 == 0)) &&
(conv_p.IN_DIM[0] >= conv_p.K[0]) && (conv_p.OUT_DIM[0] >= conv_p.K[0]) &&
(conv_p.IN_DIM[1] >= conv_p.K[1]) && (conv_p.OUT_DIM[1] >= conv_p.K[1]) &&
(conv_p.pad[0] == conv_p.pad[2]) && (conv_p.pad[1] == conv_p.pad[3]) &&
(conv_p.dilation[0] == 1) && (conv_p.dilation[0] == conv_p.dilation[1]) &&
(conv_p.stride[0] == 1) && (conv_p.stride[0] == conv_p.stride[1]);
(conv_p.stride[0] == 1 || conv_p.stride[0] == 2) &&
(conv_p.stride[0] == conv_p.stride[1]);
}

template bool fbgemmOptimizedGConv(const conv_param_t<2>& conv_p);
Expand Down
Loading

0 comments on commit cf7a2fb

Please sign in to comment.