Skip to content

Commit

Permalink
Use Sse2 instrinsics to make NeedsEscaping check faster for large JSO…
Browse files Browse the repository at this point in the history
…N strings (dotnet/corefx#41845)

* Use Sse2 instrinsics to make NeedsEscaping check faster for large
strings.

* Update the utf-8 bytes needsescaping and add tests.

* Remove unnecessary bitwise OR and add more tests

* Add more tests around surrogates, invalid strings, and characters >
short.MaxValue.


Commit migrated from dotnet/corefx@7cae92b
  • Loading branch information
ahsonkhan authored Oct 22, 2019
1 parent 8830130 commit 0089be5
Show file tree
Hide file tree
Showing 3 changed files with 571 additions and 29 deletions.
1 change: 1 addition & 0 deletions src/libraries/System.Text.Json/src/System.Text.Json.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@
<Reference Include="System.Resources.ResourceManager" />
<Reference Include="System.Runtime" />
<Reference Include="System.Runtime.Extensions" />
<Reference Include="System.Runtime.Intrinsics" />
<Reference Include="System.Text.Encoding.Extensions" />
</ItemGroup>
<ItemGroup>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,16 @@
using System.Buffers;
using System.Buffers.Text;
using System.Diagnostics;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Text.Encodings.Web;

#if BUILDING_INBOX_LIBRARY
using System.Runtime.Intrinsics;
using System.Runtime.Intrinsics.X86;
#endif

namespace System.Text.Json
{
// TODO: Replace the escaping logic with publicly shipping APIs from https://github.com/dotnet/corefx/issues/33509
Expand Down Expand Up @@ -55,57 +61,202 @@ internal static partial class JsonWriterHelper
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static bool NeedsEscaping(char value) => value > LastAsciiCharacter || AllowList[value] == 0;

public static int NeedsEscaping(ReadOnlySpan<byte> value, JavaScriptEncoder encoder)
#if BUILDING_INBOX_LIBRARY
private static readonly Vector128<short> s_mask_UInt16_0x20 = Vector128.Create((short)0x20); // Space ' '

private static readonly Vector128<short> s_mask_UInt16_0x22 = Vector128.Create((short)0x22); // Quotation Mark '"'
private static readonly Vector128<short> s_mask_UInt16_0x26 = Vector128.Create((short)0x26); // Ampersand '&'
private static readonly Vector128<short> s_mask_UInt16_0x27 = Vector128.Create((short)0x27); // Apostrophe '''
private static readonly Vector128<short> s_mask_UInt16_0x2B = Vector128.Create((short)0x2B); // Plus sign '+'
private static readonly Vector128<short> s_mask_UInt16_0x3C = Vector128.Create((short)0x3C); // Less Than Sign '<'
private static readonly Vector128<short> s_mask_UInt16_0x3E = Vector128.Create((short)0x3E); // Greater Than Sign '>'
private static readonly Vector128<short> s_mask_UInt16_0x5C = Vector128.Create((short)0x5C); // Reverse Solidus '\'
private static readonly Vector128<short> s_mask_UInt16_0x60 = Vector128.Create((short)0x60); // Grave Access '`'

private static readonly Vector128<short> s_mask_UInt16_0x7E = Vector128.Create((short)0x7E); // Tilde '~'

private static readonly Vector128<sbyte> s_mask_SByte_0x20 = Vector128.Create((sbyte)0x20); // Space ' '

private static readonly Vector128<sbyte> s_mask_SByte_0x22 = Vector128.Create((sbyte)0x22); // Quotation Mark '"'
private static readonly Vector128<sbyte> s_mask_SByte_0x26 = Vector128.Create((sbyte)0x26); // Ampersand '&'
private static readonly Vector128<sbyte> s_mask_SByte_0x27 = Vector128.Create((sbyte)0x27); // Apostrophe '''
private static readonly Vector128<sbyte> s_mask_SByte_0x2B = Vector128.Create((sbyte)0x2B); // Plus sign '+'
private static readonly Vector128<sbyte> s_mask_SByte_0x3C = Vector128.Create((sbyte)0x3C); // Less Than Sign '<'
private static readonly Vector128<sbyte> s_mask_SByte_0x3E = Vector128.Create((sbyte)0x3E); // Greater Than Sign '>'
private static readonly Vector128<sbyte> s_mask_SByte_0x5C = Vector128.Create((sbyte)0x5C); // Reverse Solidus '\'
private static readonly Vector128<sbyte> s_mask_SByte_0x60 = Vector128.Create((sbyte)0x60); // Grave Access '`'

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static Vector128<short> CreateEscapingMask(Vector128<short> sourceValue)
{
int idx;
Debug.Assert(Sse2.IsSupported);

if (encoder != null)
{
idx = encoder.FindFirstCharacterToEncodeUtf8(value);
goto Return;
}
Vector128<short> mask = Sse2.CompareLessThan(sourceValue, s_mask_UInt16_0x20); // Space ' ', anything in the control characters range

mask = Sse2.Or(mask, Sse2.CompareEqual(sourceValue, s_mask_UInt16_0x22)); // Quotation Mark '"'
mask = Sse2.Or(mask, Sse2.CompareEqual(sourceValue, s_mask_UInt16_0x26)); // Ampersand '&'
mask = Sse2.Or(mask, Sse2.CompareEqual(sourceValue, s_mask_UInt16_0x27)); // Apostrophe '''
mask = Sse2.Or(mask, Sse2.CompareEqual(sourceValue, s_mask_UInt16_0x2B)); // Plus sign '+'

mask = Sse2.Or(mask, Sse2.CompareEqual(sourceValue, s_mask_UInt16_0x3C)); // Less Than Sign '<'
mask = Sse2.Or(mask, Sse2.CompareEqual(sourceValue, s_mask_UInt16_0x3E)); // Greater Than Sign '>'
mask = Sse2.Or(mask, Sse2.CompareEqual(sourceValue, s_mask_UInt16_0x5C)); // Reverse Solidus '\'
mask = Sse2.Or(mask, Sse2.CompareEqual(sourceValue, s_mask_UInt16_0x60)); // Grave Access '`'

mask = Sse2.Or(mask, Sse2.CompareGreaterThan(sourceValue, s_mask_UInt16_0x7E)); // Tilde '~', anything above the ASCII range

return mask;
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static Vector128<sbyte> CreateEscapingMask(Vector128<sbyte> sourceValue)
{
Debug.Assert(Sse2.IsSupported);

for (idx = 0; idx < value.Length; idx++)
Vector128<sbyte> mask = Sse2.CompareLessThan(sourceValue, s_mask_SByte_0x20); // Control characters, and anything above 0x7E since sbyte.MaxValue is 0x7E

mask = Sse2.Or(mask, Sse2.CompareEqual(sourceValue, s_mask_SByte_0x22)); // Quotation Mark "
mask = Sse2.Or(mask, Sse2.CompareEqual(sourceValue, s_mask_SByte_0x26)); // Ampersand &
mask = Sse2.Or(mask, Sse2.CompareEqual(sourceValue, s_mask_SByte_0x27)); // Apostrophe '
mask = Sse2.Or(mask, Sse2.CompareEqual(sourceValue, s_mask_SByte_0x2B)); // Plus sign +

mask = Sse2.Or(mask, Sse2.CompareEqual(sourceValue, s_mask_SByte_0x3C)); // Less Than Sign <
mask = Sse2.Or(mask, Sse2.CompareEqual(sourceValue, s_mask_SByte_0x3E)); // Greater Than Sign >
mask = Sse2.Or(mask, Sse2.CompareEqual(sourceValue, s_mask_SByte_0x5C)); // Reverse Solidus \
mask = Sse2.Or(mask, Sse2.CompareEqual(sourceValue, s_mask_SByte_0x60)); // Grave Access `

return mask;
}
#endif

public static unsafe int NeedsEscaping(ReadOnlySpan<byte> value, JavaScriptEncoder encoder)
{
fixed (byte* ptr = value)
{
if (NeedsEscaping(value[idx]))
int idx = 0;

if (encoder != null)
{
idx = encoder.FindFirstCharacterToEncodeUtf8(value);
goto Return;
}
}

idx = -1; // all characters allowed
#if BUILDING_INBOX_LIBRARY
if (Sse2.IsSupported)
{
sbyte* startingAddress = (sbyte*)ptr;
while (value.Length - 16 >= idx)
{
Debug.Assert(startingAddress >= ptr && startingAddress <= (ptr + value.Length - 16));

// Load the next 16 bytes.
Vector128<sbyte> sourceValue = Sse2.LoadVector128(startingAddress);

// Check if any of the 16 bytes need to be escaped.
Vector128<sbyte> mask = CreateEscapingMask(sourceValue);

int index = Sse2.MoveMask(mask.AsByte());
// If index == 0, that means none of the 16 bytes needed to be escaped.
// TrailingZeroCount is relatively expensive, avoid it if possible.
if (index != 0)
{
// Found at least one byte that needs to be escaped, figure out the index of
// the first one found that needed to be escaped within the 16 bytes.
Debug.Assert(index > 0 && index <= 65_535);
int tzc = BitOperations.TrailingZeroCount(index);
Debug.Assert(tzc >= 0 && tzc <= 16);
idx += tzc;
goto Return;
}
idx += 16;
startingAddress += 16;
}

// Process the remaining characters.
Debug.Assert(value.Length - idx < 16);
}
#endif

for (; idx < value.Length; idx++)
{
Debug.Assert((ptr + idx) <= (ptr + value.Length));
if (NeedsEscaping(*(ptr + idx)))
{
goto Return;
}
}

Return:
return idx;
idx = -1; // all characters allowed

Return:
return idx;
}
}

public static unsafe int NeedsEscaping(ReadOnlySpan<char> value, JavaScriptEncoder encoder)
{
int idx;

// Some implementations of JavascriptEncoder.FindFirstCharacterToEncode may not accept
// null pointers and gaurd against that. Hence, check up-front and fall down to return -1.
if (encoder != null && !value.IsEmpty)
fixed (char* ptr = value)
{
fixed (char* ptr = value)
int idx = 0;

// Some implementations of JavascriptEncoder.FindFirstCharacterToEncode may not accept
// null pointers and gaurd against that. Hence, check up-front and fall down to return -1.
if (encoder != null && !value.IsEmpty)
{
idx = encoder.FindFirstCharacterToEncode(ptr, value.Length);
goto Return;
}
goto Return;
}

for (idx = 0; idx < value.Length; idx++)
{
if (NeedsEscaping(value[idx]))
#if BUILDING_INBOX_LIBRARY
if (Sse2.IsSupported)
{
goto Return;
short* startingAddress = (short*)ptr;
while (value.Length - 8 >= idx)
{
Debug.Assert(startingAddress >= ptr && startingAddress <= (ptr + value.Length - 8));

// Load the next 8 characters.
Vector128<short> sourceValue = Sse2.LoadVector128(startingAddress);

// Check if any of the 8 characters need to be escaped.
Vector128<short> mask = CreateEscapingMask(sourceValue);

int index = Sse2.MoveMask(mask.AsByte());
// If index == 0, that means none of the 8 characters needed to be escaped.
// TrailingZeroCount is relatively expensive, avoid it if possible.
if (index != 0)
{
// Found at least one character that needs to be escaped, figure out the index of
// the first one found that needed to be escaped within the 8 characters.
Debug.Assert(index > 0 && index <= 65_535);
int tzc = BitOperations.TrailingZeroCount(index);
Debug.Assert(tzc % 2 == 0 && tzc >= 0 && tzc <= 16);
idx += tzc >> 1;
goto Return;
}
idx += 8;
startingAddress += 8;
}

// Process the remaining characters.
Debug.Assert(value.Length - idx < 8);
}
#endif

for (; idx < value.Length; idx++)
{
Debug.Assert((ptr + idx) <= (ptr + value.Length));
if (NeedsEscaping(*(ptr + idx)))
{
goto Return;
}
}
}

idx = -1; // all characters allowed
idx = -1; // All characters are allowed.

Return:
return idx;
Return:
return idx;
}
}

public static int GetMaxEscapedLength(int textLength, int firstIndexToEscape)
Expand Down
Loading

0 comments on commit 0089be5

Please sign in to comment.