Skip to content

Sse validation #42

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Jun 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion benchmark/Benchmark.cs
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,13 @@ public unsafe void SIMDUtf8ValidationRealDataSse()
{
if (allLinesUtf8 != null)
{
RunUtf8ValidationBenchmark(allLinesUtf8, SimdUnicode.UTF8.GetPointerToFirstInvalidByteSse);
RunUtf8ValidationBenchmark(allLinesUtf8, (byte* pInputBuffer, int inputLength) =>
{
int dummyUtf16CodeUnitCountAdjustment, dummyScalarCountAdjustment;
// Call the method with additional out parameters within the lambda.
// You must handle these additional out parameters inside the lambda, as they cannot be passed back through the delegate.
return SimdUnicode.UTF8.GetPointerToFirstInvalidByteSse(pInputBuffer, inputLength, out dummyUtf16CodeUnitCountAdjustment, out dummyScalarCountAdjustment);
});
}
}

Expand Down
189 changes: 137 additions & 52 deletions src/UTF8.cs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ namespace SimdUnicode
public static class UTF8
{


// Returns &inputBuffer[inputLength] if the input buffer is valid.
/// <summary>
/// Given an input buffer <paramref name="pInputBuffer"/> of byte length <paramref name="inputLength"/>,
Expand All @@ -35,11 +36,10 @@ public static class UTF8
{
return GetPointerToFirstInvalidByteAvx512(pInputBuffer, inputLength);
}*/
// if (Ssse3.IsSupported)
// {
// return GetPointerToFirstInvalidByteSse(pInputBuffer, inputLength);
// }
// return GetPointerToFirstInvalidByteScalar(pInputBuffer, inputLength);
if (Ssse3.IsSupported)
{
return GetPointerToFirstInvalidByteSse(pInputBuffer, inputLength,out Utf16CodeUnitCountAdjustment, out ScalarCodeUnitCountAdjustment);
}

return GetPointerToFirstInvalidByteScalar(pInputBuffer, inputLength, out Utf16CodeUnitCountAdjustment, out ScalarCodeUnitCountAdjustment);

Expand Down Expand Up @@ -471,15 +471,13 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
return (utfadjust, scalaradjust);
}

