@@ -813,7 +813,6 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
813
813
if ( processedLength == 0 )
814
814
{
815
815
invalidBytePointer = SimdUnicode . UTF8 . SimpleRewindAndValidateWithErrors ( 0 , pInputBuffer + processedLength , inputLength - processedLength ) ;
816
-
817
816
}
818
817
else
819
818
{
@@ -827,9 +826,7 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
827
826
{
828
827
addCounters ( pInputBuffer + processedLength , invalidBytePointer , ref asciibytes , ref n4 , ref contbytes ) ;
829
828
}
830
-
831
829
int total_bytes_processed = ( int ) ( invalidBytePointer - ( pInputBuffer + start_point ) ) ;
832
-
833
830
( utf16CodeUnitCountAdjustment , scalarCountAdjustment ) = CalculateN2N3FinalSIMDAdjustments ( asciibytes , n4 , contbytes , total_bytes_processed ) ;
834
831
return invalidBytePointer ;
835
832
}
@@ -855,16 +852,13 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
855
852
}
856
853
else
857
854
{
858
-
859
855
invalidBytePointer = SimdUnicode . UTF8 . SimpleRewindAndValidateWithErrors ( processedLength - 3 , pInputBuffer + processedLength - 3 , inputLength - processedLength + 3 ) ;
860
856
861
857
}
862
858
if ( invalidBytePointer != pInputBuffer + inputLength )
863
859
{
864
-
865
860
if ( invalidBytePointer < pInputBuffer + processedLength )
866
861
{
867
-
868
862
removeCounters ( invalidBytePointer , pInputBuffer + processedLength , ref asciibytes , ref n4 , ref contbytes ) ;
869
863
}
870
864
else
@@ -891,16 +885,10 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
891
885
public unsafe static byte * GetPointerToFirstInvalidByteArm64 ( byte * pInputBuffer , int inputLength , out int utf16CodeUnitCountAdjustment , out int scalarCountAdjustment )
892
886
{
893
887
int processedLength = 0 ;
894
- int TempUtf16CodeUnitCountAdjustment = 0 ;
895
- int TempScalarCountAdjustment = 0 ;
896
-
897
- int TailScalarCodeUnitCountAdjustment = 0 ;
898
- int TailUtf16CodeUnitCountAdjustment = 0 ;
899
-
900
888
if ( pInputBuffer == null || inputLength <= 0 )
901
889
{
902
- utf16CodeUnitCountAdjustment = TempUtf16CodeUnitCountAdjustment ;
903
- scalarCountAdjustment = TempScalarCountAdjustment ;
890
+ utf16CodeUnitCountAdjustment = 0 ;
891
+ scalarCountAdjustment = 0 ;
904
892
return pInputBuffer ;
905
893
}
906
894
if ( inputLength > 128 )
@@ -986,14 +974,20 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
986
974
// we need to check if the previous block was incomplete.
987
975
if ( AdvSimd . Arm64 . MaxAcross ( prevIncomplete ) . ToScalar ( ) != 0 )
988
976
{
989
- int totalbyteasciierror = processedLength - start_point ;
990
- var ( utfadjustasciierror , scalaradjustasciierror ) = CalculateN2N3FinalSIMDAdjustments ( asciibytes , n4 , contbytes , totalbyteasciierror ) ;
991
-
992
- utf16CodeUnitCountAdjustment = utfadjustasciierror ;
993
- scalarCountAdjustment = scalaradjustasciierror ;
994
-
995
977
int off = processedLength >= 3 ? processedLength - 3 : processedLength ;
996
- return SimdUnicode . UTF8 . RewindAndValidateWithErrors ( off , pInputBuffer + off , inputLength - off , ref utf16CodeUnitCountAdjustment , ref scalarCountAdjustment ) ;
978
+ byte * invalidBytePointer = SimdUnicode . UTF8 . SimpleRewindAndValidateWithErrors ( 16 - 3 , pInputBuffer + processedLength - 3 , inputLength - processedLength + 3 ) ;
979
+ // So the code is correct up to invalidBytePointer
980
+ if ( invalidBytePointer < pInputBuffer + processedLength )
981
+ {
982
+ removeCounters ( invalidBytePointer , pInputBuffer + processedLength , ref asciibytes , ref n4 , ref contbytes ) ;
983
+ }
984
+ else
985
+ {
986
+ addCounters ( pInputBuffer + processedLength , invalidBytePointer , ref asciibytes , ref n4 , ref contbytes ) ;
987
+ }
988
+ int totalbyteasciierror = processedLength - start_point ;
989
+ ( utf16CodeUnitCountAdjustment , scalarCountAdjustment ) = CalculateN2N3FinalSIMDAdjustments ( asciibytes , n4 , contbytes , totalbyteasciierror ) ;
990
+ return invalidBytePointer ;
997
991
}
998
992
prevIncomplete = Vector128 < byte > . Zero ;
999
993
}
@@ -1019,17 +1013,25 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
1019
1013
// hardware:
1020
1014
if ( AdvSimd . Arm64 . MaxAcross ( Vector128 . AsUInt32 ( error ) ) . ToScalar ( ) != 0 )
1021
1015
{
1022
- int off = processedLength >= 3 ? processedLength - 3 : processedLength ;
1023
- byte * invalidBytePointer = SimdUnicode . UTF8 . RewindAndValidateWithErrors ( off , pInputBuffer + processedLength , inputLength - processedLength , ref TailUtf16CodeUnitCountAdjustment , ref TailScalarCodeUnitCountAdjustment ) ;
1024
- utf16CodeUnitCountAdjustment = TailUtf16CodeUnitCountAdjustment ;
1025
- scalarCountAdjustment = TailScalarCodeUnitCountAdjustment ;
1026
-
1027
- int totalbyteasciierror = processedLength - start_point ;
1028
- var ( utfadjustasciierror , scalaradjustasciierror ) = calculateErrorPathadjust ( start_point , processedLength , pInputBuffer , asciibytes , n4 , contbytes ) ;
1029
-
1030
- utf16CodeUnitCountAdjustment += utfadjustasciierror ;
1031
- scalarCountAdjustment += scalaradjustasciierror ;
1032
-
1016
+ byte * invalidBytePointer ;
1017
+ if ( processedLength == 0 )
1018
+ {
1019
+ invalidBytePointer = SimdUnicode . UTF8 . SimpleRewindAndValidateWithErrors ( 0 , pInputBuffer + processedLength , inputLength - processedLength ) ;
1020
+ }
1021
+ else
1022
+ {
1023
+ invalidBytePointer = SimdUnicode . UTF8 . SimpleRewindAndValidateWithErrors ( processedLength - 3 , pInputBuffer + processedLength - 3 , inputLength - processedLength + 3 ) ;
1024
+ }
1025
+ if ( invalidBytePointer < pInputBuffer + processedLength )
1026
+ {
1027
+ removeCounters ( invalidBytePointer , pInputBuffer + processedLength , ref asciibytes , ref n4 , ref contbytes ) ;
1028
+ }
1029
+ else
1030
+ {
1031
+ addCounters ( pInputBuffer + processedLength , invalidBytePointer , ref asciibytes , ref n4 , ref contbytes ) ;
1032
+ }
1033
+ int total_bytes_processed = ( int ) ( invalidBytePointer - ( pInputBuffer + start_point ) ) ;
1034
+ ( utf16CodeUnitCountAdjustment , scalarCountAdjustment ) = CalculateN2N3FinalSIMDAdjustments ( asciibytes , n4 , contbytes , total_bytes_processed ) ;
1033
1035
return invalidBytePointer ;
1034
1036
}
1035
1037
prevIncomplete = AdvSimd . SubtractSaturate ( currentBlock , maxValue ) ;
@@ -1041,34 +1043,44 @@ private unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust
1041
1043
asciibytes -= ( sbyte ) AdvSimd . Arm64 . AddAcross ( AdvSimd . CompareLessThan ( currentBlock , v80 ) ) . ToScalar ( ) ;
1042
1044
}
1043
1045
1044
- int totalbyte = processedLength - start_point ;
1045
- var ( utf16adjust , scalaradjust ) = CalculateN2N3FinalSIMDAdjustments ( asciibytes , n4 , contbytes , totalbyte ) ;
1046
-
1047
- TempUtf16CodeUnitCountAdjustment = utf16adjust ;
1048
- TempScalarCountAdjustment = scalaradjust ;
1049
-
1050
- }
1051
- }
1052
- // We have processed all the blocks using SIMD, we need to process the remaining bytes.
1053
- // Process the remaining bytes with the scalar function
1054
- // worst possible case is 4 bytes, where we need to backtrack 3 bytes
1055
- // 11110xxxx 10xxxxxx 10xxxxxx 10xxxxxx <== we might be pointing at the last byte
1056
- if ( processedLength < inputLength )
1057
- {
1058
-
1059
- byte * invalidBytePointer = SimdUnicode . UTF8 . RewindAndValidateWithErrors ( processedLength , pInputBuffer + processedLength , inputLength - processedLength , ref TailUtf16CodeUnitCountAdjustment , ref TailScalarCodeUnitCountAdjustment ) ;
1060
- if ( invalidBytePointer != pInputBuffer + inputLength )
1061
- {
1062
- utf16CodeUnitCountAdjustment = TempUtf16CodeUnitCountAdjustment + TailUtf16CodeUnitCountAdjustment ;
1063
- scalarCountAdjustment = TempScalarCountAdjustment + TailScalarCodeUnitCountAdjustment ;
1046
+ // We may still have an error.
1047
+ if ( processedLength < inputLength || ! Avx2 . TestZ ( prevIncomplete , prevIncomplete ) )
1048
+ {
1049
+ byte * invalidBytePointer ;
1050
+ if ( processedLength == 0 )
1051
+ {
1052
+ invalidBytePointer = SimdUnicode . UTF8 . SimpleRewindAndValidateWithErrors ( 0 , pInputBuffer + processedLength , inputLength - processedLength ) ;
1053
+ }
1054
+ else
1055
+ {
1056
+ invalidBytePointer = SimdUnicode . UTF8 . SimpleRewindAndValidateWithErrors ( processedLength - 3 , pInputBuffer + processedLength - 3 , inputLength - processedLength + 3 ) ;
1064
1057
1065
- // An invalid byte was found by the scalar function
1066
- return invalidBytePointer ;
1058
+ }
1059
+ if ( invalidBytePointer != pInputBuffer + inputLength )
1060
+ {
1061
+ if ( invalidBytePointer < pInputBuffer + processedLength )
1062
+ {
1063
+ removeCounters ( invalidBytePointer , pInputBuffer + processedLength , ref asciibytes , ref n4 , ref contbytes ) ;
1064
+ }
1065
+ else
1066
+ {
1067
+ addCounters ( pInputBuffer + processedLength , invalidBytePointer , ref asciibytes , ref n4 , ref contbytes ) ;
1068
+ }
1069
+ int total_bytes_processed = ( int ) ( invalidBytePointer - ( pInputBuffer + start_point ) ) ;
1070
+ ( utf16CodeUnitCountAdjustment , scalarCountAdjustment ) = CalculateN2N3FinalSIMDAdjustments ( asciibytes , n4 , contbytes , total_bytes_processed ) ;
1071
+ return invalidBytePointer ;
1072
+ }
1073
+ else
1074
+ {
1075
+ addCounters ( pInputBuffer + processedLength , invalidBytePointer , ref asciibytes , ref n4 , ref contbytes ) ;
1076
+ }
1077
+ }
1078
+ int final_total_bytes_processed = inputLength - start_point ;
1079
+ ( utf16CodeUnitCountAdjustment , scalarCountAdjustment ) = CalculateN2N3FinalSIMDAdjustments ( asciibytes , n4 , contbytes , final_total_bytes_processed ) ;
1080
+ return pInputBuffer + inputLength ;
1067
1081
}
1068
1082
}
1069
- utf16CodeUnitCountAdjustment = TempUtf16CodeUnitCountAdjustment + TailUtf16CodeUnitCountAdjustment ;
1070
- scalarCountAdjustment = TempScalarCountAdjustment + TailScalarCodeUnitCountAdjustment ;
1071
- return pInputBuffer + inputLength ;
1083
+ return GetPointerToFirstInvalidByteScalar ( pInputBuffer + processedLength , inputLength - processedLength , out utf16CodeUnitCountAdjustment , out scalarCountAdjustment ) ;
1072
1084
}
1073
1085
1074
1086
private static unsafe void removeCounters ( byte * start , byte * end , ref int asciibytes , ref int n4 , ref int contbytes )
0 commit comments