Skip to content

Commit

Permalink
Improvements for SpanHelpers.IndexOf (dotnet#64872)
Browse files Browse the repository at this point in the history
  • Loading branch information
EgorBo authored Feb 7, 2022
1 parent c86b338 commit bc9cdd1
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 86 deletions.
96 changes: 60 additions & 36 deletions src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ public static int IndexOf(ref byte searchSpace, int searchSpaceLength, ref byte
if (valueTailLength == 0)
return IndexOf(ref searchSpace, value, searchSpaceLength); // for single-byte values use plain IndexOf

int offset = 0;
nint offset = 0;
byte valueHead = value;
int searchSpaceMinusValueTailLength = searchSpaceLength - valueTailLength;
if (Vector128.IsHardwareAccelerated && searchSpaceMinusValueTailLength >= Vector128<byte>.Count)
Expand Down Expand Up @@ -54,7 +54,7 @@ public static int IndexOf(ref byte searchSpace, int searchSpaceLength, ref byte
if (SequenceEqual(
ref Unsafe.Add(ref searchSpace, offset + 1),
ref valueTail, (nuint)(uint)valueTailLength)) // The (nuint)-cast is necessary to pick the correct overload
return offset; // The tail matched. Return a successful find.
return (int)offset; // The tail matched. Return a successful find.

remainingSearchSpaceLength--;
offset++;
Expand All @@ -69,48 +69,60 @@ ref Unsafe.Add(ref searchSpace, offset + 1),
// Find the last unique (which is not equal to ch1) byte
// the algorithm is fine if both are equal, just a little bit less efficient
byte ch2Val = Unsafe.Add(ref value, valueTailLength);
int ch1ch2Distance = valueTailLength;
nint ch1ch2Distance = valueTailLength;
while (ch2Val == value && ch1ch2Distance > 1)
ch2Val = Unsafe.Add(ref value, --ch1ch2Distance);

Vector256<byte> ch1 = Vector256.Create(value);
Vector256<byte> ch2 = Vector256.Create(ch2Val);

nint searchSpaceMinusValueTailLengthAndVector =
searchSpaceMinusValueTailLength - (nint)Vector256<byte>.Count;

do
{
Debug.Assert(offset >= 0);
// Make sure we don't go out of bounds
Debug.Assert(offset + ch1ch2Distance + Vector256<byte>.Count <= searchSpaceLength);

Vector256<byte> cmpCh1 = Vector256.Equals(ch1, Vector256.LoadUnsafe(ref searchSpace, (nuint)offset));
Vector256<byte> cmpCh2 = Vector256.Equals(ch2, Vector256.LoadUnsafe(ref searchSpace, (nuint)(offset + ch1ch2Distance)));
Vector256<byte> cmpCh1 = Vector256.Equals(ch1, Vector256.LoadUnsafe(ref searchSpace, (nuint)offset));
Vector256<byte> cmpAnd = (cmpCh1 & cmpCh2).AsByte();

// Early out: cmpAnd is all zeros
if (cmpAnd != Vector256<byte>.Zero)
{
uint mask = cmpAnd.ExtractMostSignificantBits();
do
{
int bitPos = BitOperations.TrailingZeroCount(mask);
if (valueLength == 2 || // we already matched two bytes
SequenceEqual(
ref Unsafe.Add(ref searchSpace, offset + bitPos),
ref value, (nuint)(uint)valueLength)) // The (nuint)-cast is necessary to pick the correct overload
{
return offset + bitPos;
}
mask = BitOperations.ResetLowestSetBit(mask); // Clear the lowest set bit
} while (mask != 0);
goto CANDIDATE_FOUND;
}

LOOP_FOOTER:
offset += Vector256<byte>.Count;

if (offset == searchSpaceMinusValueTailLength)
return -1;

// Overlap with the current chunk for trailing elements
if (offset > searchSpaceMinusValueTailLength - Vector256<byte>.Count)
offset = searchSpaceMinusValueTailLength - Vector256<byte>.Count;
if (offset > searchSpaceMinusValueTailLengthAndVector)
offset = searchSpaceMinusValueTailLengthAndVector;

continue;

CANDIDATE_FOUND:
uint mask = cmpAnd.ExtractMostSignificantBits();
do
{
int bitPos = BitOperations.TrailingZeroCount(mask);
if (valueLength == 2 || // we already matched two bytes
SequenceEqual(
ref Unsafe.Add(ref searchSpace, offset + bitPos),
ref value, (nuint)(uint)valueLength)) // The (nuint)-cast is necessary to pick the correct overload
{
return (int)(offset + bitPos);
}
mask = BitOperations.ResetLowestSetBit(mask); // Clear the lowest set bit
} while (mask != 0);
goto LOOP_FOOTER;

} while (true);
}
else // 128bit vector path (SSE2 or AdvSimd)
Expand All @@ -125,42 +137,54 @@ ref Unsafe.Add(ref searchSpace, offset + bitPos),
Vector128<byte> ch1 = Vector128.Create(value);
Vector128<byte> ch2 = Vector128.Create(ch2Val);

nint searchSpaceMinusValueTailLengthAndVector =
searchSpaceMinusValueTailLength - (nint)Vector128<byte>.Count;

do
{
Debug.Assert(offset >= 0);
// Make sure we don't go out of bounds
Debug.Assert(offset + ch1ch2Distance + Vector128<byte>.Count <= searchSpaceLength);

Vector128<byte> cmpCh1 = Vector128.Equals(ch1, Vector128.LoadUnsafe(ref searchSpace, (nuint)offset));
Vector128<byte> cmpCh2 = Vector128.Equals(ch2, Vector128.LoadUnsafe(ref searchSpace, (nuint)(offset + ch1ch2Distance)));
Vector128<byte> cmpCh1 = Vector128.Equals(ch1, Vector128.LoadUnsafe(ref searchSpace, (nuint)offset));
Vector128<byte> cmpAnd = (cmpCh1 & cmpCh2).AsByte();

// Early out: cmpAnd is all zeros
if (cmpAnd != Vector128<byte>.Zero)
{
uint mask = cmpAnd.ExtractMostSignificantBits();
do
{
int bitPos = BitOperations.TrailingZeroCount(mask);
if (valueLength == 2 || // we already matched two bytes
SequenceEqual(
ref Unsafe.Add(ref searchSpace, offset + bitPos),
ref value, (nuint)(uint)valueLength)) // The (nuint)-cast is necessary to pick the correct overload
{
return offset + bitPos;
}
// Clear the lowest set bit
mask = BitOperations.ResetLowestSetBit(mask);
} while (mask != 0);
goto CANDIDATE_FOUND;
}

LOOP_FOOTER:
offset += Vector128<byte>.Count;

if (offset == searchSpaceMinusValueTailLength)
return -1;

// Overlap with the current chunk for trailing elements
if (offset > searchSpaceMinusValueTailLength - Vector128<byte>.Count)
offset = searchSpaceMinusValueTailLength - Vector128<byte>.Count;
if (offset > searchSpaceMinusValueTailLengthAndVector)
offset = searchSpaceMinusValueTailLengthAndVector;

continue;

CANDIDATE_FOUND:
uint mask = cmpAnd.ExtractMostSignificantBits();
do
{
int bitPos = BitOperations.TrailingZeroCount(mask);
if (valueLength == 2 || // we already matched two bytes
SequenceEqual(
ref Unsafe.Add(ref searchSpace, offset + bitPos),
ref value, (nuint)(uint)valueLength)) // The (nuint)-cast is necessary to pick the correct overload
{
return (int)(offset + bitPos);
}
// Clear the lowest set bit
mask = BitOperations.ResetLowestSetBit(mask);
} while (mask != 0);
goto LOOP_FOOTER;

} while (true);
}
}
Expand Down
124 changes: 74 additions & 50 deletions src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Char.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public static int IndexOf(ref char searchSpace, int searchSpaceLength, ref char
return IndexOf(ref searchSpace, value, searchSpaceLength);
}

