Skip to content

Commit 75dc197

Browse files
authored
Merge pull request #42 from simdutf/sse_validation
Sse validation
2 parents 6ff99ee + a992e33 commit 75dc197

File tree

3 files changed

+261
-59
lines changed

3 files changed

+261
-59
lines changed

benchmark/Benchmark.cs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,13 @@ public unsafe void SIMDUtf8ValidationRealDataSse()
259259
{
260260
if (allLinesUtf8 != null)
261261
{
262-
RunUtf8ValidationBenchmark(allLinesUtf8, SimdUnicode.UTF8.GetPointerToFirstInvalidByteSse);
262+
RunUtf8ValidationBenchmark(allLinesUtf8, (byte* pInputBuffer, int inputLength) =>
263+
{
264+
int dummyUtf16CodeUnitCountAdjustment, dummyScalarCountAdjustment;
265+
// Call the method with additional out parameters within the lambda.
266+
// You must handle these additional out parameters inside the lambda, as they cannot be passed back through the delegate.
267+
return SimdUnicode.UTF8.GetPointerToFirstInvalidByteSse(pInputBuffer, inputLength, out dummyUtf16CodeUnitCountAdjustment, out dummyScalarCountAdjustment);
268+
});
263269
}
264270
}
265271

src/UTF8.cs

Lines changed: 137 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ namespace SimdUnicode
1010
public static class UTF8
1111
{
1212

13+
1314
// Returns &inputBuffer[inputLength] if the input buffer is valid.
1415
/// <summary>
1516
/// Given an input buffer <paramref name="pInputBuffer"/> of byte length <paramref name="inputLength"/>,
@@ -35,11 +36,10 @@ public static class UTF8
3536
{
3637
return GetPointerToFirstInvalidByteAvx512(pInputBuffer, inputLength);
3738
}*/
38-
// if (Ssse3.IsSupported)
39-
// {
40-
// return GetPointerToFirstInvalidByteSse(pInputBuffer, inputLength);
41-
// }
42-
// return GetPointerToFirstInvalidByteScalar(pInputBuffer, inputLength);
39+
if (Ssse3.IsSupported)
40+
{
41+
return GetPointerToFirstInvalidByteSse(pInputBuffer, inputLength,out Utf16CodeUnitCountAdjustment, out ScalarCodeUnitCountAdjustment);
42+
}
4343

4444
return GetPointerToFirstInvalidByteScalar(pInputBuffer, inputLength, out Utf16CodeUnitCountAdjustment, out ScalarCodeUnitCountAdjustment);
4545

@@ -471,15 +471,13 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
471471
return (utfadjust, scalaradjust);
472472
}
473473

