@@ -10,6 +10,7 @@ namespace SimdUnicode
10
10
public static class UTF8
11
11
{
12
12
13
+
13
14
// Returns &inputBuffer[inputLength] if the input buffer is valid.
14
15
/// <summary>
15
16
/// Given an input buffer <paramref name="pInputBuffer"/> of byte length <paramref name="inputLength"/>,
@@ -35,11 +36,10 @@ public static class UTF8
35
36
{
36
37
return GetPointerToFirstInvalidByteAvx512(pInputBuffer, inputLength);
37
38
}*/
38
- // if (Ssse3.IsSupported)
39
- // {
40
- // return GetPointerToFirstInvalidByteSse(pInputBuffer, inputLength);
41
- // }
42
- // return GetPointerToFirstInvalidByteScalar(pInputBuffer, inputLength);
39
+ if ( Ssse3 . IsSupported )
40
+ {
41
+ return GetPointerToFirstInvalidByteSse ( pInputBuffer , inputLength , out Utf16CodeUnitCountAdjustment , out ScalarCodeUnitCountAdjustment ) ;
42
+ }
43
43
44
44
return GetPointerToFirstInvalidByteScalar ( pInputBuffer , inputLength , out Utf16CodeUnitCountAdjustment , out ScalarCodeUnitCountAdjustment ) ;
45
45
@@ -471,15 +471,13 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
471
471
return ( utfadjust , scalaradjust ) ;
472
472
}
473
473
474
- public unsafe static byte * GetPointerToFirstInvalidByteSse ( byte * pInputBuffer , int inputLength )
474
+ public unsafe static byte * GetPointerToFirstInvalidByteSse ( byte * pInputBuffer , int inputLength , out int utf16CodeUnitCountAdjustment , out int scalarCountAdjustment )
475
475
{
476
-
477
476
int processedLength = 0 ;
478
- int TempUtf16CodeUnitCountAdjustment = 0 ;
479
- int TempScalarCountAdjustment = 0 ;
480
-
481
477
if ( pInputBuffer == null || inputLength <= 0 )
482
478
{
479
+ utf16CodeUnitCountAdjustment = 0 ;
480
+ scalarCountAdjustment = 0 ;
483
481
return pInputBuffer ;
484
482
}
485
483
if ( inputLength > 128 )
@@ -503,24 +501,24 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
503
501
504
502
if ( processedLength + 16 < inputLength )
505
503
{
506
- // We still have work to do!
507
504
Vector128 < byte > prevInputBlock = Vector128 < byte > . Zero ;
508
505
509
506
Vector128 < byte > maxValue = Vector128 . Create (
510
507
255 , 255 , 255 , 255 , 255 , 255 , 255 , 255 ,
511
508
255 , 255 , 255 , 255 , 255 , 0b11110000 - 1 , 0b11100000 - 1 , 0b11000000 - 1 ) ;
512
- Vector128 < byte > prevIncomplete = Sse2 . SubtractSaturate ( prevInputBlock , maxValue ) ;
513
-
509
+ Vector128 < byte > prevIncomplete = Sse3 . SubtractSaturate ( prevInputBlock , maxValue ) ;
514
510
515
- Vector128 < byte > shuf1 = Vector128 . Create ( TOO_LONG , TOO_LONG , TOO_LONG , TOO_LONG ,
511
+ Vector128 < byte > shuf1 = Vector128 . Create (
512
+ TOO_LONG , TOO_LONG , TOO_LONG , TOO_LONG ,
516
513
TOO_LONG , TOO_LONG , TOO_LONG , TOO_LONG ,
517
514
TWO_CONTS , TWO_CONTS , TWO_CONTS , TWO_CONTS ,
518
515
TOO_SHORT | OVERLONG_2 ,
519
516
TOO_SHORT ,
520
517
TOO_SHORT | OVERLONG_3 | SURROGATE ,
521
518
TOO_SHORT | TOO_LARGE | TOO_LARGE_1000 | OVERLONG_4 ) ;
522
519
523
- Vector128 < byte > shuf2 = Vector128 . Create ( CARRY | OVERLONG_3 | OVERLONG_2 | OVERLONG_4 ,
520
+ Vector128 < byte > shuf2 = Vector128 . Create (
521
+ CARRY | OVERLONG_3 | OVERLONG_2 | OVERLONG_4 ,
524
522
CARRY | OVERLONG_2 ,
525
523
CARRY ,
526
524
CARRY ,
@@ -536,7 +534,8 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
536
534
CARRY | TOO_LARGE | TOO_LARGE_1000 | SURROGATE ,
537
535
CARRY | TOO_LARGE | TOO_LARGE_1000 ,
538
536
CARRY | TOO_LARGE | TOO_LARGE_1000 ) ;
539
- Vector128 < byte > shuf3 = Vector128 . Create ( TOO_SHORT , TOO_SHORT , TOO_SHORT , TOO_SHORT ,
537
+ Vector128 < byte > shuf3 = Vector128 . Create (
538
+ TOO_SHORT , TOO_SHORT , TOO_SHORT , TOO_SHORT ,
540
539
TOO_SHORT , TOO_SHORT , TOO_SHORT , TOO_SHORT ,
541
540
TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE_1000 | OVERLONG_4 ,
542
541
TOO_LONG | OVERLONG_2 | TWO_CONTS | OVERLONG_3 | TOO_LARGE ,
@@ -548,24 +547,71 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
548
547
Vector128 < byte > fourthByte = Vector128 . Create ( ( byte ) ( 0b11110000u - 0x80 ) ) ;
549
548
Vector128 < byte > v0f = Vector128 . Create ( ( byte ) 0x0F ) ;
550
549
Vector128 < byte > v80 = Vector128 . Create ( ( byte ) 0x80 ) ;
550
+ /****
551
+ * So we want to count the number of 4-byte sequences,
552
+ * the number of 4-byte sequences, 3-byte sequences, and
553
+ * the number of 2-byte sequences.
554
+ * We can do it indirectly. We know how many bytes in total
555
+ * we have (length). Let us assume that the length covers
556
+ * only complete sequences (we need to adjust otherwise).
557
+ * We have that
558
+ * length = 4 * n4 + 3 * n3 + 2 * n2 + n1
559
+ * where n1 is the number of 1-byte sequences (ASCII),
560
+ * n2 is the number of 2-byte sequences, n3 is the number
561
+ * of 3-byte sequences, and n4 is the number of 4-byte sequences.
562
+ *
563
+ * Let ncon be the number of continuation bytes, then we have
564
+ * length = n4 + n3 + n2 + ncon + n1
565
+ *
566
+ * We can solve for n2 and n3 in terms of the other variables:
567
+ * n3 = n1 - 2 * n4 + 2 * ncon - length
568
+ * n2 = -2 * n1 + n4 - 4 * ncon + 2 * length
569
+ * Thus we only need to count the number of continuation bytes,
570
+ * the number of ASCII bytes and the number of 4-byte sequences.
571
+ */
572
+ ////////////
573
+ // The *block* here is what begins at processedLength and ends
574
+ // at processedLength/16*16 or when an error occurs.
575
+ ///////////
576
+ int start_point = processedLength ;
577
+
578
+ // The block goes from processedLength to processedLength/16*16.
579
+ int asciibytes = 0 ; // number of ascii bytes in the block (could also be called n1)
580
+ int contbytes = 0 ; // number of continuation bytes in the block
581
+ int n4 = 0 ; // number of 4-byte sequences that start in this block
551
582
for ( ; processedLength + 16 <= inputLength ; processedLength += 16 )
552
583
{
553
584
554
- Vector128 < byte > currentBlock = Sse2 . LoadVector128 ( pInputBuffer + processedLength ) ;
555
-
556
- int mask = Sse2 . MoveMask ( currentBlock ) ;
585
+ Vector128 < byte > currentBlock = Avx . LoadVector128 ( pInputBuffer + processedLength ) ;
586
+ int mask = Sse42 . MoveMask ( currentBlock ) ;
557
587
if ( mask == 0 )
558
588
{
559
589
// We have an ASCII block, no need to process it, but
560
590
// we need to check if the previous block was incomplete.
561
- if ( Sse2 . MoveMask ( prevIncomplete ) != 0 )
591
+ //
592
+
593
+ if ( ! Sse41 . TestZ ( prevIncomplete , prevIncomplete ) )
562
594
{
563
- return SimdUnicode . UTF8 . RewindAndValidateWithErrors ( processedLength , pInputBuffer + processedLength , inputLength - processedLength , ref TempUtf16CodeUnitCountAdjustment , ref TempScalarCountAdjustment ) ;
595
+ int off = processedLength >= 3 ? processedLength - 3 : processedLength ;
596
+ byte * invalidBytePointer = SimdUnicode . UTF8 . SimpleRewindAndValidateWithErrors ( 16 - 3 , pInputBuffer + processedLength - 3 , inputLength - processedLength + 3 ) ;
597
+ // So the code is correct up to invalidBytePointer
598
+ if ( invalidBytePointer < pInputBuffer + processedLength )
599
+ {
600
+ removeCounters ( invalidBytePointer , pInputBuffer + processedLength , ref asciibytes , ref n4 , ref contbytes ) ;
601
+ }
602
+ else
603
+ {
604
+ addCounters ( pInputBuffer + processedLength , invalidBytePointer , ref asciibytes , ref n4 , ref contbytes ) ;
605
+ }
606
+ int totalbyteasciierror = processedLength - start_point ;
607
+ ( utf16CodeUnitCountAdjustment , scalarCountAdjustment ) = CalculateN2N3FinalSIMDAdjustments ( asciibytes , n4 , contbytes , totalbyteasciierror ) ;
608
+ return invalidBytePointer ;
564
609
}
565
610
prevIncomplete = Vector128 < byte > . Zero ;
566
611
}
567
- else
612
+ else // Contains non-ASCII characters, we need to do non-trivial processing
568
613
{
614
+ // Use SubtractSaturate to effectively compare if bytes in block are greater than markers.
569
615
// Contains non-ASCII characters, we need to do non-trivial processing
570
616
Vector128 < byte > prev1 = Ssse3 . AlignRight ( currentBlock , prevInputBlock , ( byte ) ( 16 - 1 ) ) ;
571
617
Vector128 < byte > byte_1_high = Ssse3 . Shuffle ( shuf1 , Sse2 . ShiftRightLogical ( prev1 . AsUInt16 ( ) , 4 ) . AsByte ( ) & v0f ) ;
@@ -575,54 +621,93 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
575
621
Vector128 < byte > prev2 = Ssse3 . AlignRight ( currentBlock , prevInputBlock , ( byte ) ( 16 - 2 ) ) ;
576
622
Vector128 < byte > prev3 = Ssse3 . AlignRight ( currentBlock , prevInputBlock , ( byte ) ( 16 - 3 ) ) ;
577
623
prevInputBlock = currentBlock ;
624
+
578
625
Vector128 < byte > isThirdByte = Sse2 . SubtractSaturate ( prev2 , thirdByte ) ;
579
626
Vector128 < byte > isFourthByte = Sse2 . SubtractSaturate ( prev3 , fourthByte ) ;
580
627
Vector128 < byte > must23 = Sse2 . Or ( isThirdByte , isFourthByte ) ;
581
628
Vector128 < byte > must23As80 = Sse2 . And ( must23 , v80 ) ;
582
629
Vector128 < byte > error = Sse2 . Xor ( must23As80 , sc ) ;
583
- if ( Sse2 . MoveMask ( error ) != 0 )
630
+
631
+ if ( ! Sse42 . TestZ ( error , error ) )
584
632
{
585
- return SimdUnicode . UTF8 . RewindAndValidateWithErrors ( processedLength , pInputBuffer + processedLength , inputLength - processedLength , ref TempUtf16CodeUnitCountAdjustment , ref TempScalarCountAdjustment ) ;
633
+
634
+ byte * invalidBytePointer ;
635
+ if ( processedLength == 0 )
636
+ {
637
+ invalidBytePointer = SimdUnicode . UTF8 . SimpleRewindAndValidateWithErrors ( 0 , pInputBuffer + processedLength , inputLength - processedLength ) ;
638
+ }
639
+ else
640
+ {
641
+ invalidBytePointer = SimdUnicode . UTF8 . SimpleRewindAndValidateWithErrors ( processedLength - 3 , pInputBuffer + processedLength - 3 , inputLength - processedLength + 3 ) ;
642
+ }
643
+ if ( invalidBytePointer < pInputBuffer + processedLength )
644
+ {
645
+ removeCounters ( invalidBytePointer , pInputBuffer + processedLength , ref asciibytes , ref n4 , ref contbytes ) ;
646
+ }
647
+ else
648
+ {
649
+ addCounters ( pInputBuffer + processedLength , invalidBytePointer , ref asciibytes , ref n4 , ref contbytes ) ;
650
+ }
651
+ int total_bytes_processed = ( int ) ( invalidBytePointer - ( pInputBuffer + start_point ) ) ;
652
+ ( utf16CodeUnitCountAdjustment , scalarCountAdjustment ) = CalculateN2N3FinalSIMDAdjustments ( asciibytes , n4 , contbytes , total_bytes_processed ) ;
653
+ return invalidBytePointer ;
586
654
}
587
- prevIncomplete = Sse2 . SubtractSaturate ( currentBlock , maxValue ) ;
655
+
656
+ prevIncomplete = Sse3 . SubtractSaturate ( currentBlock , maxValue ) ;
657
+
658
+ contbytes += ( int ) Popcnt . PopCount ( ( uint ) Sse42 . MoveMask ( byte_2_high ) ) ;
659
+ // We use two instructions (SubtractSaturate and MoveMask) to update n4, with one arithmetic operation.
660
+ n4 += ( int ) Popcnt . PopCount ( ( uint ) Sse42 . MoveMask ( Sse42 . SubtractSaturate ( currentBlock , fourthByte ) ) ) ;
588
661
}
662
+
663
+ // important: we just update asciibytes if there was no error.
664
+ // We count the number of ascii bytes in the block using just some simple arithmetic
665
+ // and no expensive operation:
666
+ asciibytes += ( int ) ( 16 - Popcnt . PopCount ( ( uint ) mask ) ) ;
589
667
}
590
- }
591
- }
592
- // We have processed all the blocks using SIMD, we need to process the remaining bytes.
593
668
594
- // Process the remaining bytes with the scalar function
595
- if ( processedLength < inputLength )
596
- {
597
- // We need to possibly backtrack to the start of the last code point
598
- // worst possible case is 4 bytes, where we need to backtrack 3 bytes
599
- // 11110xxxx 10xxxxxx 10xxxxxx 10xxxxxx <== we might be pointing at the last byte
600
- if ( processedLength > 0 && ( sbyte ) pInputBuffer [ processedLength ] <= - 65 )
601
- {
602
- processedLength -= 1 ;
603
- if ( processedLength > 0 && ( sbyte ) pInputBuffer [ processedLength ] <= - 65 )
669
+
670
+ // We may still have an error.
671
+ if ( processedLength < inputLength || ! Sse42 . TestZ ( prevIncomplete , prevIncomplete ) )
604
672
{
605
- processedLength -= 1 ;
606
- if ( processedLength > 0 && ( sbyte ) pInputBuffer [ processedLength ] <= - 65 )
673
+ byte * invalidBytePointer ;
674
+ if ( processedLength == 0 )
675
+ {
676
+ invalidBytePointer = SimdUnicode . UTF8 . SimpleRewindAndValidateWithErrors ( 0 , pInputBuffer + processedLength , inputLength - processedLength ) ;
677
+ }
678
+ else
679
+ {
680
+ invalidBytePointer = SimdUnicode . UTF8 . SimpleRewindAndValidateWithErrors ( processedLength - 3 , pInputBuffer + processedLength - 3 , inputLength - processedLength + 3 ) ;
681
+
682
+ }
683
+ if ( invalidBytePointer != pInputBuffer + inputLength )
684
+ {
685
+ if ( invalidBytePointer < pInputBuffer + processedLength )
686
+ {
687
+ removeCounters ( invalidBytePointer , pInputBuffer + processedLength , ref asciibytes , ref n4 , ref contbytes ) ;
688
+ }
689
+ else
690
+ {
691
+ addCounters ( pInputBuffer + processedLength , invalidBytePointer , ref asciibytes , ref n4 , ref contbytes ) ;
692
+ }
693
+ int total_bytes_processed = ( int ) ( invalidBytePointer - ( pInputBuffer + start_point ) ) ;
694
+ ( utf16CodeUnitCountAdjustment , scalarCountAdjustment ) = CalculateN2N3FinalSIMDAdjustments ( asciibytes , n4 , contbytes , total_bytes_processed ) ;
695
+ return invalidBytePointer ;
696
+ }
697
+ else
607
698
{
608
- processedLength -= 1 ;
699
+ addCounters ( pInputBuffer + processedLength , invalidBytePointer , ref asciibytes , ref n4 , ref contbytes ) ;
609
700
}
610
701
}
611
- }
612
- int TailScalarCodeUnitCountAdjustment = 0 ;
613
- int TailUtf16CodeUnitCountAdjustment = 0 ;
614
- byte * invalidBytePointer = SimdUnicode . UTF8 . GetPointerToFirstInvalidByteScalar ( pInputBuffer + processedLength , inputLength - processedLength , out TailUtf16CodeUnitCountAdjustment , out TailScalarCodeUnitCountAdjustment ) ;
615
- if ( invalidBytePointer != pInputBuffer + inputLength )
616
- {
617
- // An invalid byte was found by the scalar function
618
- return invalidBytePointer ;
702
+ int final_total_bytes_processed = inputLength - start_point ;
703
+ ( utf16CodeUnitCountAdjustment , scalarCountAdjustment ) = CalculateN2N3FinalSIMDAdjustments ( asciibytes , n4 , contbytes , final_total_bytes_processed ) ;
704
+ return pInputBuffer + inputLength ;
619
705
}
620
706
}
621
-
622
- return pInputBuffer + inputLength ;
707
+ return GetPointerToFirstInvalidByteScalar ( pInputBuffer + processedLength , inputLength - processedLength , out utf16CodeUnitCountAdjustment , out scalarCountAdjustment ) ;
623
708
}
624
709
625
-
710
+ //
626
711
public unsafe static byte * GetPointerToFirstInvalidByteAvx2 ( byte * pInputBuffer , int inputLength , out int utf16CodeUnitCountAdjustment , out int scalarCountAdjustment )
627
712
{
628
713
int processedLength = 0 ;
0 commit comments