int offset = 0;
nint offset = 0;
char valueHead = value;
int searchSpaceMinusValueTailLength = searchSpaceLength - valueTailLength;
if (Vector128.IsHardwareAccelerated && searchSpaceMinusValueTailLength >= Vector128<ushort>.Count)
Expand Down Expand Up @@ -59,7 +59,7 @@ ref Unsafe.As<char, byte>(ref Unsafe.Add(ref searchSpace, offset + 1)),
ref valueTail,
(nuint)(uint)valueTailLength * 2))
{
return offset; // The tail matched. Return a successful find.
return (int)offset; // The tail matched. Return a successful find.
}

remainingSearchSpaceLength--;
Expand All @@ -75,109 +75,133 @@ ref Unsafe.As<char, byte>(ref Unsafe.Add(ref searchSpace, offset + 1)),
// Find the last unique (which is not equal to ch1) character
// the algorithm is fine if both are equal, just a little bit less efficient
ushort ch2Val = Unsafe.Add(ref value, valueTailLength);
int ch1ch2Distance = valueTailLength;
nint ch1ch2Distance = valueTailLength;
while (ch2Val == valueHead && ch1ch2Distance > 1)
ch2Val = Unsafe.Add(ref value, --ch1ch2Distance);

Vector256<ushort> ch1 = Vector256.Create((ushort)valueHead);
Vector256<ushort> ch2 = Vector256.Create(ch2Val);