474-
public unsafe static byte* GetPointerToFirstInvalidByteSse(byte* pInputBuffer, int inputLength)
474+
public unsafe static byte* GetPointerToFirstInvalidByteSse(byte* pInputBuffer, int inputLength, out int utf16CodeUnitCountAdjustment, out int scalarCountAdjustment)
475475
{
476-
477476
int processedLength = 0;
478-
int TempUtf16CodeUnitCountAdjustment = 0;
479-
int TempScalarCountAdjustment = 0;
480-
481477
if (pInputBuffer == null || inputLength <= 0)
482478
{
479+
utf16CodeUnitCountAdjustment = 0;
480+
scalarCountAdjustment = 0;
483481
return pInputBuffer;
484482
}
485483
if (inputLength > 128)
@@ -503,24 +501,24 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
503501

504502
if (processedLength + 16 < inputLength)
505503
{
506-
// We still have work to do!
507504
Vector128<byte> prevInputBlock = Vector128<byte>.Zero;
508505

509506
Vector128<byte> maxValue = Vector128.Create(
510507
255, 255, 255, 255, 255, 255, 255, 255,
511508
255, 255, 255, 255, 255, 0b11110000 - 1, 0b11100000 - 1, 0b11000000 - 1);
512-
Vector128<byte> prevIncomplete = Sse2.SubtractSaturate(prevInputBlock, maxValue);
513-
509+
Vector128<byte> prevIncomplete = Sse3.SubtractSaturate(prevInputBlock, maxValue);
514510

515-
Vector128<byte> shuf1 = Vector128.Create(TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG,
511+
Vector128<byte> shuf1 = Vector128.Create(
512+
TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG,
516513
TOO_LONG, TOO_LONG, TOO_LONG, TOO_LONG,
517514
TWO_CONTS, TWO_CONTS, TWO_CONTS, TWO_CONTS,
518515
TOO_SHORT | OVERLONG_2,
519516
TOO_SHORT,
520517
TOO_SHORT | OVERLONG_3 | SURROGATE,
521518
TOO_SHORT | TOO_LARGE | TOO_LARGE_1000 | OVERLONG_4);
522519

523-
Vector128<byte> shuf2 = Vector128.Create(CARRY | OVERLONG_3 | OVERLONG_2 | OVERLONG_4,
520+
Vector128<byte> shuf2 = Vector128.Create(
521+
CARRY | OVERLONG_3 | OVERLONG_2 | OVERLONG_4,
524522
CARRY | OVERLONG_2,
525523
CARRY,
526524
CARRY,
@@ -536,7 +534,8 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
536534
CARRY | TOO_LARGE | TOO_LARGE_1000 | SURROGATE,
537535
CARRY | TOO_LARGE | TOO_LARGE_1000,
538536
CARRY | TOO_LARGE | TOO_LARGE_1000);
539-
Vector128<byte> shuf3 = Vector128.Create(TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT,
537+
Vector128<byte> shuf3 = Vector128.Create(
538+
TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT,
540539
TOO_SHORT, TOO_SHORT, TOO_SHORT, TOO_SHORT,
541540
TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE_1000 | OVERLONG_4,
542541
TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE,
@@ -548,24 +547,71 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
548547
Vector128<byte> fourthByte = Vector128.Create((byte)(0b11110000u - 0x80));
549548
Vector128<byte> v0f = Vector128.Create((byte)0x0F);
550549
Vector128<byte> v80 = Vector128.Create((byte)0x80);
550+
/****
551+
* So we want to count the number of 4-byte sequences,
552+
* the number of 4-byte sequences, 3-byte sequences, and
553+
* the number of 2-byte sequences.
554+
* We can do it indirectly. We know how many bytes in total
555+
* we have (length). Let us assume that the length covers
556+
* only complete sequences (we need to adjust otherwise).
557+
* We have that
558+
* length = 4 * n4 + 3 * n3 + 2 * n2 + n1
559+
* where n1 is the number of 1-byte sequences (ASCII),
560+
* n2 is the number of 2-byte sequences, n3 is the number
561+
* of 3-byte sequences, and n4 is the number of 4-byte sequences.
562+
*
563+
* Let ncon be the number of continuation bytes, then we have
564+
* length = n4 + n3 + n2 + ncon + n1
565+
*
566+
* We can solve for n2 and n3 in terms of the other variables:
567+
* n3 = n1 - 2 * n4 + 2 * ncon - length
568+
* n2 = -2 * n1 + n4 - 4 * ncon + 2 * length
569+
* Thus we only need to count the number of continuation bytes,
570+
* the number of ASCII bytes and the number of 4-byte sequences.
571+
*/
572+
////////////
573+
// The *block* here is what begins at processedLength and ends
574+
// at processedLength/16*16 or when an error occurs.
575+
///////////
576+
int start_point = processedLength;
577+
578+
// The block goes from processedLength to processedLength/16*16.
579+
int asciibytes = 0; // number of ascii bytes in the block (could also be called n1)
580+
int contbytes = 0; // number of continuation bytes in the block
581+
int n4 = 0; // number of 4-byte sequences that start in this block
551582
for (; processedLength + 16 <= inputLength; processedLength += 16)
552583
{
553584

554-
Vector128<byte> currentBlock = Sse2.LoadVector128(pInputBuffer + processedLength);
555-
556-
int mask = Sse2.MoveMask(currentBlock);
585+
Vector128<byte> currentBlock = Avx.LoadVector128(pInputBuffer + processedLength);
586+
int mask = Sse42.MoveMask(currentBlock);
557587
if (mask == 0)
558588
{
559589
// We have an ASCII block, no need to process it, but
560590
// we need to check if the previous block was incomplete.
561-
if (Sse2.MoveMask(prevIncomplete) != 0)
591+
//
592+
593+
if (!Sse41.TestZ(prevIncomplete, prevIncomplete))
562594
{
563-
return SimdUnicode.UTF8.RewindAndValidateWithErrors(processedLength, pInputBuffer + processedLength, inputLength - processedLength, ref TempUtf16CodeUnitCountAdjustment, ref TempScalarCountAdjustment);
595+
int off = processedLength >= 3 ? processedLength - 3 : processedLength;
596+
byte* invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(16 - 3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3);
597+
// So the code is correct up to invalidBytePointer
598+
if (invalidBytePointer < pInputBuffer + processedLength)
599+
{
600+
removeCounters(invalidBytePointer, pInputBuffer + processedLength, ref asciibytes, ref n4, ref contbytes);
601+
}
602+
else
603+
{
604+
addCounters(pInputBuffer + processedLength, invalidBytePointer, ref asciibytes, ref n4, ref contbytes);
605+
}
606+
int totalbyteasciierror = processedLength - start_point;
607+
(utf16CodeUnitCountAdjustment, scalarCountAdjustment) = CalculateN2N3FinalSIMDAdjustments(asciibytes, n4, contbytes, totalbyteasciierror);
608+
return invalidBytePointer;
564609
}
565610
prevIncomplete = Vector128<byte>.Zero;
566611
}
567-
else
612+
else // Contains non-ASCII characters, we need to do non-trivial processing
568613
{
614+
// Use SubtractSaturate to effectively compare if bytes in block are greater than markers.
569615
// Contains non-ASCII characters, we need to do non-trivial processing
570616
Vector128<byte> prev1 = Ssse3.AlignRight(currentBlock, prevInputBlock, (byte)(16 - 1));
571617
Vector128<byte> byte_1_high = Ssse3.Shuffle(shuf1, Sse2.ShiftRightLogical(prev1.AsUInt16(), 4).AsByte() & v0f);
@@ -575,54 +621,93 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
575621
Vector128<byte> prev2 = Ssse3.AlignRight(currentBlock, prevInputBlock, (byte)(16 - 2));
576622
Vector128<byte> prev3 = Ssse3.AlignRight(currentBlock, prevInputBlock, (byte)(16 - 3));
577623
prevInputBlock = currentBlock;
624+
578625
Vector128<byte> isThirdByte = Sse2.SubtractSaturate(prev2, thirdByte);
579626
Vector128<byte> isFourthByte = Sse2.SubtractSaturate(prev3, fourthByte);
580627
Vector128<byte> must23 = Sse2.Or(isThirdByte, isFourthByte);
581628
Vector128<byte> must23As80 = Sse2.And(must23, v80);
582629
Vector128<byte> error = Sse2.Xor(must23As80, sc);
583-
if (Sse2.MoveMask(error) != 0)
630+
631+
if (!Sse42.TestZ(error, error))
584632
{
585-
return SimdUnicode.UTF8.RewindAndValidateWithErrors(processedLength, pInputBuffer + processedLength, inputLength - processedLength, ref TempUtf16CodeUnitCountAdjustment, ref TempScalarCountAdjustment);
633+
634+
byte* invalidBytePointer;
635+
if (processedLength == 0)
636+
{
637+
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(0, pInputBuffer + processedLength, inputLength - processedLength);
638+
}
639+
else
640+
{
641+
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(processedLength - 3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3);
642+
}
643+
if (invalidBytePointer < pInputBuffer + processedLength)
644+
{
645+
removeCounters(invalidBytePointer, pInputBuffer + processedLength, ref asciibytes, ref n4, ref contbytes);
646+
}
647+
else
648+
{
649+
addCounters(pInputBuffer + processedLength, invalidBytePointer, ref asciibytes, ref n4, ref contbytes);
650+
}
651+
int total_bytes_processed = (int)(invalidBytePointer - (pInputBuffer + start_point));
652+
(utf16CodeUnitCountAdjustment, scalarCountAdjustment) = CalculateN2N3FinalSIMDAdjustments(asciibytes, n4, contbytes, total_bytes_processed);
653+
return invalidBytePointer;
586654
}
587-
prevIncomplete = Sse2.SubtractSaturate(currentBlock, maxValue);
655+
656+
prevIncomplete = Sse3.SubtractSaturate(currentBlock, maxValue);
657+
658+
contbytes += (int)Popcnt.PopCount((uint)Sse42.MoveMask(byte_2_high));
659+
// We use two instructions (SubtractSaturate and MoveMask) to update n4, with one arithmetic operation.
660+
n4 += (int)Popcnt.PopCount((uint)Sse42.MoveMask(Sse42.SubtractSaturate(currentBlock, fourthByte)));
588661
}
662+
663+
// important: we just update asciibytes if there was no error.
664+
// We count the number of ascii bytes in the block using just some simple arithmetic
665+
// and no expensive operation:
666+
asciibytes += (int)(16 - Popcnt.PopCount((uint)mask));
589667
}
590-
}
591-
}
592-
// We have processed all the blocks using SIMD, we need to process the remaining bytes.
593668

594-
// Process the remaining bytes with the scalar function
595-
if (processedLength < inputLength)
596-
{
597-
// We need to possibly backtrack to the start of the last code point
598-
// worst possible case is 4 bytes, where we need to backtrack 3 bytes
599-
// 11110xxxx 10xxxxxx 10xxxxxx 10xxxxxx <== we might be pointing at the last byte
600-
if (processedLength > 0 && (sbyte)pInputBuffer[processedLength] <= -65)
601-
{
602-
processedLength -= 1;
603-
if (processedLength > 0 && (sbyte)pInputBuffer[processedLength] <= -65)
669+
670+
// We may still have an error.
671+
if (processedLength < inputLength || !Sse42.TestZ(prevIncomplete, prevIncomplete))
604672
{
605-
processedLength -= 1;
606-
if (processedLength > 0 && (sbyte)pInputBuffer[processedLength] <= -65)
673+
byte* invalidBytePointer;
674+
if (processedLength == 0)
675+
{
676+
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(0, pInputBuffer + processedLength, inputLength - processedLength);
677+
}
678+
else
679+
{
680+
invalidBytePointer = SimdUnicode.UTF8.SimpleRewindAndValidateWithErrors(processedLength - 3, pInputBuffer + processedLength - 3, inputLength - processedLength + 3);
681+
682+
}
683+
if (invalidBytePointer != pInputBuffer + inputLength)
684+
{
685+
if (invalidBytePointer < pInputBuffer + processedLength)
686+
{
687+
removeCounters(invalidBytePointer, pInputBuffer + processedLength, ref asciibytes, ref n4, ref contbytes);
688+
}
689+
else
690+
{
691+
addCounters(pInputBuffer + processedLength, invalidBytePointer, ref asciibytes, ref n4, ref contbytes);
692+
}
693+
int total_bytes_processed = (int)(invalidBytePointer - (pInputBuffer + start_point));
694+
(utf16CodeUnitCountAdjustment, scalarCountAdjustment) = CalculateN2N3FinalSIMDAdjustments(asciibytes, n4, contbytes, total_bytes_processed);
695+
return invalidBytePointer;
696+
}
697+
else
607698
{
608-
processedLength -= 1;
699+
addCounters(pInputBuffer + processedLength, invalidBytePointer, ref asciibytes, ref n4, ref contbytes);
609700
}
610701
}
611-
}
612-
int TailScalarCodeUnitCountAdjustment = 0;
613-
int TailUtf16CodeUnitCountAdjustment = 0;
614-
byte* invalidBytePointer = SimdUnicode.UTF8.GetPointerToFirstInvalidByteScalar(pInputBuffer + processedLength, inputLength - processedLength, out TailUtf16CodeUnitCountAdjustment, out TailScalarCodeUnitCountAdjustment);
615-
if (invalidBytePointer != pInputBuffer + inputLength)
616-
{
617-
// An invalid byte was found by the scalar function
618-
return invalidBytePointer;
702+
int final_total_bytes_processed = inputLength - start_point;
703+
(utf16CodeUnitCountAdjustment, scalarCountAdjustment) = CalculateN2N3FinalSIMDAdjustments(asciibytes, n4, contbytes, final_total_bytes_processed);
704+
return pInputBuffer + inputLength;
619705
}
620706
}
621-
622-
return pInputBuffer + inputLength;
707+
return GetPointerToFirstInvalidByteScalar(pInputBuffer + processedLength, inputLength - processedLength, out utf16CodeUnitCountAdjustment, out scalarCountAdjustment);
623708
}
624709

625-
710+
//
626711
public unsafe static byte* GetPointerToFirstInvalidByteAvx2(byte* pInputBuffer, int inputLength, out int utf16CodeUnitCountAdjustment, out int scalarCountAdjustment)
627712
{
628713
int processedLength = 0;

0 commit comments

Comments
 (0)