public unsafe static byte* GetPointerToFirstInvalidByteSse(byte* pInputBuffer, int inputLength)
public unsafe static byte* GetPointerToFirstInvalidByteSse(byte* pInputBuffer, int inputLength, out int utf16CodeUnitCountAdjustment, out int scalarCountAdjustment)
{

int processedLength = 0;
int TempUtf16CodeUnitCountAdjustment = 0;
int TempScalarCountAdjustment = 0;

if (pInputBuffer == null || inputLength <= 0)
{
utf16CodeUnitCountAdjustment = 0;
scalarCountAdjustment = 0;
return pInputBuffer;
}
if (inputLength > 128)
Expand All @@ -503,24 +501,24 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust

if (processedLength + 16 < inputLength)
{
// We still have work to do!
Vector128<byte> prevInputBlock = Vector128<byte>.Zero;

Vector128<byte> maxValue = Vector128.Create(
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 0b11110000 - 1, 0b11100000 - 1, 0b11000000 - 1);
Vector128<byte> prevIncomplete = Sse2.SubtractSaturate(prevInputBlock, maxValue);

Vector128<byte> prevIncomplete = Sse3.SubtractSaturate(prevInputBlock, maxValue);

Vector128<byte> shuf1 = Vector128.Create(TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG,
Vector128<byte> shuf1 = Vector128.Create(
TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG,
TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG,
TWO_CONTS, TWO_CONTS, TWO_CONTS, TWO_CONTS,
TOO_SHORT | OVERLONG_2,
TOO_SHORT,
TOO_SHORT | OVERLONG_3 | SURROGATE,
TOO_SHORT | TOO_LARGE | TOO_LARGE_1000 | OVERLONG_4);

Vector128<byte> shuf2 = Vector128.Create(CARRY | OVERLONG_3 | OVERLONG_2 | OVERLONG_4,
Vector128<byte> shuf2 = Vector128.Create(
CARRY | OVERLONG_3 | OVERLONG_2 | OVERLONG_4,
CARRY | OVERLONG_2,
CARRY,
CARRY,
Expand All @@ -536,7 +534,8 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
CARRY | TOO_LARGE | TOO_LARGE_1000 | SURROGATE,
CARRY | TOO_LARGE | TOO_LARGE_1000,
CARRY | TOO_LARGE | TOO_LARGE_1000);
Vector128<byte> shuf3 = Vector128.Create(TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT,
Vector128<byte> shuf3 = Vector128.Create(
TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT,
TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT,
TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE_1000 | OVERLONG_4,
TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE,
Expand All @@ -548,24 +547,71 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
Vector128<byte> fourthByte = Vector128.Create((byte)(0b11110000u - 0x80));
Vector128<byte> v0f = Vector128.Create((byte)0x0F);
Vector128<byte> v80 = Vector128.Create((byte)0x80);
/****
* So we want to count the number of 4-byte sequences,
* the number of 4-byte sequences, 3-byte sequences, and
* the number of 2-byte sequences.
* We can do it indirectly. We know how many bytes in total
* we have (length). Let us assume that the length covers
* only complete sequences (we need to adjust otherwise).
* We have that
* length = 4 * n4 + 3 * n3 + 2 * n2 + n1
* where n1 is the number of 1-byte sequences (ASCII),
* n2 is the number of 2-byte sequences, n3 is the number
* of 3-byte sequences, and n4 is the number of 4-byte sequences.
*
* Let ncon be the number of continuation bytes, then we have
* length = n4 + n3 + n2 + ncon + n1
*
* We can solve for n2 and n3 in terms of the other variables:
* n3 = n1 - 2 * n4 + 2 * ncon - length
* n2 = -2 * n1 + n4 - 4 * ncon + 2 * length
* Thus we only need to count the number of continuation bytes,
* the number of ASCII bytes and the number of 4-byte sequences.
*/
////////////
// The *block* here is what begins at processedLength and ends
// at processedLength/16*16 or when an error occurs.
///////////
int start_point = processedLength;

// The block goes from processedLength to processedLength/16*16.
int asciibytes = 0; // number of ascii bytes in the block (could also be called n1)
int contbytes = 0; // number of continuation bytes in the block
int n4 = 0; // number of 4-byte sequences that start in this block
for (; processedLength + 16 <= inputLength; processedLength += 16)
{

Vector128<byte> currentBlock = Sse2.LoadVector128(pInputBuffer + processedLength);

int mask = Sse2.MoveMask(currentBlock);
Vector128<byte> currentBlock = Avx.LoadVector128(pInputBuffer + processedLength);
int mask = Sse42.MoveMask(currentBlock);
if (mask == 0)
{
// We have an ASCII block, no need to process it, but
// we need to check if the previous block was incomplete.
if (Sse2.MoveMask(prevIncomplete) != 0)
//

if (!Sse41.TestZ(prevIncomplete, prevIncomplete))
{
return SimdUnicode.UTF8.RewindAndValidateWithErrors(processedLength, pInputBuffer + processedLength, inputLength - processedLength, ref TempUtf16CodeUnitCountAdjustment, ref TempScalarCountAdjustment);
int off = processedLength >= 3 ? processedLength - 3 : processedLength;
byte* invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(16 - 3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3);
// So the code is correct up to invalidBytePointer
if (invalidBytePointer < pInputBuffer + processedLength)
{
removeCounters(invalidBytePointer, pInputBuffer + processedLength, ref asciibytes, ref n4, ref contbytes);
}
else
{
addCounters(pInputBuffer + processedLength, invalidBytePointer, ref asciibytes, ref n4, ref contbytes);
}
int totalbyteasciierror = processedLength - start_point;
(utf16CodeUnitCountAdjustment, scalarCountAdjustment) = CalculateN2N3FinalSIMDAdjustments(asciibytes, n4, contbytes, totalbyteasciierror);
return invalidBytePointer;
}
prevIncomplete = Vector128<byte>.Zero;
}
else
else // Contains non-ASCII characters, we need to do non-trivial processing
{
// Use SubtractSaturate to effectively compare if bytes in block are greater than markers.
// Contains non-ASCII characters, we need to do non-trivial processing
Vector128<byte> prev1 = Ssse3.AlignRight(currentBlock, prevInputBlock, (byte)(16 - 1));
Vector128<byte> byte_1_high = Ssse3.Shuffle(shuf1, Sse2.ShiftRightLogical(prev1.AsUInt16(), 4).AsByte() & v0f);
Expand All @@ -575,54 +621,93 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
Vector128<byte> prev2 = Ssse3.AlignRight(currentBlock, prevInputBlock, (byte)(16 - 2));
Vector128<byte> prev3 = Ssse3.AlignRight(currentBlock, prevInputBlock, (byte)(16 - 3));
prevInputBlock = currentBlock;

Vector128<byte> isThirdByte = Sse2.SubtractSaturate(prev2, thirdByte);
Vector128<byte> isFourthByte = Sse2.SubtractSaturate(prev3, fourthByte);
Vector128<byte> must23 = Sse2.Or(isThirdByte, isFourthByte);
Vector128<byte> must23As80 = Sse2.And(must23, v80);
Vector128<byte> error = Sse2.Xor(must23As80, sc);
if (Sse2.MoveMask(error) != 0)

if (!Sse42.TestZ(error, error))
{
return SimdUnicode.UTF8.RewindAndValidateWithErrors(processedLength, pInputBuffer + processedLength, inputLength - processedLength, ref TempUtf16CodeUnitCountAdjustment, ref TempScalarCountAdjustment);

byte* invalidBytePointer;
if (processedLength == 0)
{
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(0, pInputBuffer + processedLength, inputLength - processedLength);
}
else
{
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(processedLength - 3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3);
}
if (invalidBytePointer < pInputBuffer + processedLength)
{
removeCounters(invalidBytePointer, pInputBuffer + processedLength, ref asciibytes, ref n4, ref contbytes);
}
else
{
addCounters(pInputBuffer + processedLength, invalidBytePointer, ref asciibytes, ref n4, ref contbytes);
}
int total_bytes_processed = (int)(invalidBytePointer - (pInputBuffer + start_point));
(utf16CodeUnitCountAdjustment, scalarCountAdjustment) = CalculateN2N3FinalSIMDAdjustments(asciibytes, n4, contbytes, total_bytes_processed);
return invalidBytePointer;
}
prevIncomplete = Sse2.SubtractSaturate(currentBlock, maxValue);

prevIncomplete = Sse3.SubtractSaturate(currentBlock, maxValue);

contbytes += (int)Popcnt.PopCount((uint)Sse42.MoveMask(byte_2_high));
// We use two instructions (SubtractSaturate and MoveMask) to update n4, with one arithmetic operation.
n4 += (int)Popcnt.PopCount((uint)Sse42.MoveMask(Sse42.SubtractSaturate(currentBlock, fourthByte)));
}

// important: we just update asciibytes if there was no error.
// We count the number of ascii bytes in the block using just some simple arithmetic
// and no expensive operation:
asciibytes += (int)(16 - Popcnt.PopCount((uint)mask));
}
}
}
// We have processed all the blocks using SIMD, we need to process the remaining bytes.

// Process the remaining bytes with the scalar function
if (processedLength < inputLength)
{
// We need to possibly backtrack to the start of the last code point
// worst possible case is 4 bytes, where we need to backtrack 3 bytes
// 11110xxxx 10xxxxxx 10xxxxxx 10xxxxxx <== we might be pointing at the last byte
if (processedLength > 0 && (sbyte)pInputBuffer[processedLength] <= -65)
{
processedLength -= 1;
if (processedLength > 0 && (sbyte)pInputBuffer[processedLength] <= -65)

// We may still have an error.
if (processedLength < inputLength || !Sse42.TestZ(prevIncomplete, prevIncomplete))
{
processedLength -= 1;
if (processedLength > 0 && (sbyte)pInputBuffer[processedLength] <= -65)
byte* invalidBytePointer;
if (processedLength == 0)
{
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(0, pInputBuffer + processedLength, inputLength - processedLength);
}
else
{
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(processedLength - 3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3);

}
if (invalidBytePointer != pInputBuffer + inputLength)
{
if (invalidBytePointer < pInputBuffer + processedLength)
{
removeCounters(invalidBytePointer, pInputBuffer + processedLength, ref asciibytes, ref n4, ref contbytes);
}
else
{
addCounters(pInputBuffer + processedLength, invalidBytePointer, ref asciibytes, ref n4, ref contbytes);
}
int total_bytes_processed = (int)(invalidBytePointer - (pInputBuffer + start_point));
(utf16CodeUnitCountAdjustment, scalarCountAdjustment) = CalculateN2N3FinalSIMDAdjustments(asciibytes, n4, contbytes, total_bytes_processed);
return invalidBytePointer;
}
else
{
processedLength -= 1;
addCounters(pInputBuffer + processedLength, invalidBytePointer, ref asciibytes, ref n4, ref contbytes);
}
}
}
int TailScalarCodeUnitCountAdjustment = 0;
int TailUtf16CodeUnitCountAdjustment = 0;
byte* invalidBytePointer = SimdUnicode.UTF8.GetPointerToFirstInvalidByteScalar(pInputBuffer + processedLength, inputLength - processedLength, out TailUtf16CodeUnitCountAdjustment, out TailScalarCodeUnitCountAdjustment);
if (invalidBytePointer != pInputBuffer + inputLength)
{
// An invalid byte was found by the scalar function
return invalidBytePointer;
int final_total_bytes_processed = inputLength - start_point;
(utf16CodeUnitCountAdjustment, scalarCountAdjustment) = CalculateN2N3FinalSIMDAdjustments(asciibytes, n4, contbytes, final_total_bytes_processed);
return pInputBuffer + inputLength;
}
}

return pInputBuffer + inputLength;
return GetPointerToFirstInvalidByteScalar(pInputBuffer + processedLength, inputLength - processedLength, out utf16CodeUnitCountAdjustment, out scalarCountAdjustment);
}


//
public unsafe static byte* GetPointerToFirstInvalidByteAvx2(byte* pInputBuffer, int inputLength, out int utf16CodeUnitCountAdjustment, out int scalarCountAdjustment)
{
int processedLength = 0;
Expand Down
Loading