Skip to content

Commit

Permalink
sparse convolution output processing (pytorch#27)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#27

DoSpmdmOnInpBuffer can't be used together with PackAWithIm2Col because DoSpmdmOnInpBuffer expects im2col'ed A matrix. This diff implements DoSConvOnInpBuffer that does sparse convolution directly on A input without im2col. The performance is well optimized and need to see if this implementation is good enough to get good resnet50 performance.

Reviewed By: dskhudia

Differential Revision: D13192336

fbshipit-source-id: 2076555ba9749e111afbaec408a2bfa0f55bd5bc
  • Loading branch information
jspark1105 authored and facebook-github-bot committed Nov 29, 2018
1 parent 90535d3 commit 027de07
Show file tree
Hide file tree
Showing 7 changed files with 429 additions and 12 deletions.
47 changes: 46 additions & 1 deletion include/fbgemm/Fbgemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@ class ReluOutput {
};

/**
* @brief Perform Sparse-Matrix * Dense-Matrix as a part the of output
* @brief Perform Dense-Matrix * Sparse-Matrix as a part the of output
* processing pipeline.
*
* SPMDM (SParse Matrix times Dense Matrix) inplace on the 32-bit input buffer
Expand Down Expand Up @@ -847,6 +847,51 @@ class DoSpmdmOnInpBuffer {
const int groups_;
};

/**
* @brief Perform Dense-Matrix * Sparse-Matrix as a part the of output
* processing pipeline.
*
* SPMDM (SParse Matrix times Dense Matrix) inplace on the 32-bit input buffer
* (inp). After modifying the input buffer, pass it to the next op.
* When groups > 1, each group is numRows() x (numCols()/groups) matrix.
*/
template <
typename outT = std::int32_t,
typename inT = std::int32_t,
typename nextOPType = DoNothing<inT, inT>>
class DoSConvOnInpBuffer {
public:
using outType = outT;
using inpType = inT;
DoSConvOnInpBuffer(
nextOPType& nextop,
const std::uint8_t* A,
const conv_param_t<>& conv_p,
std::int32_t A_zero_point,
const CompressedSparseColumn& B_csc,
int groups = 1)
: nextop_(nextop),
A_(A),
conv_p_(conv_p),
A_zero_point_(A_zero_point),
B_csc_(B_csc) {}

template <inst_set_t instSet>
inline int f(
outT* out,
inT* inp,
const block_type_t& block,
int ld_out,
int ld_in) const;

private:
nextOPType& nextop_;
const std::uint8_t* A_;
const conv_param_t<>& conv_p_;
const std::int32_t A_zero_point_;
const CompressedSparseColumn& B_csc_;
};

enum class QuantizationGranularity {
TENSOR,
GROUP,
Expand Down
35 changes: 33 additions & 2 deletions include/fbgemm/FbgemmI8Spmdm.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <cstdint>
#include <vector>
#include "ConvUtils.h"
#include "Utils.h"

// #define FBGEMM_MEASURE_TIME_BREAKDOWN
Expand All @@ -21,6 +22,7 @@ extern double spmdm_transpose_32xN_time;
extern double spmdm_compute_time;
extern double spmdm_transpose_Nx32_time;
extern double spmdm_run_time;
extern double sconv_run_time;
#endif

namespace fbgemm {
Expand All @@ -46,6 +48,19 @@ class CompressedSparseColumn {
std::vector<std::int8_t>& Values() {
return values_;
}
std::vector<std::int16_t>& KHs() {
return kh_;
}
std::vector<std::int16_t>& KWs() {
return kw_;
}
/**
* ICs include group: i.e. for ith input channels withint group g, ICs contain
* g*(groups_per_input_channels) + i
*/
std::vector<std::int16_t>& ICs() {
return ic_;
}

std::size_t NumOfRows() const {
return num_rows_;
Expand Down Expand Up @@ -83,12 +98,28 @@ class CompressedSparseColumn {
std::int32_t* C,
int ldc) const;

void SparseConv(
const conv_param_t<>& conv_p,
const block_type_t& block,
const std::uint8_t* A,
std::int32_t A_zero_point,
bool accumulation,
std::int32_t* C,
int ldc) const;

private:
const std::size_t num_rows_;
std::vector<std::int32_t> colptr_;
std::vector<std::int16_t> rowidx_;
std::vector<std::int32_t> colptr_; // corresponds to out channels
std::vector<std::int8_t> values_;

// For SpMDM
std::vector<std::int16_t> rowidx_; // kh kw ic are flattened with im2col

// For direct sparse convolution
std::vector<std::int16_t> kh_;
std::vector<std::int16_t> kw_;
std::vector<std::int16_t> ic_; // in channels

// Cache IsHyperSparse to minimize its overhead.
mutable bool hyper_sparse_;

Expand Down
12 changes: 12 additions & 0 deletions include/fbgemm/OutputProcessing-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ inline int DoSpmdmOnInpBuffer<outT, inT, nextOPType>::f(
return nextop_.template f<instSet>(out, inp, block, ld_out, ld_in);
}

template <typename outT, typename inT, typename nextOPType>
template <inst_set_t instSet>
inline int DoSConvOnInpBuffer<outT, inT, nextOPType>::f(
outT* out,
inT* inp,
const block_type_t& block,
int ld_out,
int ld_in) const {
B_csc_.SparseConv(conv_p_, block, A_, A_zero_point_, true, inp, ld_in);
return nextop_.template f<instSet>(out, inp, block, ld_out, ld_in);
}

template <
bool FUSE_RELU,
QuantizationGranularity Q_GRAN,
Expand Down
18 changes: 18 additions & 0 deletions src/ExecuteKernelU8S8.cc
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,24 @@ INSTANTIATE_Q_GRANS(true);
#undef INSTANTIATE_Q_GRANS
#undef INSTANTIATE_BASE

#define INSTANTIATE_BASE(RELU, Q_GRAN) \
template class ExecuteKernel< \
PackAWithIm2Col<uint8_t, int16_t>, \
PackBMatrix<int8_t, int16_t>, \
uint8_t, \
DoSConvOnInpBuffer<uint8_t, int32_t, ReQuantizeOutput<RELU, Q_GRAN>>>;

#define INSTANTIATE_Q_GRANS(RELU) \
INSTANTIATE_BASE(RELU, QuantizationGranularity::TENSOR); \
INSTANTIATE_BASE(RELU, QuantizationGranularity::GROUP); \
INSTANTIATE_BASE(RELU, QuantizationGranularity::OUT_CHANNEL);

INSTANTIATE_Q_GRANS(false);
INSTANTIATE_Q_GRANS(true);

#undef INSTANTIATE_Q_GRANS
#undef INSTANTIATE_BASE

template class ExecuteKernel<
PackAWithRowOffset<uint8_t, int16_t>,
PackBMatrix<int8_t, int16_t>,
Expand Down
25 changes: 25 additions & 0 deletions src/Fbgemm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,31 @@ INSTANTIATE_Q_GRANS(true);
#undef INSTANTIATE_Q_GRANS
#undef INSTANTIATE_BASE

#define INSTANTIATE_BASE(RELU, Q_GRAN) \
template void fbgemmPacked( \
PackMatrix<PackAWithIm2Col<uint8_t, int16_t>, uint8_t, int16_t>& packA, \
PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB, \
uint8_t* C, \
int32_t* C_buffer, \
uint32_t ldc, \
const DoSConvOnInpBuffer< \
uint8_t, \
int32_t, \
ReQuantizeOutput<RELU, Q_GRAN>>& outProcess, \
int thread_id, \
int num_threads);

#define INSTANTIATE_Q_GRANS(RELU) \
INSTANTIATE_BASE(RELU, QuantizationGranularity::TENSOR); \
INSTANTIATE_BASE(RELU, QuantizationGranularity::GROUP); \
INSTANTIATE_BASE(RELU, QuantizationGranularity::OUT_CHANNEL);

INSTANTIATE_Q_GRANS(false);
INSTANTIATE_Q_GRANS(true);

#undef INSTANTIATE_Q_GRANS
#undef INSTANTIATE_BASE

template void fbgemmPacked(
PackMatrix<PackAWithRowOffset<uint8_t, int16_t>, uint8_t, int16_t>& packA,
PackMatrix<PackBMatrix<int8_t, int16_t>, int8_t, int16_t>& packB,
Expand Down
72 changes: 69 additions & 3 deletions src/FbgemmI8Spmdm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ double spmdm_transpose_32xN_time = 0.0;
double spmdm_compute_time = 0.0;
double spmdm_transpose_Nx32_time = 0.0;
double spmdm_run_time = 0.0;
double sconv_run_time = 0.0;
#endif

using namespace std;
Expand Down Expand Up @@ -222,8 +223,8 @@ void CompressedSparseColumn::SpMDM(
t_very_start = std::chrono::high_resolution_clock::now();
#endif

uint8_t A_buffer[K * 32] __attribute__((aligned(64)));
int32_t C_buffer[N * 32] __attribute__((aligned(64)));
alignas(64) uint8_t A_buffer[K * 32];
alignas(64) int32_t C_buffer[N * 32];

// If we compute C = C + A * B, where B is a sparse matrix in CSC format, for
// each non-zero in B, we'd need to access the corresponding column in A.
Expand Down Expand Up @@ -269,7 +270,7 @@ void CompressedSparseColumn::SpMDM(
for (int i1 = block.row_start; i1 < i_end; i1 += 32) {
// Transpose 32 x K submatrix of A
if (i_end - i1 < 32) {
uint8_t A_temp_buffer[K * 32] __attribute__((aligned(64)));
alignas(64) uint8_t A_temp_buffer[K * 32];
for (int i2 = 0; i2 < (i_end - i1) / 8 * 8; i2 += 8) {
transpose_8rows(K, A + (i1 + i2) * lda, lda, A_buffer + i2, 32);
}
Expand Down Expand Up @@ -505,4 +506,69 @@ void CompressedSparseColumn::SpMDM(
#endif
}

void CompressedSparseColumn::SparseConv(
const conv_param_t<>& conv_p,
const block_type_t& block,
const uint8_t* A,
int32_t A_zero_point,
bool accumulation,
int32_t* C,
int ldc) const {
int K = NumOfRows();
int N = block.col_size;

if (K == 0 || N == 0) {
return;
}

#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
std::chrono::time_point<std::chrono::high_resolution_clock> t_start, t_end;
double dt;
t_start = std::chrono::high_resolution_clock::now();
#endif

// TODO: if not hyper sparse, transpose a block of A matrix as in SpMDM.
if (!accumulation) {
for (int i = block.row_start; i < block.row_start + block.row_size; ++i) {
for (int j = block.col_start; j < block.col_start + block.col_size;
++j) {
C[(i - block.row_start) * ldc + j - block.col_start] = 0;
}
}
}
for (int j = block.col_start; j < block.col_start + block.col_size; ++j) {
for (int k = colptr_[j]; k < colptr_[j + 1]; ++k) {
int v = values_[k];
for (int i = block.row_start; i < block.row_start + block.row_size;
++i) {
int ow = i % conv_p.OUT_DIM[1];
int oh = i / conv_p.OUT_DIM[1] % conv_p.OUT_DIM[0];
int n = i / conv_p.OUT_DIM[1] / conv_p.OUT_DIM[0];
assert(n < conv_p.MB);
int iw = -conv_p.pad[1] + ow * conv_p.stride[1] + kw_[k];
int ih = -conv_p.pad[0] + oh * conv_p.stride[0] + kh_[k];

if (ih >= 0 && ih < conv_p.IN_DIM[0] && iw >= 0 &&
iw < conv_p.IN_DIM[1]) {
C[(i - block.row_start) * ldc + j - block.col_start] +=
A[((n * conv_p.IN_DIM[0] + ih) * conv_p.IN_DIM[1] + iw) *
conv_p.IC +
ic_[k]] *
v;
} else {
C[(i - block.row_start) * ldc + j - block.col_start] +=
A_zero_point * v;
}
}
}
} // for each column of B

#ifdef FBGEMM_MEASURE_TIME_BREAKDOWN
t_end = std::chrono::high_resolution_clock::now();
dt = std::chrono::duration_cast<std::chrono::nanoseconds>(t_end - t_start)
.count();
sconv_run_time += (dt);
#endif
}

} // namespace fbgemm
Loading

0 comments on commit 027de07

Please sign in to comment.