Skip to content

Commit

Permalink
Optimize FindFirstCharToEncode for JavaScriptEncoder.Default using Ss…
Browse files Browse the repository at this point in the history
…se3 intrinsics (dotnet/corefx#42073)

Commit migrated from dotnet/corefx@ba320d4
  • Loading branch information
gfoidl authored and GrabYourPitchforks committed Nov 3, 2019
1 parent 420e273 commit 9e08ecd
Show file tree
Hide file tree
Showing 5 changed files with 343 additions and 92 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
</ItemGroup>
<ItemGroup Condition="'$(TargetsNetCoreApp)' == 'true'">
<Compile Include="System\Text\Encodings\Web\Sse2Helper.cs" />
<Compile Include="System\Text\Encodings\Web\Ssse3Helper.cs" />
</ItemGroup>
<ItemGroup>
<Compile Include="$(CommonPath)\CoreLib\System\Text\UnicodeDebug.cs">
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,122 +73,268 @@ public override bool WillEncode(int unicodeScalar)
return NeedsEscaping((char)unicodeScalar);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public override unsafe int FindFirstCharacterToEncode(char* text, int textLength)
{
if (text == null)
{
throw new ArgumentNullException(nameof(text));
}

Debug.Assert(textLength >= 0);

if (textLength == 0)
{
goto AllAllowed;
}

int idx = 0;
short* ptr = (short*)text;
short* end = ptr + (uint)textLength;

#if NETCOREAPP
if (Sse2.IsSupported && textLength >= Vector128<short>.Count)
{
goto VectorizedEntry;
}

Sequential:
#endif
Debug.Assert(textLength > 0 && ptr < end);

do
{
Debug.Assert(text <= ptr && ptr < (text + textLength));

if (NeedsEscaping(*(char*)ptr))
{
goto Return;
}

ptr++;
idx++;
}
while (ptr < end);

AllAllowed:
idx = -1;

Return:
return idx;

#if NETCOREAPP
if (Sse2.IsSupported)
VectorizedEntry:
int index;
short* vectorizedEnd;

if (textLength >= 2 * Vector128<short>.Count)
{
short* startingAddress = (short*)text;
while (textLength - 8 >= idx)
vectorizedEnd = end - 2 * Vector128<short>.Count;

do
{
Debug.Assert(startingAddress >= text && startingAddress <= (text + textLength - 8));
Debug.Assert(text <= ptr && ptr <= (text + textLength - 2 * Vector128<short>.Count));

// Load the next 8 characters.
Vector128<short> sourceValue = Sse2.LoadVector128(startingAddress);
// Load the next 16 characters, combine them to one byte vector.
// Chars that don't cleanly convert to ASCII bytes will get converted (saturated) to
// somewhere in the range [0x7F, 0xFF], which the NeedsEscaping method will detect.
Vector128<sbyte> sourceValue = Sse2.PackSignedSaturate(
Sse2.LoadVector128(ptr),
Sse2.LoadVector128(ptr + Vector128<short>.Count));

// Check if any of the 8 characters need to be escaped.
Vector128<short> mask = Sse2Helper.CreateEscapingMask_DefaultJavaScriptEncoderBasicLatin(sourceValue);
// Check if any of the 16 characters need to be escaped.
index = NeedsEscaping(sourceValue);

int index = Sse2.MoveMask(mask.AsByte());
// If index == 0, that means none of the 8 characters needed to be escaped.
// If index == 0, that means none of the 16 characters needed to be escaped.
// TrailingZeroCount is relatively expensive, avoid it if possible.
if (index != 0)
{
// Found at least one character that needs to be escaped, figure out the index of
// the first one found that needed to be escaped within the 8 characters.
Debug.Assert(index > 0 && index <= 65_535);
int tzc = BitOperations.TrailingZeroCount(index);
Debug.Assert(tzc % 2 == 0 && tzc >= 0 && tzc <= 16);
idx += tzc >> 1;
goto Return;
goto VectorizedFound;
}
idx += 8;
startingAddress += 8;
}

// Process the remaining characters.
Debug.Assert(textLength - idx < 8);
ptr += 2 * Vector128<short>.Count;
}
while (ptr <= vectorizedEnd);
}
#endif

for (; idx < textLength; idx++)
vectorizedEnd = end - Vector128<short>.Count;

Vectorized:
// PERF: JIT produces better code for do-while as for a while-loop (no spills)
if (ptr <= vectorizedEnd)
{
Debug.Assert((text + idx) <= (text + textLength));
if (NeedsEscaping(*(text + idx)))
do
{
goto Return;
Debug.Assert(text <= ptr && ptr <= (text + textLength - Vector128<short>.Count));

// Load the next 8 characters + a dummy known that it must not be escaped.
// Put the dummy second, so it's easier for GetIndexOfFirstNeedToEscape.
Vector128<sbyte> sourceValue = Sse2.PackSignedSaturate(
Sse2.LoadVector128(ptr),
Vector128.Create((short)'A')); // max. one "iteration", so no need to cache this vector

index = NeedsEscaping(sourceValue);

// If index == 0, that means none of the 16 bytes needed to be escaped.
// TrailingZeroCount is relatively expensive, avoid it if possible.
if (index != 0)
{
goto VectorizedFound;
}

ptr += Vector128<short>.Count;
}
while (ptr <= vectorizedEnd);
}

idx = -1; // All characters are allowed.
// Process the remaining characters.
Debug.Assert(end - ptr < Vector128<short>.Count);

Return:
// Process the remaining elements vectorized, only if the remaining count
// is above thresholdForRemainingVectorized, otherwise process them sequential.
// Threshold found by testing.
const int thresholdForRemainingVectorized = 5;
if (ptr < end - thresholdForRemainingVectorized)
{
ptr = vectorizedEnd;
goto Vectorized;
}

idx = CalculateIndex(ptr, text);

if (idx < textLength)
{
goto Sequential;
}

goto AllAllowed;

VectorizedFound:
idx = GetIndexOfFirstNeedToEscape(index);
idx += CalculateIndex(ptr, text);
return idx;

static int CalculateIndex(short* ptr, char* text)
{
// Subtraction with short* results in a idiv, so use byte* and shift
return (int)(((byte*)ptr - (byte*)text) >> 1);
}
#endif
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public override unsafe int FindFirstCharacterToEncodeUtf8(ReadOnlySpan<byte> utf8Text)
{
fixed (byte* ptr = utf8Text)
fixed (byte* pValue = utf8Text)
{
uint textLength = (uint)utf8Text.Length;

if (textLength == 0)
{
goto AllAllowed;
}

int idx = 0;
byte* ptr = pValue;
byte* end = ptr + textLength;

#if NETCOREAPP
if (Sse2.IsSupported)
{
sbyte* startingAddress = (sbyte*)ptr;
while (utf8Text.Length - 16 >= idx)
{
Debug.Assert(startingAddress >= ptr && startingAddress <= (ptr + utf8Text.Length - 16));

// Load the next 16 bytes.
Vector128<sbyte> sourceValue = Sse2.LoadVector128(startingAddress);

// Check if any of the 16 bytes need to be escaped.
Vector128<sbyte> mask = Sse2Helper.CreateEscapingMask_DefaultJavaScriptEncoderBasicLatin(sourceValue);

int index = Sse2.MoveMask(mask);
// If index == 0, that means none of the 16 bytes needed to be escaped.
// TrailingZeroCount is relatively expensive, avoid it if possible.
if (index != 0)
{
// Found at least one byte that needs to be escaped, figure out the index of
// the first one found that needed to be escaped within the 16 bytes.
int tzc = BitOperations.TrailingZeroCount(index);
Debug.Assert(tzc >= 0 && tzc <= 16);
idx += tzc;
goto Return;
}
idx += 16;
startingAddress += 16;
}

// Process the remaining bytes.
Debug.Assert(utf8Text.Length - idx < 16);
if (Sse2.IsSupported && textLength >= Vector128<sbyte>.Count)
{
goto Vectorized;
}

Sequential:
#endif
Debug.Assert(textLength > 0 && ptr < end);

for (; idx < utf8Text.Length; idx++)
do
{
Debug.Assert((ptr + idx) <= (ptr + utf8Text.Length));
if (NeedsEscaping(*(ptr + idx)))
Debug.Assert(pValue <= ptr && ptr < (pValue + utf8Text.Length));

if (NeedsEscaping(*ptr))
{
goto Return;
}

ptr++;
idx++;
}
while (ptr < end);

idx = -1; // All bytes are allowed.
AllAllowed:
idx = -1;

Return:
return idx;

#if NETCOREAPP
Vectorized:
byte* vectorizedEnd = end - Vector128<byte>.Count;
int index;

do
{
Debug.Assert(pValue <= ptr && ptr <= (pValue + utf8Text.Length - Vector128<byte>.Count));
// Load the next 16 bytes
Vector128<sbyte> sourceValue = Sse2.LoadVector128((sbyte*)ptr);

index = NeedsEscaping(sourceValue);

// If index == 0, that means none of the 16 bytes needed to be escaped.
// TrailingZeroCount is relatively expensive, avoid it if possible.
if (index != 0)
{
goto VectorizedFound;
}

ptr += Vector128<sbyte>.Count;
}
while (ptr <= vectorizedEnd);

// Process the remaining elements.
Debug.Assert(end - ptr < Vector128<byte>.Count);

// Process the remaining elements vectorized, only if the remaining count
// is above thresholdForRemainingVectorized, otherwise process them sequential.
const int thresholdForRemainingVectorized = 4;
if (ptr < end - thresholdForRemainingVectorized)
{
// PERF: duplicate instead of jumping at the beginning of the previous loop
// otherwise all the static data (vectors) will be re-assigned to registers,
// so they are re-used.

Debug.Assert(pValue <= vectorizedEnd && vectorizedEnd <= (pValue + utf8Text.Length - Vector128<byte>.Count));

// Load the last 16 bytes
Vector128<sbyte> sourceValue = Sse2.LoadVector128((sbyte*)vectorizedEnd);

index = NeedsEscaping(sourceValue);
if (index != 0)
{
ptr = vectorizedEnd;
goto VectorizedFound;
}

idx = -1;
goto Return;
}

idx = CalculateIndex(ptr, pValue);

if (idx < textLength)
{
goto Sequential;
}

goto AllAllowed;

VectorizedFound:
idx = GetIndexOfFirstNeedToEscape(index);
idx += CalculateIndex(ptr, pValue);
return idx;

static int CalculateIndex(byte* ptr, byte* pValue) => (int)(ptr - pValue);
#endif
}
}

Expand Down Expand Up @@ -285,5 +431,35 @@ public override unsafe bool TryEncodeUnicodeScalar(int unicodeScalar, char* buff

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool NeedsEscaping(char value) => value > LastAsciiCharacter || AllowList[value] == 0;

#if NETCOREAPP
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static int NeedsEscaping(Vector128<sbyte> sourceValue)
{
Debug.Assert(Sse2.IsSupported);

// Check if any of the 16 bytes need to be escaped.
Vector128<sbyte> mask = Ssse3.IsSupported
? Ssse3Helper.CreateEscapingMask_DefaultJavaScriptEncoderBasicLatin(sourceValue)
: Sse2Helper.CreateEscapingMask_DefaultJavaScriptEncoderBasicLatin(sourceValue);

int index = Sse2.MoveMask(mask.AsByte());
return index;
}

// PERF: don't manually inline or call this method in NeedsEscaping
// as the resulting asm won't be great
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static int GetIndexOfFirstNeedToEscape(int index)
{
// Found at least one byte that needs to be escaped, figure out the index of
// the first one found that needed to be escaped within the 16 bytes.
Debug.Assert(index > 0 && index <= 65_535);
int tzc = BitOperations.TrailingZeroCount(index);
Debug.Assert(tzc >= 0 && tzc <= 16);

return tzc;
}
#endif
}
}
Loading

0 comments on commit 9e08ecd

Please sign in to comment.