Skip to content

Commit

Permalink
Matrix transpose kernels for int8/uint8 datatypes (pytorch#431)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#431

Vectorized implementation of transposing int8/uint8 matrices.

Reviewed By: jiecaoyu

Differential Revision: D23917645

fbshipit-source-id: 72ba7f7b5716c8f514c44555d9941920e8a7f0fd
  • Loading branch information
dskhudia authored and facebook-github-bot committed Oct 2, 2020
1 parent 1d71039 commit fe91640
Show file tree
Hide file tree
Showing 7 changed files with 614 additions and 52 deletions.
23 changes: 16 additions & 7 deletions bench/TransposeBenchmark.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,22 +18,30 @@
using namespace std;
using namespace fbgemm;

template <typename T>
void performance_test() {
constexpr int NWARMUP = 4;
constexpr int NITER = 256;

normal_distribution<float> dist;
uniform_int_distribution<int> dist(0, 10);
default_random_engine engine;

cout << setw(4) << "M" << setw(4) << "N"
string runType;
if (is_same<T, float>::value) {
runType = "float";
} else {
runType = "i8";
}

cout << setw(8) << "dtype" << setw(4) << "M" << setw(4) << "N"
<< " B_elements_per_sec" << endl;

int dims[] = {1, 2, 3, 4, 5, 6, 8, 9, 10, 15, 16,
17, 32, 33, 63, 64, 65, 127, 128, 129, 255, 256};
for (int M : dims) {
for (int N : dims) {
vector<float> a(M * N);
vector<float> b(N * M), b_ref(N * M);
vector<T> a(M * N);
vector<T> b(N * M), b_ref(N * M);

generate(a.begin(), a.end(), [&dist, &engine] { return dist(engine); });
transpose_ref(M, N, a.data(), N, b_ref.data(), M);
Expand All @@ -44,15 +52,16 @@ void performance_test() {
NITER);
duration *= 1e9; // convert to ns

cout << setw(4) << M << setw(4) << N << setw(10) << setprecision(3)
<< (M * N) / duration << endl;
cout << setw(8) << runType << setw(4) << M << setw(4) << N << setw(10)
<< setprecision(3) << (M * N) / duration << endl;

compare_buffers(b_ref.data(), b.data(), M, N, N, 5);
} // N
} // M
} // performance_test

int main() {
performance_test();
performance_test<float>();
performance_test<uint8_t>();
return 0;
}
10 changes: 3 additions & 7 deletions include/fbgemm/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,9 @@ void printMatrix(
* @param M the number of rows of input matrix
* @param N the number of columns of input matrix
*/
FBGEMM_API void transpose_simd(
int M,
int N,
const float* src,
int ld_src,
float* dst,
int ld_dst);
template <typename T>
FBGEMM_API void
transpose_simd(int M, int N, const T* src, int ld_src, T* dst, int ld_dst);

/**
* @brief Explicitly set instruction set to be used
Expand Down
54 changes: 41 additions & 13 deletions src/TransposeUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,42 +11,70 @@

namespace fbgemm {

void transpose_ref(
int M,
int N,
const float* src,
int ld_src,
float* dst,
int ld_dst) {
template <typename T>
void transpose_ref(int M, int N, const T* src, int ld_src, T* dst, int ld_dst) {
for (int j = 0; j < N; j++) {
for (int i = 0; i < M; i++) {
dst[i + j * ld_dst] = src[i * ld_src + j];
}
} // for each output row
}

template <typename T>
void transpose_simd(
int M,
int N,
const float* src,
const T* src,
int ld_src,
float* dst,
T* dst,
int ld_dst) {
if ((M == 1 && ld_dst == 1) || (N == 1 && ld_src == 1)) {
if (dst != src) {
memcpy(dst, src, M * N * sizeof(float));
memcpy(dst, src, M * N * sizeof(T));
}
return;
}
static const auto iset = fbgemmInstructionSet();
// Run time CPU detection
if (isZmm(iset)) {
internal::transpose_avx512(M, N, src, ld_src, dst, ld_dst);
internal::transpose_avx512<T>(M, N, src, ld_src, dst, ld_dst);
} else if (isYmm(iset)) {
internal::transpose_avx2(M, N, src, ld_src, dst, ld_dst);
internal::transpose_avx2<T>(M, N, src, ld_src, dst, ld_dst);
} else {
transpose_ref(M, N, src, ld_src, dst, ld_dst);
transpose_ref<T>(M, N, src, ld_src, dst, ld_dst);
}
}

template void transpose_ref<float>(
int M,
int N,
const float* src,
int ld_src,
float* dst,
int ld_dst);

template void transpose_ref<uint8_t>(
int M,
int N,
const uint8_t* src,
int ld_src,
uint8_t* dst,
int ld_dst);

template FBGEMM_API void transpose_simd<float>(
int M,
int N,
const float* src,
int ld_src,
float* dst,
int ld_dst);

template FBGEMM_API void transpose_simd<uint8_t>(
int M,
int N,
const uint8_t* src,
int ld_src,
uint8_t* dst,
int ld_dst);

} // namespace fbgemm
24 changes: 8 additions & 16 deletions src/TransposeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,9 @@ namespace fbgemm {
* @param dst The memory buffer of the destination matrix B.
* @param ld_dst The leading dimension of the destination matrix B.
*/
FBGEMM_API void transpose_ref(
int M,
int N,
const float* src,
int ld_src,
float* dst,
int ld_dst);
template <typename T>
FBGEMM_API void
transpose_ref(int M, int N, const T* src, int ld_src, T* dst, int ld_dst);

namespace internal {

Expand All @@ -33,25 +29,21 @@ namespace internal {
*
* This is called if the code is running on a CPU with Intel AVX2 support.
*/
void transpose_avx2(
int M,
int N,
const float* src,
int ld_src,
float* dst,
int ld_dst);
template <typename T>
void transpose_avx2(int M, int N, const T* src, int ld_src, T* dst, int ld_dst);

/**
* @brief Transpose a matrix using Intel AVX512.
*
* This is called if the code is running on a CPU with Intel AVX512 support.
*/
template <typename T>
void transpose_avx512(
int M,
int N,
const float* src,
const T* src,
int ld_src,
float* dst,
T* dst,
int ld_dst);

} // namespace internal
Expand Down
16 changes: 16 additions & 0 deletions src/UtilsAvx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace fbgemm {

namespace internal {

template <>
void transpose_avx2(
int M,
int N,
Expand Down Expand Up @@ -169,6 +170,21 @@ void transpose_avx2(
}
}

template <>
void transpose_avx2(
int M,
int N,
const uint8_t* src,
int ld_src,
uint8_t* dst,
int ld_dst) {
for (int j = 0; j < N; j++) {
for (int i = 0; i < M; i++) {
dst[i + j * ld_dst] = src[i * ld_src + j];
}
} // for each output row
}

} // namespace internal

} // namespace fbgemm
Loading

0 comments on commit fe91640

Please sign in to comment.