Skip to content

Commit 0949b93

Browse files
committed
Added updating of counts if no error
1 parent 631195e commit 0949b93

File tree

1 file changed

+59
-46
lines changed

1 file changed

+59
-46
lines changed

src/UTF8.cs

Lines changed: 59 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ public static class UTF8
200200
// Note that this function is unsafe, and it is the caller's responsibility
201201
// to ensure that we can read at least 4 bytes before pInputBuffer.
202202
// (Nick Nuon added 7th may) there is an addenum labeled important in the mock PR however I think we can treat unterminated as
203-
public unsafe static (int totalbyteadjustment,int i,int ascii,int n2,int n4) adjustmentFactor(byte* pInputBuffer) {
203+
public unsafe static (int totalbyteadjustment,int backedupByHowMuch,int ascii,int n2,int n4) adjustmentFactor(byte* pInputBuffer) {
204204
// Find the first non-continuation byte, working backward.
205205
int i = 1;
206206
for (; i <= 4; i++)
@@ -223,27 +223,41 @@ public unsafe static (int totalbyteadjustment,int i,int ascii,int n2,int n4) adj
223223
return (4 - i,i,0,0,-1); // We have that i == 1 or i == 2 or i == 3 or i == 4, if i == 1, we are missing three bytes, if i == 2, we are missing two bytes, if i == 3, we are missing one byte.
224224
}
225225

226-
public unsafe static (int utfadjust, int scalaradjust) calculatefinaladjust(int start_point, int processedLength, byte* pInputBuffer, int asciibytes, int n4, int contbytes)
226+
public static (int utfadjust, int scalaradjust) CalculateN2N3FinalSIMDAdjustments(int asciibytes, int n4, int contbytes, int totalbyte)
227+
{
228+
// Calculate n3 based on the provided formula
229+
int n3 = asciibytes - 2 * n4 + 2 * contbytes - totalbyte;
230+
231+
// Calculate n2 based on the provided formula
232+
int n2 = -2 * asciibytes + n4 - 3 * contbytes + 2 * totalbyte;
233+
234+
// Calculate utfadjust by adding them all up
235+
int utfadjust = -2 * n4 - 2 * n3 - n2;
236+
237+
// Calculate scalaradjust based on n4
238+
int scalaradjust = -n4;
239+
240+
// Return the calculated utfadjust and scalaradjust
241+
return (utfadjust, scalaradjust);
242+
}
243+
244+
245+
246+
247+
248+
public unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust(int start_point, int processedLength, byte* pInputBuffer, int asciibytes, int n4, int n2, int contbytes)
227249
{
228250
// Calculate the total bytes from start_point to processedLength
229251
int totalbyte = processedLength - start_point;
252+
int adjusttotalbyte = 0, backedupByHowMuch = 0, adjustascii = 0, adjustn2 = 0, adjustn4 = 0;
230253

231254
// Adjust the length to include a complete character, if necessary
232255
if (totalbyte > 0)
233256
{
234-
var (temptotalbyte,i ,tempascii, tempn2, tempn4) = adjustmentFactor(pInputBuffer + processedLength);
257+
(adjusttotalbyte, backedupByHowMuch ,adjustascii, adjustn2, adjustn4) = adjustmentFactor(pInputBuffer + processedLength);
235258
}
236259

237-
// Calculate n3 based on provided formula
238-
int n3 = asciibytes - 2 * n4 + 2 * contbytes - totalbyte;
239-
240-
// Calculate n2 based on provided formula
241-
int n2 = -2 * asciibytes + n4 - 4 * contbytes + 2 * totalbyte;
242-
243-
// TODO add them all up
244-
245-
int utfadjust = -2 * n4 - 2* n3 - n2;
246-
int scalaradjust = n4;
260+
var (utfadjust,scalaradjust) = CalculateN2N3FinalSIMDAdjustments( asciibytes + adjustascii, n4 + adjustn4, contbytes + adjustn2, totalbyte + adjusttotalbyte);
247261

248262
// Return the calculated n2 and n3
249263
return (utfadjust, scalaradjust);
@@ -405,7 +419,7 @@ public unsafe static (int utfadjust, int scalaradjust) calculatefinaladjust(int
405419

406420
public unsafe static byte* GetPointerToFirstInvalidByteAvx2(byte* pInputBuffer, int inputLength,out int utf16CodeUnitCountAdjustment, out int scalarCountAdjustment)
407421
{
408-
// Console.WriteLine("--------------------------Calling function----------------------------------");
422+
Console.WriteLine("--------------------------Calling function----------------------------------");
409423
// Console.WriteLine("Length: " + inputLength);
410424
int processedLength = 0;
411425
int TempUtf16CodeUnitCountAdjustment= 0 ;
@@ -546,7 +560,9 @@ public unsafe static (int utfadjust, int scalaradjust) calculatefinaladjust(int
546560
int asciibytes = 0; // number of ascii bytes in the block (could also be called n1)
547561
int contbytes = 0; // number of continuation bytes in the block
548562
int n4 = 0; // number of 4-byte sequences that start in this block
549-
int totalbyte, n3, n2;
563+
// int totalbyte = 0, n3 = 0, n2 = 0;
564+
565+
550566

551567
for (; processedLength + 32 <= inputLength; processedLength += 32)
552568
{
@@ -560,21 +576,18 @@ public unsafe static (int utfadjust, int scalaradjust) calculatefinaladjust(int
560576
// we need to check if the previous block was incomplete.
561577
if (!Avx2.TestZ(prevIncomplete, prevIncomplete))
562578
{
563-
564-
// TODO/think about : this path iss not explicitly tested
565-
// Console.WriteLine("----All ASCII need rewind");
579+
// TODO? : this path iss not explicitly tested
566580
utf16CodeUnitCountAdjustment = TempUtf16CodeUnitCountAdjustment;
567581
scalarCountAdjustment = TempScalarCountAdjustment;
568582

569583
int off = processedLength >= 3 ? processedLength - 3 : processedLength;
570-
// int off = processedLength;
571584
return SimdUnicode.UTF8.RewindAndValidateWithErrors(off, pInputBuffer + off, inputLength - off, ref utf16CodeUnitCountAdjustment,ref scalarCountAdjustment);
572585
}
573586
prevIncomplete = Vector256<byte>.Zero;
574587
}
575588
else // Contains non-ASCII characters, we need to do non-trivial processing
576589
{
577-
// Console.WriteLine("--Found non-ascii:triggering SIMD routine at " + processedLength + "bytes"); //debug
590+
Console.WriteLine("--Found non-ascii:triggering SIMD routine at " + processedLength + "bytes"); //debug
578591
// Use SubtractSaturate to effectively compare if bytes in block are greater than markers.
579592
Vector256<byte> shuffled = Avx2.Permute2x128(prevInputBlock, currentBlock, 0x21);
580593
prevInputBlock = currentBlock;
@@ -594,7 +607,7 @@ public unsafe static (int utfadjust, int scalaradjust) calculatefinaladjust(int
594607
Vector256<byte> error = Avx2.Xor(must23As80, sc);
595608
if (!Avx2.TestZ(error, error))
596609
{
597-
// Console.WriteLine("-----Error path!!");
610+
Console.WriteLine("-----Error path!!");
598611
TailScalarCodeUnitCountAdjustment =0;
599612
TailUtf16CodeUnitCountAdjustment =0;
600613

@@ -610,11 +623,6 @@ public unsafe static (int utfadjust, int scalaradjust) calculatefinaladjust(int
610623

611624
return invalidBytePointer;
612625
}
613-
// Adjustments :TODO:
614-
// TempUtf16CodeUnitCountAdjustment -= (int)fourByteCount * 2;
615-
// TempUtf16CodeUnitCountAdjustment -= (int)twoByteCount;
616-
// TempUtf16CodeUnitCountAdjustment -= (int)threeByteCount *2;
617-
// TempScalarCountAdjustment -= (int)fourByteCount;
618626

619627
// Console.WriteLine("Doublecount(Temp) after SIMD processing:" + TempUtf16CodeUnitCountAdjustment); debug
620628
// Console.WriteLine("Scalarcount after SIMD processing:" + TempScalarCountAdjustment);
@@ -623,7 +631,7 @@ public unsafe static (int utfadjust, int scalaradjust) calculatefinaladjust(int
623631
if (!Avx2.TestZ(prevIncomplete, prevIncomplete))
624632
{
625633
// We have an unterminated sequence.
626-
// Console.WriteLine("---Unterminated seq--- at " + processedLength + "bytes");
634+
Console.WriteLine("---Unterminated seq--- at " + processedLength + "bytes");
627635
// processedLength -= 3;
628636

629637
// Console.WriteLine("incomplete utf16 count", incompleteUtf16CodeUnitPreventDoubleCounting);
@@ -634,43 +642,48 @@ public unsafe static (int utfadjust, int scalaradjust) calculatefinaladjust(int
634642

635643
var (totalbyteadjustment, i,tempascii, tempn2, tempn4) = adjustmentFactor(pInputBuffer + processedLength + 32);
636644
processedLength -= i;
637-
638-
639-
// for(int k = 0; k < 3; k++)
640-
// {
641-
// // TODO:I do not remember why I put +32 here but the compiler complains if I remeve it
642-
// int candidateByte = pInputBuffer[processedLength + 32 + k];
643-
// // Console.WriteLine("Backing up " + k +" bytes");
644-
// // Console.WriteLine("Byte after backing up:" + Convert.ToString(candidateByte, 2).PadLeft(8, '0'));
645-
646-
// // backedup = 3-k +1;
647-
648-
// if ((candidateByte & 0b11000000) == 0b11000000)
649-
// {
650-
// // Whatever you do, do not delete this
651-
// processedLength += k;
652-
// break;
653-
// }
654-
// }
645+
// totalbyte -= totalbyteadjustment;
646+
asciibytes +=tempascii;
647+
n4 += tempn4;
648+
contbytes +=tempn2;
655649

656650
// // Console.WriteLine("Backed up " + backedup +" bytes");
657651
// // Console.WriteLine("TempUTF16:"+ TempUtf16CodeUnitCountAdjustment);
658652
// // Console.WriteLine("TempScalar:"+ TempScalarCountAdjustment);
659653
// // Console.WriteLine("-----------------");
660654

661655
}
656+
657+
// We use one instruction (MoveMask) to update ncon, plus one arithmetic operation.
658+
contbytes += Avx2.MoveMask(sc);
659+
660+
// We use two instructions (SubtractSaturate and MoveMask) to update n4, with one arithmetic operation.
661+
n4 += Avx2.MoveMask(Avx2.SubtractSaturate(currentBlock, fourthByte));
662662
}
663663
}
664+
// There are 2 possible scenarios here : either
665+
// A) it arrives flush en the border. eg it doesnt need to be processed further
666+
// B) There is some bytes remaining in which case we need to call the scalar functien
667+
// Either way we need to calculate n2,n3 and update the utf16adjust and scalar adjust
668+
int totalbyte = processedLength - start_point;
669+
var (utf16adjust, scalaradjust) = CalculateN2N3FinalSIMDAdjustments( asciibytes, n4, contbytes, totalbyte);
670+
671+
utf16CodeUnitCountAdjustment = utf16adjust;
672+
scalarCountAdjustment = scalaradjust;
664673
}
674+
675+
665676
}
666677
// Console.WriteLine("-Done with SIMD part!"); //debug
667678
// We have processed all the blocks using SIMD, we need to process the remaining bytes.
668679
// Process the remaining bytes with the scalar function
680+
681+
669682
// worst possible case is 4 bytes, where we need to backtrack 3 bytes
670683
// 11110xxxx 10xxxxxx 10xxxxxx 10xxxxxx <== we might be pointing at the last byte
671684
if (processedLength < inputLength)
672685
{
673-
// Console.WriteLine("----Process remaining Scalar @ " + processedLength + "bytes");
686+
Console.WriteLine("----Process remaining Scalar @ " + processedLength + "bytes");
674687
// int overlapCount = 0;
675688
// Console.WriteLine("processed length after backtrack:" + processedLength);
676689
// Console.WriteLine("TempUTF16 before tail remaining check:"+ TempUtf16CodeUnitCountAdjustment);

0 commit comments

Comments
 (0)