Skip to content

Commit

Permalink
Improve SSE2 implementations in x86 targets. (Tencent#3605)
Browse files Browse the repository at this point in the history
* Make some typos for SSE2 floor.

* Improve the implementation of SSE2 abs.

* Improve the implementation of SSE2 ceil.
  • Loading branch information
MouriNaruto authored Mar 9, 2022
1 parent 2b4a212 commit 8e29c42
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 31 deletions.
72 changes: 42 additions & 30 deletions src/layer/x86/unaryop_x86.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,17 @@ struct unary_op_floor
return (__m128)_mm_floor_ps(x);
#endif // __SSE4_1__

// The sign bit mask.
const __m128 magic_sign_bit = _mm_set_ps1(-0.0f);
// Use negative zero as the sign bit mask.
const __m128 magic_negative_zero = _mm_set_ps1(-0.0f);

// The smallest float number that have no fractional part. (2^23)
const __m128 magic_smallest_no_fraction = _mm_set_ps1(8388608.0f);

// absolute = abs(x);
__m128 absolute = _mm_andnot_ps(magic_sign_bit, x);
__m128 absolute = _mm_andnot_ps(magic_negative_zero, x);

// negative_mask = magic_sign_bit && x;
__m128 negative_mask = _mm_and_ps(magic_sign_bit, x);
// negative_mask = magic_negative_zero && x;
__m128 negative_mask = _mm_and_ps(magic_negative_zero, x);

// no_fraction = (magic_smallest_no_fraction < absolute);
__m128 no_fraction = _mm_cmplt_ps(magic_smallest_no_fraction, absolute);
Expand All @@ -164,7 +164,7 @@ struct unary_op_floor
// fixed_result = truncated_with_sign - negative_fix;
__m128 fixed_result = _mm_sub_ps(truncated_with_sign, negative_fix);

// return ((x && no_fraction) || (!no_fraction && negative_fix));
// return ((x && no_fraction) || (!no_fraction && fixed_result));
return _mm_or_ps(
_mm_and_ps(x, no_fraction),
_mm_andnot_ps(no_fraction, fixed_result));
Expand All @@ -190,30 +190,42 @@ struct unary_op_ceil
#if __SSE4_1__
return (__m128)_mm_ceil_ps(x);
#endif // __SSE4_1__
const __m128 magic_negative_one = _mm_set_ps1(-1.0f);
const __m128 magic_infinity = _mm_set_ps1(INFINITY);
const __m128 magic_max_fraction = _mm_set_ps1(8388607.5f);

__m128i v1 = _mm_castps_si128(x);
__m128i v2 = _mm_srai_epi32(v1, 31);
__m128i v3 = _mm_cmpgt_epi32(
_mm_and_si128(v1, _mm_castps_si128(magic_infinity)),
_mm_castps_si128(magic_max_fraction));
__m128 v4 = _mm_castsi128_ps(_mm_or_si128(
_mm_andnot_si128(
v3,
_mm_or_si128(
_mm_castps_si128(_mm_cvtepi32_ps(_mm_cvttps_epi32(x))),
_mm_slli_epi32(v2, 31))),
_mm_and_si128(v1, v3)));

return _mm_sub_ps(
v4,
_mm_castsi128_ps(_mm_andnot_si128(
v2,
_mm_andnot_si128(
_mm_cmpeq_epi32(v1, _mm_castps_si128(v4)),
_mm_castps_si128(magic_negative_one)))));

// Use negative zero as the sign bit mask.
const __m128 magic_negative_zero = _mm_set_ps1(-0.0f);

// The smallest float number that have no fractional part. (2^23)
const __m128 magic_smallest_no_fraction = _mm_set_ps1(8388608.0f);

// absolute = abs(x);
__m128 absolute = _mm_andnot_ps(magic_negative_zero, x);

// negative_mask = magic_negative_zero && x;
__m128 negative_mask = _mm_and_ps(magic_negative_zero, x);

// no_fraction = (magic_smallest_no_fraction < absolute);
__m128 no_fraction = _mm_cmplt_ps(magic_smallest_no_fraction, absolute);

// truncated = static_cast<float>(static_cast<uint32_t>(absolute));
__m128 truncated = _mm_cvtepi32_ps(_mm_cvttps_epi32(absolute));

// truncated_with_sign = (truncated || negative_mask);
__m128 truncated_with_sign = _mm_or_ps(truncated, negative_mask);

// positive_fix = ((x > -0.0f) && (x > truncated_with_sign) ? -1.0f : 0.0f);
__m128 positive_fix = _mm_and_ps(
_mm_and_ps(
_mm_cmpgt_ps(x, magic_negative_zero),
_mm_cmpgt_ps(x, truncated_with_sign)),
_mm_set_ps1(-1.0f));

// fixed_result = truncated_with_sign - positive_fix;
__m128 fixed_result = _mm_sub_ps(truncated_with_sign, positive_fix);

// return ((x && no_fraction) || (!no_fraction && fixed_result));
return _mm_or_ps(
_mm_and_ps(x, no_fraction),
_mm_andnot_ps(no_fraction, fixed_result));
}
#if __AVX__
__m256 operator()(const __m256& x) const
Expand Down
6 changes: 5 additions & 1 deletion src/layer/x86/x86_activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,11 @@ static NCNN_FORCEINLINE __m128 hardswish_sse(__m128 inputs, __m128 a, __m128 b)

static NCNN_FORCEINLINE __m128 abs_sse(__m128 inputs)
{
return _mm_max_ps(_mm_sub_ps(_mm_setzero_ps(), inputs), inputs);
// Use negative zero as the sign bit mask.
const __m128 magic_negative_zero = _mm_set_ps1(-0.0f);

// return (!magic_negative_zero && x);
return _mm_andnot_ps(magic_negative_zero, inputs);
}

static NCNN_FORCEINLINE __m128 lrelu_sse(__m128 inputs, float slope)
Expand Down

0 comments on commit 8e29c42

Please sign in to comment.