Skip to content

Commit

Permalink
x86 sse2/avx2 optimization for convolution sgemm/winograd int8 family (
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui authored Oct 20, 2022
1 parent c33cbc9 commit 8eab5ea
Show file tree
Hide file tree
Showing 7 changed files with 590 additions and 1,259 deletions.
617 changes: 255 additions & 362 deletions src/layer/x86/convolution_3x3_pack8to1_int8.h

Large diffs are not rendered by default.

513 changes: 203 additions & 310 deletions src/layer/x86/convolution_3x3_pack8to4_int8.h

Large diffs are not rendered by default.

176 changes: 20 additions & 156 deletions src/layer/x86/convolution_sgemm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -338,17 +338,8 @@ static void im2col_sgemm_int8_sse(const Mat& bottom_im2col, Mat& top_blob, const

if (nn4 > 0)
{
#if __AVXVNNI__ || __AVX512VNNI__
__m256i _sum10_02 = _mm256_setzero_si256();
__m256i _sum30_22 = _mm256_setzero_si256();
#else
__m256i _sum10_02 = _mm256_setzero_si256();
__m256i _sum01_13 = _mm256_setzero_si256();
__m256i _sum11_03 = _mm256_setzero_si256();
__m256i _sum30_22 = _mm256_setzero_si256();
__m256i _sum21_33 = _mm256_setzero_si256();
__m256i _sum31_23 = _mm256_setzero_si256();
#endif

int j = 0;
for (; j < nn4; j++)
Expand All @@ -371,72 +362,21 @@ static void im2col_sgemm_int8_sse(const Mat& bottom_im2col, Mat& top_blob, const
_sum20_32 = _mm256_dpwssd_epi32(_sum20_32, _val23_16, _w01_16);
_sum30_22 = _mm256_dpwssd_epi32(_sum30_22, _val32_16, _w01_16);
#else
__m256i _sl00_11 = _mm256_mullo_epi16(_val01_16, _w01_16);
__m256i _sh00_11 = _mm256_mulhi_epi16(_val01_16, _w01_16);
__m256i _sl10_01 = _mm256_mullo_epi16(_val10_16, _w01_16);
__m256i _sh10_01 = _mm256_mulhi_epi16(_val10_16, _w01_16);
__m256i _sl20_31 = _mm256_mullo_epi16(_val23_16, _w01_16);
__m256i _sh20_31 = _mm256_mulhi_epi16(_val23_16, _w01_16);
__m256i _sl30_21 = _mm256_mullo_epi16(_val32_16, _w01_16);
__m256i _sh30_21 = _mm256_mulhi_epi16(_val32_16, _w01_16);

_sum00_12 = _mm256_add_epi32(_sum00_12, _mm256_unpacklo_epi16(_sl00_11, _sh00_11));
_sum10_02 = _mm256_add_epi32(_sum10_02, _mm256_unpacklo_epi16(_sl10_01, _sh10_01));
_sum01_13 = _mm256_add_epi32(_sum01_13, _mm256_unpackhi_epi16(_sl00_11, _sh00_11));
_sum11_03 = _mm256_add_epi32(_sum11_03, _mm256_unpackhi_epi16(_sl10_01, _sh10_01));
_sum20_32 = _mm256_add_epi32(_sum20_32, _mm256_unpacklo_epi16(_sl20_31, _sh20_31));
_sum30_22 = _mm256_add_epi32(_sum30_22, _mm256_unpacklo_epi16(_sl30_21, _sh30_21));
_sum21_33 = _mm256_add_epi32(_sum21_33, _mm256_unpackhi_epi16(_sl20_31, _sh20_31));
_sum31_23 = _mm256_add_epi32(_sum31_23, _mm256_unpackhi_epi16(_sl30_21, _sh30_21));
_sum00_12 = _mm256_add_epi32(_sum00_12, _mm256_madd_epi16(_val01_16, _w01_16));
_sum10_02 = _mm256_add_epi32(_sum10_02, _mm256_madd_epi16(_val10_16, _w01_16));
_sum20_32 = _mm256_add_epi32(_sum20_32, _mm256_madd_epi16(_val23_16, _w01_16));
_sum30_22 = _mm256_add_epi32(_sum30_22, _mm256_madd_epi16(_val32_16, _w01_16));
#endif

tmpptr += 16;
kptr0 += 16;
}

#if __AVXVNNI__ || __AVX512VNNI__
_sum00_12 = _mm256_hadd_epi32(_sum00_12, _sum10_02);
_sum20_32 = _mm256_hadd_epi32(_sum20_32, _sum30_22);

_sum00_12 = _mm256_permute4x64_epi64(_sum00_12, _MM_SHUFFLE(2, 1, 3, 0));
_sum20_32 = _mm256_permute4x64_epi64(_sum20_32, _MM_SHUFFLE(2, 1, 3, 0));
#else
// transpose 4x8
{
__m256i _tmp0, _tmp1, _tmp2, _tmp3;
_tmp0 = _mm256_unpacklo_epi32(_sum00_12, _sum10_02);
_tmp1 = _mm256_unpacklo_epi32(_sum01_13, _sum11_03);
_tmp2 = _mm256_unpackhi_epi32(_sum00_12, _sum10_02);
_tmp3 = _mm256_unpackhi_epi32(_sum01_13, _sum11_03);
_sum00_12 = _mm256_unpacklo_epi64(_tmp0, _tmp1);
_sum10_02 = _mm256_unpackhi_epi64(_tmp0, _tmp1);
_sum01_13 = _mm256_unpacklo_epi64(_tmp2, _tmp3);
_sum11_03 = _mm256_unpackhi_epi64(_tmp2, _tmp3);
}
{
__m256i _tmp0, _tmp1, _tmp2, _tmp3;
_tmp0 = _mm256_unpacklo_epi32(_sum20_32, _sum30_22);
_tmp1 = _mm256_unpacklo_epi32(_sum21_33, _sum31_23);
_tmp2 = _mm256_unpackhi_epi32(_sum20_32, _sum30_22);
_tmp3 = _mm256_unpackhi_epi32(_sum21_33, _sum31_23);
_sum20_32 = _mm256_unpacklo_epi64(_tmp0, _tmp1);
_sum30_22 = _mm256_unpackhi_epi64(_tmp0, _tmp1);
_sum21_33 = _mm256_unpacklo_epi64(_tmp2, _tmp3);
_sum31_23 = _mm256_unpackhi_epi64(_tmp2, _tmp3);
}

_sum00_12 = _mm256_add_epi32(_sum00_12, _sum10_02);
_sum01_13 = _mm256_add_epi32(_sum01_13, _sum11_03);
_sum00_12 = _mm256_add_epi32(_sum00_12, _sum01_13);

_sum20_32 = _mm256_add_epi32(_sum20_32, _sum30_22);
_sum21_33 = _mm256_add_epi32(_sum21_33, _sum31_23);
_sum20_32 = _mm256_add_epi32(_sum20_32, _sum21_33);

__m256i _perm_mask = _mm256_set_epi32(6, 4, 3, 1, 7, 5, 2, 0);
_sum00_12 = _mm256_permutevar8x32_epi32(_sum00_12, _perm_mask);
_sum20_32 = _mm256_permutevar8x32_epi32(_sum20_32, _perm_mask);
#endif
}

__m128i _sum00 = _mm256_extracti128_si256(_sum00_12, 0);
Expand Down Expand Up @@ -532,25 +472,10 @@ static void im2col_sgemm_int8_sse(const Mat& bottom_im2col, Mat& top_blob, const
if (nn4 > 0)
{
#if __AVX2__
#if __AVXVNNI__ || __AVX512VNNI__
__m256i _sum10_02 = _mm256_setzero_si256();
#else
__m256i _sum10_02 = _mm256_setzero_si256();
__m256i _sum01_13 = _mm256_setzero_si256();
__m256i _sum11_03 = _mm256_setzero_si256();
#endif
#else
#if __XOP__
__m128i _sum01 = _mm_setzero_si128();
__m128i _sum11 = _mm_setzero_si128();
#else
__m128i _sum01 = _mm_setzero_si128();
__m128i _sum02 = _mm_setzero_si128();
__m128i _sum03 = _mm_setzero_si128();
__m128i _sum11 = _mm_setzero_si128();
__m128i _sum12 = _mm_setzero_si128();
__m128i _sum13 = _mm_setzero_si128();
#endif
#endif

int j = 0;
Expand All @@ -571,15 +496,8 @@ static void im2col_sgemm_int8_sse(const Mat& bottom_im2col, Mat& top_blob, const
_sum00_12 = _mm256_dpwssd_epi32(_sum00_12, _val01_16, _w01_16);
_sum10_02 = _mm256_dpwssd_epi32(_sum10_02, _val10_16, _w01_16);
#else
__m256i _sl00_11 = _mm256_mullo_epi16(_val01_16, _w01_16);
__m256i _sh00_11 = _mm256_mulhi_epi16(_val01_16, _w01_16);
__m256i _sl10_01 = _mm256_mullo_epi16(_val10_16, _w01_16);
__m256i _sh10_01 = _mm256_mulhi_epi16(_val10_16, _w01_16);

_sum00_12 = _mm256_add_epi32(_sum00_12, _mm256_unpacklo_epi16(_sl00_11, _sh00_11));
_sum10_02 = _mm256_add_epi32(_sum10_02, _mm256_unpacklo_epi16(_sl10_01, _sh10_01));
_sum01_13 = _mm256_add_epi32(_sum01_13, _mm256_unpackhi_epi16(_sl00_11, _sh00_11));
_sum11_03 = _mm256_add_epi32(_sum11_03, _mm256_unpackhi_epi16(_sl10_01, _sh10_01));
_sum00_12 = _mm256_add_epi32(_sum00_12, _mm256_madd_epi16(_val01_16, _w01_16));
_sum10_02 = _mm256_add_epi32(_sum10_02, _mm256_madd_epi16(_val10_16, _w01_16));
#endif
#else
__m128i _val01 = _mm_loadl_epi64((const __m128i*)tmpptr);
Expand All @@ -604,23 +522,10 @@ static void im2col_sgemm_int8_sse(const Mat& bottom_im2col, Mat& top_blob, const
_sum10 = _mm_maddd_epi16(_val1, _w0, _sum10);
_sum11 = _mm_maddd_epi16(_val1, _w1, _sum11);
#else
__m128i _sl00 = _mm_mullo_epi16(_val0, _w0);
__m128i _sh00 = _mm_mulhi_epi16(_val0, _w0);
__m128i _sl01 = _mm_mullo_epi16(_val0, _w1);
__m128i _sh01 = _mm_mulhi_epi16(_val0, _w1);
__m128i _sl10 = _mm_mullo_epi16(_val1, _w0);
__m128i _sh10 = _mm_mulhi_epi16(_val1, _w0);
__m128i _sl11 = _mm_mullo_epi16(_val1, _w1);
__m128i _sh11 = _mm_mulhi_epi16(_val1, _w1);

_sum00 = _mm_add_epi32(_sum00, _mm_unpacklo_epi16(_sl00, _sh00));
_sum01 = _mm_add_epi32(_sum01, _mm_unpackhi_epi16(_sl00, _sh00));
_sum02 = _mm_add_epi32(_sum02, _mm_unpacklo_epi16(_sl01, _sh01));
_sum03 = _mm_add_epi32(_sum03, _mm_unpackhi_epi16(_sl01, _sh01));
_sum10 = _mm_add_epi32(_sum10, _mm_unpacklo_epi16(_sl10, _sh10));
_sum11 = _mm_add_epi32(_sum11, _mm_unpackhi_epi16(_sl10, _sh10));
_sum12 = _mm_add_epi32(_sum12, _mm_unpacklo_epi16(_sl11, _sh11));
_sum13 = _mm_add_epi32(_sum13, _mm_unpackhi_epi16(_sl11, _sh11));
_sum00 = _mm_add_epi32(_mm_madd_epi16(_val0, _w0), _sum00);
_sum01 = _mm_add_epi32(_mm_madd_epi16(_val0, _w1), _sum01);
_sum10 = _mm_add_epi32(_mm_madd_epi16(_val1, _w0), _sum10);
_sum11 = _mm_add_epi32(_mm_madd_epi16(_val1, _w1), _sum11);
#endif
#endif

Expand All @@ -629,67 +534,26 @@ static void im2col_sgemm_int8_sse(const Mat& bottom_im2col, Mat& top_blob, const
}

#if __AVX2__
#if __AVXVNNI__ || __AVX512VNNI__
_sum00_12 = _mm256_hadd_epi32(_sum00_12, _sum10_02);

_sum00_12 = _mm256_permute4x64_epi64(_sum00_12, _MM_SHUFFLE(2, 1, 3, 0));
#else
// transpose 4x8
{
__m256i _tmp0, _tmp1, _tmp2, _tmp3;
_tmp0 = _mm256_unpacklo_epi32(_sum00_12, _sum10_02);
_tmp1 = _mm256_unpacklo_epi32(_sum01_13, _sum11_03);
_tmp2 = _mm256_unpackhi_epi32(_sum00_12, _sum10_02);
_tmp3 = _mm256_unpackhi_epi32(_sum01_13, _sum11_03);
_sum00_12 = _mm256_unpacklo_epi64(_tmp0, _tmp1);
_sum10_02 = _mm256_unpackhi_epi64(_tmp0, _tmp1);
_sum01_13 = _mm256_unpacklo_epi64(_tmp2, _tmp3);
_sum11_03 = _mm256_unpackhi_epi64(_tmp2, _tmp3);
}

_sum00_12 = _mm256_add_epi32(_sum00_12, _sum10_02);
_sum01_13 = _mm256_add_epi32(_sum01_13, _sum11_03);
_sum00_12 = _mm256_add_epi32(_sum00_12, _sum01_13);

__m256i _perm_mask = _mm256_set_epi32(6, 4, 3, 1, 7, 5, 2, 0);
_sum00_12 = _mm256_permutevar8x32_epi32(_sum00_12, _perm_mask);
#endif
#else
#if __XOP__
#if __SSSE3__
_sum00 = _mm_hadd_epi32(_sum00, _sum01);
_sum10 = _mm_hadd_epi32(_sum10, _sum11);
#else
// transpose 4x4
{
__m128i _tmp0, _tmp1, _tmp2, _tmp3;
_tmp0 = _mm_unpacklo_epi32(_sum00, _sum01);
_tmp1 = _mm_unpacklo_epi32(_sum02, _sum03);
_tmp2 = _mm_unpackhi_epi32(_sum00, _sum01);
_tmp3 = _mm_unpackhi_epi32(_sum02, _sum03);
_sum00 = _mm_unpacklo_epi64(_tmp0, _tmp1);
_sum01 = _mm_unpackhi_epi64(_tmp0, _tmp1);
_sum02 = _mm_unpacklo_epi64(_tmp2, _tmp3);
_sum03 = _mm_unpackhi_epi64(_tmp2, _tmp3);
}
{
__m128i _tmp0, _tmp1, _tmp2, _tmp3;
_tmp0 = _mm_unpacklo_epi32(_sum10, _sum11);
_tmp1 = _mm_unpacklo_epi32(_sum12, _sum13);
_tmp2 = _mm_unpackhi_epi32(_sum10, _sum11);
_tmp3 = _mm_unpackhi_epi32(_sum12, _sum13);
_sum10 = _mm_unpacklo_epi64(_tmp0, _tmp1);
_sum11 = _mm_unpackhi_epi64(_tmp0, _tmp1);
_sum12 = _mm_unpacklo_epi64(_tmp2, _tmp3);
_sum13 = _mm_unpackhi_epi64(_tmp2, _tmp3);
}
__m128i _sum00_sh = _mm_shuffle_epi32(_sum00, 216);
__m128i _sum01_sh = _mm_shuffle_epi32(_sum01, 216);
__m128i _sum10_sh = _mm_shuffle_epi32(_sum10, 216);
__m128i _sum11_sh = _mm_shuffle_epi32(_sum11, 216);

_sum00 = _mm_unpacklo_epi64(_sum00_sh, _sum01_sh);
_sum01 = _mm_unpackhi_epi64(_sum00_sh, _sum01_sh);
_sum10 = _mm_unpacklo_epi64(_sum10_sh, _sum11_sh);
_sum11 = _mm_unpackhi_epi64(_sum10_sh, _sum11_sh);

_sum00 = _mm_add_epi32(_sum00, _sum01);
_sum02 = _mm_add_epi32(_sum02, _sum03);
_sum10 = _mm_add_epi32(_sum10, _sum11);
_sum12 = _mm_add_epi32(_sum12, _sum13);

_sum00 = _mm_add_epi32(_sum00, _sum02);
_sum10 = _mm_add_epi32(_sum10, _sum12);
#endif
#endif
}
Expand Down
Loading

0 comments on commit 8eab5ea

Please sign in to comment.