Skip to content

Commit

Permalink
Make matrix dimentions to transpose_simd unsigned (pytorch#449)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch#449

They could not be negative, could they

Reviewed By: jianyuh

Differential Revision: D24663461

fbshipit-source-id: 8446ae442d37a90f9e751adebc90e369b22053fa
  • Loading branch information
malfet authored and facebook-github-bot committed Nov 3, 2020
1 parent 5b7566f commit 8eb6dcb
Show file tree
Hide file tree
Showing 6 changed files with 72 additions and 72 deletions.
2 changes: 1 addition & 1 deletion include/fbgemm/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ void printMatrix(
*/
template <typename T>
FBGEMM_API void
transpose_simd(int M, int N, const T* src, int ld_src, T* dst, int ld_dst);
transpose_simd(unsigned M, unsigned N, const T* src, unsigned ld_src, T* dst, unsigned ld_dst);

/**
* @brief Explicitly set instruction set to be used
Expand Down
46 changes: 23 additions & 23 deletions src/TransposeUtils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,22 @@
namespace fbgemm {

template <typename T>
void transpose_ref(int M, int N, const T* src, int ld_src, T* dst, int ld_dst) {
for (int j = 0; j < N; j++) {
for (int i = 0; i < M; i++) {
void transpose_ref(unsigned M, unsigned N, const T* src, unsigned ld_src, T* dst, unsigned ld_dst) {
for (unsigned j = 0; j < N; j++) {
for (unsigned i = 0; i < M; i++) {
dst[i + j * ld_dst] = src[i * ld_src + j];
}
} // for each output row
}

template <typename T>
void transpose_simd(
int M,
int N,
unsigned M,
unsigned N,
const T* src,
int ld_src,
unsigned ld_src,
T* dst,
int ld_dst) {
unsigned ld_dst) {
if ((M == 1 && ld_dst == 1) || (N == 1 && ld_src == 1)) {
if (dst != src) {
// sizeof must be first operand force dims promotion to OS-bitness type
Expand All @@ -47,35 +47,35 @@ void transpose_simd(
}

template void transpose_ref<float>(
int M,
int N,
unsigned M,
unsigned N,
const float* src,
int ld_src,
unsigned ld_src,
float* dst,
int ld_dst);
unsigned ld_dst);

template void transpose_ref<uint8_t>(
int M,
int N,
unsigned M,
unsigned N,
const uint8_t* src,
int ld_src,
unsigned ld_src,
uint8_t* dst,
int ld_dst);
unsigned ld_dst);

template FBGEMM_API void transpose_simd<float>(
int M,
int N,
unsigned M,
unsigned N,
const float* src,
int ld_src,
unsigned ld_src,
float* dst,
int ld_dst);
unsigned ld_dst);

template FBGEMM_API void transpose_simd<uint8_t>(
int M,
int N,
unsigned M,
unsigned N,
const uint8_t* src,
int ld_src,
unsigned ld_src,
uint8_t* dst,
int ld_dst);
unsigned ld_dst);

} // namespace fbgemm
12 changes: 6 additions & 6 deletions src/TransposeUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace fbgemm {
*/
template <typename T>
FBGEMM_API void
transpose_ref(int M, int N, const T* src, int ld_src, T* dst, int ld_dst);
transpose_ref(unsigned M, unsigned N, const T* src, unsigned ld_src, T* dst, unsigned ld_dst);

namespace internal {

Expand All @@ -30,7 +30,7 @@ namespace internal {
* This is called if the code is running on a CPU with Intel AVX2 support.
*/
template <typename T>
void transpose_avx2(int M, int N, const T* src, int ld_src, T* dst, int ld_dst);
void transpose_avx2(unsigned M, unsigned N, const T* src, unsigned ld_src, T* dst, unsigned ld_dst);

/**
* @brief Transpose a matrix using Intel AVX512.
Expand All @@ -39,12 +39,12 @@ void transpose_avx2(int M, int N, const T* src, int ld_src, T* dst, int ld_dst);
*/
template <typename T>
void transpose_avx512(
int M,
int N,
unsigned M,
unsigned N,
const T* src,
int ld_src,
unsigned ld_src,
T* dst,
int ld_dst);
unsigned ld_dst);

} // namespace internal

Expand Down
26 changes: 13 additions & 13 deletions src/TransposeUtilsAvx2.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ namespace internal {

// 4 * 4 = 16 instructions
static inline void
transpose_kernel_4x4_sse(const float* src, int ld_src, float* dst, int ld_dst) {
transpose_kernel_4x4_sse(const float* src, unsigned ld_src, float* dst, unsigned ld_dst) {
// load from src to registers
// a : a0 a1 a2 a3
// b : b0 b1 b2 b3
Expand All @@ -46,13 +46,13 @@ transpose_kernel_4x4_sse(const float* src, int ld_src, float* dst, int ld_dst) {

// kernel for transpose mxn where m, n <= 4
// M + (M + 1) / 2 * 2 + 2 * N instructions
template <int M>
template <unsigned M>
static void transpose_kernel_mxn_sse(
int N,
unsigned N,
const float* src,
int ld_src,
unsigned ld_src,
float* dst,
int ld_dst) {
unsigned ld_dst) {
// clang-format off
alignas(64) static const int masks[5][4] = {
{ 0, 0, 0, 0, },
Expand All @@ -66,7 +66,7 @@ static void transpose_kernel_mxn_sse(
// load from src to registers
__m128i mask_v = _mm_load_si128(reinterpret_cast<const __m128i*>(masks[N]));
__m128 input[4];
int i;
unsigned i;
for (i = 0; i < M; ++i) {
input[i] = _mm_maskload_ps(&src[i * ld_src], mask_v);
}
Expand Down Expand Up @@ -100,9 +100,9 @@ static void transpose_kernel_mxn_sse(
// 8 * 5 = 40 instructions
static inline void transpose_kernel_8x8_avx2(
const float* src,
int ld_src,
unsigned ld_src,
float* dst,
int ld_dst) {
unsigned ld_dst) {
// load from src to registers
// a : a0 a1 a2 a3 a4 a5 a6 a7
// b : b0 b1 b2 b3 b4 b5 b6 b7
Expand Down Expand Up @@ -190,18 +190,18 @@ static inline void transpose_kernel_8x8_avx2(

// kernel for transposing mxn where m, n <= 8
// M + (M + 1) / 2 * 2 + (M + 3) / 4 * 4 + 2 * N instructions
template <int M>
template <unsigned M>
static void transpose_kernel_mxn_avx2(
int N,
unsigned N,
const float* src,
int ld_src,
unsigned ld_src,
float* dst,
int ld_dst) {
unsigned ld_dst) {
// load from src to registers
__m256i mask_v = _mm256_load_si256(
reinterpret_cast<const __m256i*>(internal::avx2_ps_or_epi32_masks[N]));
__m256 input[8];
int i;
unsigned i;
for (i = 0; i < M; ++i) {
input[i] = _mm256_maskload_ps(&src[i * ld_src], mask_v);
}
Expand Down
28 changes: 14 additions & 14 deletions src/UtilsAvx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ namespace internal {

template <>
void transpose_avx2(
int M,
int N,
unsigned M,
unsigned N,
const float* src,
int ld_src,
unsigned ld_src,
float* dst,
int ld_dst) {
int ib = 0, jb = 0;
unsigned ld_dst) {
unsigned ib = 0, jb = 0;
if (N % 8 > 0 && N % 8 < 4) {
// If the remainder has n < 4 columns, we use the SSE kernel for the
// remainder because it requires 2 * (2 * 4 + 2 * N) = 16 + 4N instructions
Expand All @@ -31,7 +31,7 @@ void transpose_avx2(
transpose_kernel_8x8_avx2(
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
for (int i = ib; i < ib + 8; i += 4) {
for (unsigned i = ib; i < ib + 8; i += 4) {
transpose_kernel_mxn_sse<4>(
N - jb,
&src[i * ld_src + jb],
Expand All @@ -49,7 +49,7 @@ void transpose_avx2(
transpose_kernel_8x8_avx2(
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
for (int i = ib; i < ib + 8; i += 4) {
for (unsigned i = ib; i < ib + 8; i += 4) {
transpose_kernel_4x4_sse(
&src[i * ld_src + jb], ld_src, &dst[i + jb * ld_dst], ld_dst);
}
Expand Down Expand Up @@ -79,7 +79,7 @@ void transpose_avx2(
// on m.
switch (M - ib) {
case 1:
for (int j = 0; j < N; ++j) {
for (unsigned j = 0; j < N; ++j) {
dst[ib + j * ld_dst] = src[ib * ld_src + j];
}
break;
Expand Down Expand Up @@ -172,14 +172,14 @@ void transpose_avx2(

template <>
void transpose_avx2(
int M,
int N,
unsigned M,
unsigned N,
const uint8_t* src,
int ld_src,
unsigned ld_src,
uint8_t* dst,
int ld_dst) {
for (int j = 0; j < N; j++) {
for (int i = 0; i < M; i++) {
unsigned ld_dst) {
for (unsigned j = 0; j < N; j++) {
for (unsigned i = 0; i < M; i++) {
dst[i + j * ld_dst] = src[i * ld_src + j];
}
} // for each output row
Expand Down
30 changes: 15 additions & 15 deletions src/UtilsAvx512.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ namespace {
// 16 * 6 = 96 instructions
inline void transpose_kernel_16x16_avx512(
const float* src,
int ld_src,
unsigned ld_src,
float* dst,
int ld_dst) {
unsigned ld_dst) {
// load from src to registers
// a: a0 a1 a2 a3 a4 a5 a6 a7 a8 a9 a10 a11 a12 a13 a14 a15
// b: b0 b1 b2 b3 b4 b5 b6 b7 b8 b9 b10 b11 b12 b13 b14 b15
Expand Down Expand Up @@ -307,13 +307,13 @@ namespace internal {

template <>
void transpose_avx512(
int M,
int N,
unsigned M,
unsigned N,
const float* src,
int ld_src,
unsigned ld_src,
float* dst,
int ld_dst) {
int ib = 0, jb = 0;
unsigned ld_dst) {
unsigned ib = 0, jb = 0;
if (N % 16 > 0 && N % 16 < 4) {
// If the remainder has n < 4 columns, we use the SSE kernel for the
// remainder because it requires 4 * (2 * 4 + 2 * N) = 32 + 8N instructions
Expand All @@ -324,7 +324,7 @@ void transpose_avx512(
transpose_kernel_16x16_avx512(
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
for (int i = ib; i < ib + 16; i += 4) {
for (unsigned i = ib; i < ib + 16; i += 4) {
transpose_kernel_mxn_sse<4>(
N - jb,
&src[i * ld_src + jb],
Expand All @@ -342,7 +342,7 @@ void transpose_avx512(
transpose_kernel_16x16_avx512(
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
for (int i = ib; i < ib + 16; i += 4) {
for (unsigned i = ib; i < ib + 16; i += 4) {
transpose_kernel_4x4_sse(
&src[i * ld_src + jb], ld_src, &dst[i + jb * ld_dst], ld_dst);
}
Expand All @@ -356,7 +356,7 @@ void transpose_avx512(
transpose_kernel_16x16_avx512(
&src[ib * ld_src + jb], ld_src, &dst[ib + jb * ld_dst], ld_dst);
}
for (int i = ib; i < ib + 16; i += 8) {
for (unsigned i = ib; i < ib + 16; i += 8) {
transpose_kernel_8x8_avx2(
&src[i * ld_src + jb], ld_src, &dst[i + jb * ld_dst], ld_dst);
}
Expand Down Expand Up @@ -386,7 +386,7 @@ void transpose_avx512(
// on m.
switch (M - ib) {
case 1:
for (int j = 0; j < N; ++j) {
for (unsigned j = 0; j < N; ++j) {
dst[ib + j * ld_dst] = src[ib * ld_src + j];
}
break;
Expand Down Expand Up @@ -1040,12 +1040,12 @@ void transpose_16x32_block(

template <>
void transpose_avx512(
int M,
int N,
unsigned M,
unsigned N,
const uint8_t* src,
int ld_src,
unsigned ld_src,
uint8_t* dst,
int ld_dst) {
unsigned ld_dst) {
int i = 0;
for (; i < M / 16 * 16; i += 16) {
int j = 0;
Expand Down

0 comments on commit 8eb6dcb

Please sign in to comment.