nint searchSpaceMinusValueTailLengthAndVector =
searchSpaceMinusValueTailLength - (nint)Vector256<ushort>.Count;

do
{
// Make sure we don't go out of bounds
Debug.Assert(offset + ch1ch2Distance + Vector256<ushort>.Count <= searchSpaceLength);

Vector256<ushort> cmpCh1 = Vector256.Equals(ch1, LoadVector256(ref searchSpace, offset));
Vector256<ushort> cmpCh2 = Vector256.Equals(ch2, LoadVector256(ref searchSpace, offset + ch1ch2Distance));
Vector256<ushort> cmpCh1 = Vector256.Equals(ch1, LoadVector256(ref searchSpace, offset));
Vector256<byte> cmpAnd = (cmpCh1 & cmpCh2).AsByte();

// Early out: cmpAnd is all zeros
if (cmpAnd != Vector256<byte>.Zero)
{
uint mask = cmpAnd.ExtractMostSignificantBits();
do
{
int bitPos = BitOperations.TrailingZeroCount(mask);
// div by 2 (shr) because we work with 2-byte chars
int charPos = (int)((uint)bitPos / 2);
if (valueLength == 2 || // we already matched two chars
SequenceEqual(
ref Unsafe.As<char, byte>(ref Unsafe.Add(ref searchSpace, offset + charPos)),
ref Unsafe.As<char, byte>(ref value), (nuint)(uint)valueLength * 2))
{
return offset + charPos;
}

// Clear two the lowest set bits
if (Bmi1.IsSupported)
mask = Bmi1.ResetLowestSetBit(Bmi1.ResetLowestSetBit(mask));
else
mask &= ~(uint)(0b11 << bitPos);
} while (mask != 0);
goto CANDIDATE_FOUND;
}

LOOP_FOOTER:
offset += Vector256<ushort>.Count;

if (offset == searchSpaceMinusValueTailLength)
return -1;

// Overlap with the current chunk for trailing elements
if (offset > searchSpaceMinusValueTailLength - Vector256<ushort>.Count)
offset = searchSpaceMinusValueTailLength - Vector256<ushort>.Count;
if (offset > searchSpaceMinusValueTailLengthAndVector)
offset = searchSpaceMinusValueTailLengthAndVector;

continue;

CANDIDATE_FOUND:
uint mask = cmpAnd.ExtractMostSignificantBits();
do
{
int bitPos = BitOperations.TrailingZeroCount(mask);
// div by 2 (shr) because we work with 2-byte chars
nint charPos = (nint)((uint)bitPos / 2);
if (valueLength == 2 || // we already matched two chars
SequenceEqual(
ref Unsafe.As<char, byte>(ref Unsafe.Add(ref searchSpace, offset + charPos)),
ref Unsafe.As<char, byte>(ref value), (nuint)(uint)valueLength * 2))
{
return (int)(offset + charPos);
}

// Clear two the lowest set bits
if (Bmi1.IsSupported)
mask = Bmi1.ResetLowestSetBit(Bmi1.ResetLowestSetBit(mask));
else
mask &= ~(uint)(0b11 << bitPos);
} while (mask != 0);
goto LOOP_FOOTER;

} while (true);
}
else // 128bit vector path (SSE2 or AdvSimd)
{
// Find the last unique (which is not equal to ch1) character
// the algorithm is fine if both are equal, just a little bit less efficient
ushort ch2Val = Unsafe.Add(ref value, valueTailLength);
int ch1ch2Distance = valueTailLength;
nint ch1ch2Distance = valueTailLength;
while (ch2Val == valueHead && ch1ch2Distance > 1)
ch2Val = Unsafe.Add(ref value, --ch1ch2Distance);

Vector128<ushort> ch1 = Vector128.Create((ushort)valueHead);
Vector128<ushort> ch2 = Vector128.Create(ch2Val);

nint searchSpaceMinusValueTailLengthAndVector =
searchSpaceMinusValueTailLength - (nint)Vector128<ushort>.Count;

do
{
// Make sure we don't go out of bounds
Debug.Assert(offset + ch1ch2Distance + Vector128<ushort>.Count <= searchSpaceLength);

Vector128<ushort> cmpCh1 = Vector128.Equals(ch1, LoadVector128(ref searchSpace, offset));
Vector128<ushort> cmpCh2 = Vector128.Equals(ch2, LoadVector128(ref searchSpace, offset + ch1ch2Distance));
Vector128<ushort> cmpCh1 = Vector128.Equals(ch1, LoadVector128(ref searchSpace, offset));
Vector128<byte> cmpAnd = (cmpCh1 & cmpCh2).AsByte();

// Early out: cmpAnd is all zeros
if (cmpAnd != Vector128<byte>.Zero)
{
uint mask = cmpAnd.ExtractMostSignificantBits();
do
{
int bitPos = BitOperations.TrailingZeroCount(mask);
// div by 2 (shr) because we work with 2-byte chars
int charPos = (int)((uint)bitPos / 2);
if (valueLength == 2 || // we already matched two chars
SequenceEqual(
ref Unsafe.As<char, byte>(ref Unsafe.Add(ref searchSpace, offset + charPos)),
ref Unsafe.As<char, byte>(ref value), (nuint)(uint)valueLength * 2))
{
return offset + charPos;
}

// Clear two lowest set bits
if (Bmi1.IsSupported)
mask = Bmi1.ResetLowestSetBit(Bmi1.ResetLowestSetBit(mask));
else
mask &= ~(uint)(0b11 << bitPos);
} while (mask != 0);
goto CANDIDATE_FOUND;
}

LOOP_FOOTER:
offset += Vector128<ushort>.Count;

if (offset == searchSpaceMinusValueTailLength)
return -1;

// Overlap with the current chunk for trailing elements
if (offset > searchSpaceMinusValueTailLength - Vector128<ushort>.Count)
offset = searchSpaceMinusValueTailLength - Vector128<ushort>.Count;
if (offset > searchSpaceMinusValueTailLengthAndVector)
offset = searchSpaceMinusValueTailLengthAndVector;

continue;

CANDIDATE_FOUND:
uint mask = cmpAnd.ExtractMostSignificantBits();
do
{
int bitPos = BitOperations.TrailingZeroCount(mask);
// div by 2 (shr) because we work with 2-byte chars
int charPos = (int)((uint)bitPos / 2);
if (valueLength == 2 || // we already matched two chars
SequenceEqual(
ref Unsafe.As<char, byte>(ref Unsafe.Add(ref searchSpace, offset + charPos)),
ref Unsafe.As<char, byte>(ref value), (nuint)(uint)valueLength * 2))
{
return (int)(offset + charPos);
}

// Clear two lowest set bits
if (Bmi1.IsSupported)
mask = Bmi1.ResetLowestSetBit(Bmi1.ResetLowestSetBit(mask));
else
mask &= ~(uint)(0b11 << bitPos);
} while (mask != 0);
goto LOOP_FOOTER;

} while (true);
}
}
Expand Down

0 comments on commit bc9cdd1

Please sign in to comment.