Skip to content

Commit c0f1a09

Browse files
committed
fix: correct performance problem with arm function, it was due to Vector128.shuffle (DO NOT USE)
1 parent bc1272b commit c0f1a09

File tree

3 files changed

+7
-10
lines changed

3 files changed

+7
-10
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ You can print the content of a vector register like so:
115115
## Performance tips
116116

117117
- Be careful: `Vector128.Shuffle` is not the same as `Ssse3.Shuffle` nor is `Vector128.Shuffle` the same as `Avx2.Shuffle`. Prefer the latter.
118+
- Similarly `Vector128.Shuffle` is not the same as `AdvSimd.Arm64.VectorTableLookup`, use the latter.
118119

119120
## More reading
120121

benchmark/Benchmark.cs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,6 @@ public unsafe void Utf8ValidationRealDataScalar()
210210
}
211211
}
212212

213-
214213
[Benchmark]
215214
[BenchmarkCategory("arm64")]
216215
public unsafe void SIMDUtf8ValidationRealDataArm64()

src/UTF8.cs

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -790,7 +790,6 @@ public unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust(
790790
int asciibytes = 0; // number of ascii bytes in the block (could also be called n1)
791791
int contbytes = 0; // number of continuation bytes in the block
792792
int n4 = 0; // number of 4-byte sequences that start in this block
793-
794793
for (; processedLength + 16 <= inputLength; processedLength += 16)
795794
{
796795

@@ -817,9 +816,10 @@ public unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust(
817816
{
818817
// Contains non-ASCII characters, we need to do non-trivial processing
819818
Vector128<byte> prev1 = AdvSimd.ExtractVector128(prevInputBlock, currentBlock, (byte)(16 - 1));
820-
Vector128<byte> byte_1_high = Vector128.Shuffle(shuf1, AdvSimd.ShiftRightLogical(prev1.AsUInt16(), 4).AsByte() & v0f);
821-
Vector128<byte> byte_1_low = Vector128.Shuffle(shuf2, (prev1 & v0f));
822-
Vector128<byte> byte_2_high = Vector128.Shuffle(shuf3, AdvSimd.ShiftRightLogical(currentBlock.AsUInt16(), 4).AsByte() & v0f);
819+
// Vector128.Shuffle vs AdvSimd.Arm64.VectorTableLookup: prefer the latter!!!
820+
Vector128<byte> byte_1_high = AdvSimd.Arm64.VectorTableLookup(shuf1, AdvSimd.ShiftRightLogical(prev1.AsUInt16(), 4).AsByte() & v0f);
821+
Vector128<byte> byte_1_low = AdvSimd.Arm64.VectorTableLookup (shuf2, (prev1 & v0f));
822+
Vector128<byte> byte_2_high = AdvSimd.Arm64.VectorTableLookup (shuf3, AdvSimd.ShiftRightLogical(currentBlock.AsUInt16(), 4).AsByte() & v0f);
823823
Vector128<byte> sc = AdvSimd.And(AdvSimd.And(byte_1_high, byte_1_low), byte_2_high);
824824
Vector128<byte> prev2 = AdvSimd.ExtractVector128(prevInputBlock, currentBlock, (byte)(16 - 2));
825825
Vector128<byte> prev3 = AdvSimd.ExtractVector128(prevInputBlock, currentBlock, (byte)(16 - 3));
@@ -849,13 +849,11 @@ public unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust(
849849
}
850850
prevIncomplete = AdvSimd.SubtractSaturate(currentBlock, maxValue);
851851
Vector128<sbyte> largestcont = Vector128.Create((sbyte)-65); // -65 => 0b10111111
852-
contbytes += 16 - AdvSimd.Arm64.AddAcross(AdvSimd.CompareGreaterThan(Vector128.AsSByte(currentBlock), largestcont)).ToScalar();
852+
contbytes += -AdvSimd.Arm64.AddAcross(AdvSimd.CompareLessThanOrEqual(Vector128.AsSByte(currentBlock), largestcont)).ToScalar();
853853
Vector128<byte> fourthByteMinusOne = Vector128.Create((byte)(0b11110000u - 1));
854854
n4 += (int)(AdvSimd.Arm64.AddAcross(AdvSimd.SubtractSaturate(currentBlock, fourthByteMinusOne)).ToScalar());
855855
}
856-
857-
asciibytes -= (int)AdvSimd.Arm64.AddAcross(AdvSimd.CompareGreaterThanOrEqual(currentBlock, v80)).ToScalar();
858-
856+
asciibytes -= (sbyte)AdvSimd.Arm64.AddAcross(AdvSimd.CompareLessThan(currentBlock, v80)).ToScalar();
859857
}
860858

861859
int totalbyte = processedLength - start_point;
@@ -886,7 +884,6 @@ public unsafe static (int utfadjust, int scalaradjust) calculateErrorPathadjust(
886884
}
887885
utf16CodeUnitCountAdjustment = TempUtf16CodeUnitCountAdjustment + TailUtf16CodeUnitCountAdjustment;
888886
scalarCountAdjustment = TempScalarCountAdjustment + TailScalarCodeUnitCountAdjustment;
889-
890887
return pInputBuffer + inputLength;
891888
}
892889
public unsafe static byte* GetPointerToFirstInvalidByte(byte* pInputBuffer, int inputLength, out int Utf16CodeUnitCountAdjustment, out int ScalarCodeUnitCountAdjustment)

0 commit comments

Comments
 (0)