Skip to content

Commit

Permalink
Transpose for 16b tensors (microsoft#14877)
Browse files Browse the repository at this point in the history
### Description
Matrix transpose for 16b tensors (shorts, and half precision floats)


### Motivation and Context

Need it for fp16 operations
  • Loading branch information
chenfucn authored Mar 2, 2023
1 parent 7cd4b33 commit 603026f
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 2 deletions.
9 changes: 9 additions & 0 deletions onnxruntime/core/mlas/inc/mlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -1051,6 +1051,15 @@ MlasTranspose(
size_t N
);

void
MLASCALL
MlasTranspose(
const uint16_t* Input,
uint16_t* Output,
size_t M,
size_t N
);

void
MLASCALL
MlasTranspose(
Expand Down
162 changes: 162 additions & 0 deletions onnxruntime/core/mlas/lib/transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,32 @@ MlasTranspose4x4Block(
_mm_storeu_si128((__m128i*)&Output[OutputStride * 3], c3);
}

MLAS_FORCEINLINE
void
MlasTranspose4x4Block(
const uint16_t* Input,
size_t InputStride,
uint16_t* Output,
size_t OutputStride
)
{
__m128i a0 = _mm_loadl_epi64((const __m128i*)&Input[InputStride * 0]);
__m128i a1 = _mm_loadl_epi64((const __m128i*)&Input[InputStride * 1]);
__m128i a2 = _mm_loadl_epi64((const __m128i*)&Input[InputStride * 2]);
__m128i a3 = _mm_loadl_epi64((const __m128i*)&Input[InputStride * 3]);

__m128i b0 = _mm_unpacklo_epi16(a0, a2);
__m128i b1 = _mm_unpacklo_epi16(a1, a3);

__m128i c0 = _mm_unpacklo_epi16(b0, b1);
__m128i c1 = _mm_unpackhi_epi16(b0, b1);

_mm_storel_pi((__m64*)&Output[OutputStride * 0], _mm_castsi128_ps(c0));
_mm_storeh_pi((__m64*)&Output[OutputStride * 1], _mm_castsi128_ps(c0));
_mm_storel_pi((__m64*)&Output[OutputStride * 2], _mm_castsi128_ps(c1));
_mm_storeh_pi((__m64*)&Output[OutputStride * 3], _mm_castsi128_ps(c1));
}

MLAS_FORCEINLINE
void
MlasTranspose8x8Block(
Expand Down Expand Up @@ -123,6 +149,32 @@ MlasTranspose4x4Block(
vst1q_u32(&Output[OutputStride * 3], c1.val[1]);
}

MLAS_FORCEINLINE
void
MlasTranspose4x4Block(
const uint16_t* Input,
size_t InputStride,
uint16_t* Output,
size_t OutputStride
)
{
uint16x4_t a0 = vld1_u16(&Input[InputStride * 0]);
uint16x4_t a1 = vld1_u16(&Input[InputStride * 1]);
uint16x4_t a2 = vld1_u16(&Input[InputStride * 2]);
uint16x4_t a3 = vld1_u16(&Input[InputStride * 3]);

uint16x4x2_t b0 = vzip_u16(a0, a2);
uint16x4x2_t b1 = vzip_u16(a1, a3);

uint16x4x2_t c0 = vzip_u16(b0.val[0], b1.val[0]);
uint16x4x2_t c1 = vzip_u16(b0.val[1], b1.val[1]);

vst1_u16(&Output[OutputStride * 0], c0.val[0]);
vst1_u16(&Output[OutputStride * 1], c0.val[1]);
vst1_u16(&Output[OutputStride * 2], c1.val[0]);
vst1_u16(&Output[OutputStride * 3], c1.val[1]);
}

MLAS_FORCEINLINE
void
MlasTranspose8x8Block(
Expand Down Expand Up @@ -498,6 +550,116 @@ MlasTranspose(
N);
}


void
MLASCALL
MlasTranspose(
const uint16_t* Input,
uint16_t* Output,
size_t M,
size_t N
)
/*++
Routine Description:
This routine transposes the input matrix (M rows by N columns) to the
output matrix (N rows by M columns).
Arguments:
Input - Supplies the input buffer.
Output - Supplies the output buffer.
M - Supplies the number of rows for the input matrix and the number of
columns for the output matrix.
N - Supplies the number of columns for the input matrix and the number of
rows for the output matrix.
Return Value:
None.
--*/
{
size_t n = N;

//
// Transpose elements from the input matrix to the output matrix 4 columns
// at a time.
//

while (n >= 4) {

const uint16_t* s = Input;
uint16_t* d = Output;
size_t m = M;

#if defined(MLAS_SSE2_INTRINSICS) || defined(MLAS_NEON_INTRINSICS)

while (m >= 4) {

MlasTranspose4x4Block(s, N, d, M);

s += N * 4;
d += 4;
m -= 4;
}

#endif

while (m > 0) {

MlasTranspose4xNVector(s, 1, d, M);

s += N;
d += 1;
m -= 1;
}

Input += 4;
Output += M * 4;
n -= 4;
}

//
// Transpose elements from the input matrix to the output matrix for the
// remaining columns.
//

while (n > 0) {

const uint16_t* s = Input;
uint16_t* d = Output;
size_t m = M;

while (m >= 4) {

MlasTranspose4xNVector(s, N, d, 1);

s += N * 4;
d += 4;
m -= 4;
}

while (m > 0) {

d[0] = s[0];

s += N;
d += 1;
m -= 1;
}

Input += 1;
Output += M;
n -= 1;
}
}


void
MLASCALL
MlasTranspose(
Expand Down
6 changes: 4 additions & 2 deletions onnxruntime/test/mlas/unittest/test_transpose.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@ class MlasTransposeTest : public MlasTestBase {
};

template <> MlasTransposeTest<uint32_t>* MlasTestFixture<MlasTransposeTest<uint32_t>>::mlas_tester(nullptr);
template <> MlasTransposeTest<uint16_t>* MlasTestFixture<MlasTransposeTest<uint16_t>>::mlas_tester(nullptr);
template <> MlasTransposeTest<uint8_t>* MlasTestFixture<MlasTransposeTest<uint8_t>>::mlas_tester(nullptr);

static UNUSED_VARIABLE bool added_to_main = AddTestRegister([](bool is_short_execute) {
size_t count = 0;
if (is_short_execute) {
count += MlasDirectShortExecuteTests<MlasTransposeTest<uint32_t>>::RegisterShortExecute();
count += MlasDirectShortExecuteTests<MlasTransposeTest<uint8_t>>::RegisterShortExecute();
count += MlasDirectShortExecuteTests<MlasTransposeTest<uint32_t>>::RegisterShortExecute();
count += MlasDirectShortExecuteTests<MlasTransposeTest<uint16_t>>::RegisterShortExecute();
count += MlasDirectShortExecuteTests<MlasTransposeTest<uint8_t>>::RegisterShortExecute();
}
return count;
});

0 comments on commit 603026f

Please sign in to comment.