From d37e0b67bc201ed4f00b1167366575757bc5502f Mon Sep 17 00:00:00 2001 From: Ben Adams Date: Wed, 15 Jan 2025 17:55:20 +0000 Subject: [PATCH] Allow input params as output params --- src/Nethermind.Int256.Benchmark/Numbers.cs | 2 +- src/Nethermind.Int256.Test/Convertibles.cs | 4 +- src/Nethermind.Int256.Test/Int256Tests.cs | 4 +- src/Nethermind.Int256.Test/TestNumbers.cs | 2 +- src/Nethermind.Int256.Test/UInt256Tests.cs | 248 ++++++++++++++++++--- src/Nethermind.Int256.Test/UnaryOps.cs | 2 +- src/Nethermind.Int256/Int256.cs | 2 +- src/Nethermind.Int256/UInt256.cs | 45 ++-- 8 files changed, 249 insertions(+), 60 deletions(-) diff --git a/src/Nethermind.Int256.Benchmark/Numbers.cs b/src/Nethermind.Int256.Benchmark/Numbers.cs index e610a3f..356d820 100644 --- a/src/Nethermind.Int256.Benchmark/Numbers.cs +++ b/src/Nethermind.Int256.Benchmark/Numbers.cs @@ -12,7 +12,7 @@ public static class Numbers public static readonly BigInteger TwoTo256 = TwoTo128 * TwoTo128; public static readonly BigInteger UInt256Max = TwoTo256 - 1; - public static readonly BigInteger Int256Max = ( BigInteger.One << 255 ) - 1; + public static readonly BigInteger Int256Max = (BigInteger.One << 255) - 1; public static readonly BigInteger Int256Min = -Int256Max; } } diff --git a/src/Nethermind.Int256.Test/Convertibles.cs b/src/Nethermind.Int256.Test/Convertibles.cs index 18ee10c..b0ece29 100644 --- a/src/Nethermind.Int256.Test/Convertibles.cs +++ b/src/Nethermind.Int256.Test/Convertibles.cs @@ -76,10 +76,10 @@ public static (Type type, BigInteger? min, BigInteger? max)[] ConvertibleTypes = (typeof(decimal), (BigInteger?)decimal.MinValue, (BigInteger?)decimal.MaxValue), (typeof(BigInteger), null, null) }; - + public static IEnumerable TestCases => GenerateTestCases(Numbers, BigInteger.Zero); public static IEnumerable SignedTestCases => GenerateTestCases(SignedNumbers); - + private static IEnumerable GenerateTestCases(IEnumerable<(object, string)> numbers, BigInteger? minValue = null) { Type ExpectedException(BigInteger value, BigInteger? min, BigInteger? max) => diff --git a/src/Nethermind.Int256.Test/Int256Tests.cs b/src/Nethermind.Int256.Test/Int256Tests.cs index 3ccbed9..2c8ec52 100644 --- a/src/Nethermind.Int256.Test/Int256Tests.cs +++ b/src/Nethermind.Int256.Test/Int256Tests.cs @@ -63,7 +63,7 @@ string Expected(string valueString) { if (valueString.Contains("Infinity")) { - return valueString.StartsWith('-') ? "-∞" : "∞" ; + return valueString.StartsWith('-') ? "-∞" : "∞"; } string expected = valueString.Replace(",", ""); return type == typeof(float) ? expected[..Math.Min(6, expected.Length)] : type == typeof(double) ? expected[..Math.Min(14, expected.Length)] : expected; @@ -77,7 +77,7 @@ string Expected(string valueString) string convertedValue = Expected(((IFormattable)System.Convert.ChangeType(item, type)).ToString("0.#", null)); convertedValue.Should().BeEquivalentTo(expected); } - catch (Exception e) when(e.GetType() == expectedException) { } + catch (Exception e) when (e.GetType() == expectedException) { } } } } diff --git a/src/Nethermind.Int256.Test/TestNumbers.cs b/src/Nethermind.Int256.Test/TestNumbers.cs index ba6f84f..a50a1a3 100644 --- a/src/Nethermind.Int256.Test/TestNumbers.cs +++ b/src/Nethermind.Int256.Test/TestNumbers.cs @@ -12,7 +12,7 @@ public static class TestNumbers public static readonly BigInteger TwoTo256 = TwoTo128 * TwoTo128; public static readonly BigInteger UInt256Max = TwoTo256 - 1; - public static readonly BigInteger Int256Max = (BigInteger.One << 255)-1; + public static readonly BigInteger Int256Max = (BigInteger.One << 255) - 1; public static readonly BigInteger Int256Min = -Int256Max; } } diff --git a/src/Nethermind.Int256.Test/UInt256Tests.cs b/src/Nethermind.Int256.Test/UInt256Tests.cs index 19577d9..72466b6 100644 --- a/src/Nethermind.Int256.Test/UInt256Tests.cs +++ b/src/Nethermind.Int256.Test/UInt256Tests.cs @@ -42,7 +42,6 @@ public virtual void Add((BigInteger A, BigInteger B) test) a = convert(test.A); - // Test reusing input as output a.Add(b, out b); b.Convert(out resUInt256); resUInt256.Should().Be(resBigInt); @@ -53,6 +52,8 @@ public virtual void AddOverflow((BigInteger A, BigInteger B) test) { BigInteger resUInt256; BigInteger resBigInt = test.A + test.B; + bool expectedOverflow = test.A + test.B > (BigInteger)UInt256.MaxValue; + resBigInt %= (BigInteger.One << 256); resBigInt = postprocess(resBigInt); T uint256a = convert(test.A); @@ -62,15 +63,22 @@ public virtual void AddOverflow((BigInteger A, BigInteger B) test) res.Convert(out resUInt256); resUInt256.Should().Be(resBigInt); + overflow.Should().Be(expectedOverflow); - if (test.A + test.B <= (BigInteger)UInt256.MaxValue) - { - overflow.Should().Be(false); - } - else - { - overflow.Should().Be(true); - } + // Test reusing input as output + overflow = T.AddOverflow(uint256a, uint256b, out uint256a); + uint256a.Convert(out resUInt256); + + resUInt256.Should().Be(resBigInt); + overflow.Should().Be(expectedOverflow); + + uint256a = convert(test.A); + + overflow = T.AddOverflow(uint256a, uint256b, out uint256b); + uint256b.Convert(out resUInt256); + + resUInt256.Should().Be(resBigInt); + overflow.Should().Be(expectedOverflow); } [TestCaseSource(typeof(TernaryOps), nameof(TernaryOps.TestCases))] @@ -91,6 +99,26 @@ public virtual void AddMod((BigInteger A, BigInteger B, BigInteger M) test) res.Convert(out BigInteger resUInt256); resUInt256.Should().Be(resBigInt); + + // Test reusing input as output + uint256a.AddMod(uint256b, uint256m, out uint256a); + uint256a.Convert(out resUInt256); + + resUInt256.Should().Be(resBigInt); + + uint256a = convert(test.A); + + uint256a.AddMod(uint256b, uint256m, out uint256b); + uint256b.Convert(out resUInt256); + + resUInt256.Should().Be(resBigInt); + + uint256b = convert(test.B); + + uint256a.AddMod(uint256b, uint256m, out uint256m); + uint256m.Convert(out resUInt256); + + resUInt256.Should().Be(resBigInt); } [TestCaseSource(typeof(BinaryOps), nameof(BinaryOps.TestCases))] @@ -111,6 +139,19 @@ public virtual void Subtract((BigInteger A, BigInteger B) test) res.Convert(out BigInteger resUInt256); resUInt256.Should().Be(resBigInt); + + // Test reusing input as output + uint256a.Subtract(uint256b, out uint256a); + uint256a.Convert(out resUInt256); + + resUInt256.Should().Be(resBigInt); + + uint256a = convert(test.A); + + uint256a.Subtract(uint256b, out uint256b); + uint256b.Convert(out resUInt256); + + resUInt256.Should().Be(resBigInt); } [TestCaseSource(typeof(TernaryOps), nameof(TernaryOps.TestCases))] @@ -143,45 +184,78 @@ protected void SubtractModCore((BigInteger A, BigInteger B, BigInteger M) test, res.Convert(out BigInteger resUInt256); resUInt256.Should().Be(resBigInt); + + // Test reusing input as output + uint256a.SubtractMod(uint256b, uint256m, out uint256a); + uint256a.Convert(out resUInt256); + + resUInt256.Should().Be(resBigInt); + + uint256a = convert(test.A); + + uint256a.SubtractMod(uint256b, uint256m, out uint256b); + uint256b.Convert(out resUInt256); + + resUInt256.Should().Be(resBigInt); + + uint256b = convert(test.B); + + uint256a.SubtractMod(uint256b, uint256m, out uint256m); + uint256m.Convert(out resUInt256); + + resUInt256.Should().Be(resBigInt); } [TestCaseSource(typeof(BinaryOps), nameof(BinaryOps.TestCases))] public virtual void SubtractOverflow((BigInteger A, BigInteger B) test) { - BigInteger resUInt256; BigInteger resBigInt = test.A - test.B; resBigInt %= BigInteger.One << 256; resBigInt = postprocess(resBigInt); T uint256a = convert(test.A); T uint256b = convert(test.B); - if (test.A >= test.B) - { - if (uint256a is UInt256 a && uint256b is UInt256 b) - { - UInt256 res = a - b; - res.Convert(out resUInt256); - } - else - { - uint256a.Subtract(uint256b, out T res); - res.Convert(out resUInt256); - } - resUInt256.Should().Be(resBigInt); - } - else + SubtractTest(in uint256a, in uint256b, out T res); + + // Test reusing input as output + SubtractTest(in uint256a, in uint256b, out uint256a); + + uint256a = convert(test.A); + + SubtractTest(in uint256a, in uint256b, out uint256b); + + void SubtractTest(in T uint256a, in T uint256b, out T res) { - if (uint256a is UInt256 a && uint256b is UInt256 b) + BigInteger resUInt256; + if (test.A >= test.B) { - a.Invoking(y => y - b) - .Should().Throw() - .WithMessage($"Underflow in subtraction {a} - {b}"); + if (uint256a is UInt256 a && uint256b is UInt256 b) + { + res = (T)(object)(a - b); + res.Convert(out resUInt256); + } + else + { + uint256a.Subtract(uint256b, out res); + res.Convert(out resUInt256); + } + resUInt256.Should().Be(resBigInt); } else { - uint256a.Subtract(uint256b, out T res); - res.Convert(out resUInt256); - resUInt256.Should().Be(resBigInt); + if (uint256a is UInt256 a && uint256b is UInt256 b) + { + res = default!; + a.Invoking(y => y - b) + .Should().Throw() + .WithMessage($"Underflow in subtraction {a} - {b}"); + } + else + { + uint256a.Subtract(uint256b, out res); + res.Convert(out resUInt256); + resUInt256.Should().Be(resBigInt); + } } } } @@ -230,6 +304,24 @@ public virtual void MultiplyMod((BigInteger A, BigInteger B, BigInteger M) test) res.Convert(out BigInteger resUInt256); resUInt256.Should().Be(resBigInt); + + // Test reusing input as output + uint256a.MultiplyMod(uint256b, uint256m, out uint256a); + uint256a.Convert(out resUInt256); + + resUInt256.Should().Be(resBigInt); + + uint256a = convert(test.A); + uint256a.MultiplyMod(uint256b, uint256m, out uint256b); + uint256b.Convert(out resUInt256); + + resUInt256.Should().Be(resBigInt); + + uint256b = convert(test.B); + uint256a.MultiplyMod(uint256b, uint256m, out uint256m); + uint256m.Convert(out resUInt256); + + resUInt256.Should().Be(resBigInt); } [TestCaseSource(typeof(BinaryOps), nameof(BinaryOps.TestCases))] @@ -248,6 +340,18 @@ public virtual void Div((BigInteger A, BigInteger B) test) res.Convert(out BigInteger resUInt256); resUInt256.Should().Be(resBigInt); + + // Test reusing input as output + uint256a.Divide(uint256b, out uint256a); + uint256a.Convert(out resUInt256); + + resUInt256.Should().Be(resBigInt); + + uint256a = convert(test.A); + uint256a.Divide(uint256b, out uint256b); + uint256b.Convert(out resUInt256); + + resUInt256.Should().Be(resBigInt); } [TestCaseSource(typeof(BinaryOps), nameof(BinaryOps.TestCases))] @@ -262,6 +366,18 @@ public virtual void And((BigInteger A, BigInteger B) test) res.Convert(out BigInteger resUInt256); resUInt256.Should().Be(resBigInt); + + // Test reusing input as output + T.And(uint256a, uint256b, out uint256a); + uint256a.Convert(out resUInt256); + + resUInt256.Should().Be(resBigInt); + + uint256a = convert(test.A); + T.And(uint256a, uint256b, out uint256b); + uint256b.Convert(out resUInt256); + + resUInt256.Should().Be(resBigInt); } [TestCaseSource(typeof(BinaryOps), nameof(BinaryOps.TestCases))] @@ -276,6 +392,18 @@ public virtual void Or((BigInteger A, BigInteger B) test) res.Convert(out BigInteger resUInt256); resUInt256.Should().Be(resBigInt); + + // Test reusing input as output + T.Or(uint256a, uint256b, out uint256a); + uint256a.Convert(out resUInt256); + + resUInt256.Should().Be(resBigInt); + + uint256a = convert(test.A); + T.Or(uint256a, uint256b, out uint256b); + uint256b.Convert(out resUInt256); + + resUInt256.Should().Be(resBigInt); } [TestCaseSource(typeof(BinaryOps), nameof(BinaryOps.TestCases))] @@ -290,6 +418,19 @@ public virtual void Xor((BigInteger A, BigInteger B) test) res.Convert(out BigInteger resUInt256); resUInt256.Should().Be(resBigInt); + + // Test reusing input as output + T.Xor(uint256a, uint256b, out uint256a); + uint256a.Convert(out resUInt256); + + resUInt256.Should().Be(resBigInt); + + uint256a = convert(test.A); + + T.Xor(uint256a, uint256b, out uint256b); + uint256b.Convert(out resUInt256); + + resUInt256.Should().Be(resBigInt); } [TestCaseSource(typeof(BinaryOps), nameof(BinaryOps.ShiftTestCases))] @@ -305,6 +446,12 @@ public virtual void Exp((BigInteger A, int n) test) res.Convert(out BigInteger resUInt256); resUInt256.Should().Be(resBigInt); + + // Test reusing input as output + uint256a.Exp(convertFromInt(test.n), out uint256a); + res.Convert(out resUInt256); + + resUInt256.Should().Be(resBigInt); } [TestCaseSource(typeof(TernaryOps), nameof(TernaryOps.TestCases))] @@ -326,6 +473,24 @@ public virtual void ExpMod((BigInteger A, BigInteger B, BigInteger M) test) res.Convert(out BigInteger resUInt256); resUInt256.Should().Be(resBigInt); + + // Test reusing input as output + uint256a.ExpMod(uint256b, uint256m, out uint256a); + uint256a.Convert(out resUInt256); + + resUInt256.Should().Be(resBigInt); + + uint256a = convert(test.A); + uint256a.ExpMod(uint256b, uint256m, out uint256b); + uint256b.Convert(out resUInt256); + + resUInt256.Should().Be(resBigInt); + + uint256b = convert(test.B); + uint256a.ExpMod(uint256b, uint256m, out uint256m); + uint256m.Convert(out resUInt256); + + resUInt256.Should().Be(resBigInt); } [TestCaseSource(typeof(BinaryOps), nameof(BinaryOps.ShiftTestCases))] @@ -344,6 +509,12 @@ public virtual void Lsh((BigInteger A, int n) test) res.Convert(out BigInteger resUInt256); resUInt256.Should().Be(resBigInt); + + // Test reusing input as output + uint256a.LeftShift(test.n, out uint256a); + uint256a.Convert(out resUInt256); + + resUInt256.Should().Be(resBigInt); } [TestCaseSource(typeof(BinaryOps), nameof(BinaryOps.ShiftTestCases))] @@ -362,6 +533,12 @@ public virtual void Rsh((BigInteger A, int n) test) res.Convert(out BigInteger resUInt256); resUInt256.Should().Be(resBigInt); + + // Test reusing input as output + uint256a.RightShift(test.n, out uint256a); + uint256a.Convert(out resUInt256); + + resUInt256.Should().Be(resBigInt); } [TestCaseSource(typeof(UnaryOps), nameof(UnaryOps.TestCases))] @@ -385,6 +562,11 @@ public virtual void Not(BigInteger test) T.Not(uint256, out T res); res.Convert(out BigInteger resUInt256); resUInt256.Should().Be(resBigInt); + + // Test reusing input as output + T.Not(uint256, out uint256); + uint256.Convert(out resUInt256); + resUInt256.Should().Be(resBigInt); } [TestCaseSource(typeof(UnaryOps), nameof(UnaryOps.TestCases))] @@ -476,7 +658,7 @@ string Expected(string valueString) { if (valueString.Contains("Infinity")) { - return valueString.StartsWith('-') ? "-∞" : "∞" ; + return valueString.StartsWith('-') ? "-∞" : "∞"; } string expected = valueString.Replace(",", ""); return type == typeof(float) ? expected[..Math.Min(6, expected.Length)] : type == typeof(double) ? expected[..Math.Min(14, expected.Length)] : expected; diff --git a/src/Nethermind.Int256.Test/UnaryOps.cs b/src/Nethermind.Int256.Test/UnaryOps.cs index 651b11f..1908c13 100644 --- a/src/Nethermind.Int256.Test/UnaryOps.cs +++ b/src/Nethermind.Int256.Test/UnaryOps.cs @@ -62,7 +62,7 @@ public static class UnaryOps public static IEnumerable ShiftTestCases => Enumerable.Range(0, 257); const int Seed = 0; - + public static IEnumerable RandomSigned(int count) { Random rand = new(Seed); diff --git a/src/Nethermind.Int256/Int256.cs b/src/Nethermind.Int256/Int256.cs index 79b3637..08918bd 100644 --- a/src/Nethermind.Int256/Int256.cs +++ b/src/Nethermind.Int256/Int256.cs @@ -67,7 +67,7 @@ public Int256(int n) Add(in a, in b, out Int256 res); return res; } - + public static bool AddOverflow(in Int256 a, in Int256 b, out Int256 res) { var overflow = UInt256.AddOverflow(a._value, b._value, out UInt256 ures); diff --git a/src/Nethermind.Int256/UInt256.cs b/src/Nethermind.Int256/UInt256.cs index 7c2a2c7..75a4c73 100644 --- a/src/Nethermind.Int256/UInt256.cs +++ b/src/Nethermind.Int256/UInt256.cs @@ -514,16 +514,16 @@ public static void AddMod(in UInt256 x, in UInt256 y, in UInt256 m, out UInt256 return; } - if (AddOverflow(x, y, out res)) + if (AddOverflow(x, y, out UInt256 intermediate)) { const int length = 5; - Span sum = stackalloc ulong[length] { res.u0, res.u1, res.u2, res.u3, 1 }; + Span sum = stackalloc ulong[length] { intermediate.u0, intermediate.u1, intermediate.u2, intermediate.u3, 1 }; Span quot = stackalloc ulong[length]; Udivrem(ref MemoryMarshal.GetReference(quot), ref MemoryMarshal.GetReference(sum), length, in m, out res); } else { - Mod(res, m, out res); + Mod(intermediate, m, out res); } } @@ -965,19 +965,21 @@ private static bool SubtractImpl(in UInt256 a, in UInt256 b, out UInt256 res) public static void SubtractMod(in UInt256 a, in UInt256 b, in UInt256 m, out UInt256 res) { - if (SubtractUnderflow(a, b, out res)) + if (SubtractUnderflow(a, b, out UInt256 intermediate)) { - Subtract(b, a, out res); - Mod(res, m, out res); - if (!res.IsZero) + Subtract(b, a, out intermediate); + Mod(intermediate, m, out intermediate); + if (!intermediate.IsZero) { - Subtract(m, res, out res); + Subtract(m, intermediate, out intermediate); } } else { - Mod(res, m, out res); + Mod(intermediate, m, out intermediate); } + + res = intermediate; } public void SubtractMod(in UInt256 a, in UInt256 m, out UInt256 res) => SubtractMod(this, a, m, out res); @@ -1055,7 +1057,7 @@ private void Squared(out UInt256 result) (carry0, ulong temp2) = UmulHopi(carry0, u0, u2); (ulong carry1, ulong res1) = UmulHopi(temp1, u0, u1); - (carry1, temp2) = UmulStepi(temp2,u1 , u1, carry1); + (carry1, temp2) = UmulStepi(temp2, u1, u1, carry1); (ulong carry2, ulong res2) = UmulHopi(temp2, u0, u2); @@ -1090,17 +1092,19 @@ public static void ExpMod(in UInt256 b, in UInt256 e, in UInt256 m, out UInt256 result = Zero; return; } - result = One; + UInt256 intermediate = One; UInt256 bs = b; int len = e.BitLen; for (int i = 0; i < len; i++) { if (e.Bit(i)) { - MultiplyMod(result, bs, m, out result); + MultiplyMod(intermediate, bs, m, out intermediate); } MultiplyMod(bs, bs, m, out bs); } + + result = intermediate; } public void ExpMod(in UInt256 exp, in UInt256 m, out UInt256 res) => ExpMod(this, exp, m, out res); @@ -1247,9 +1251,10 @@ public static void Divide(in UInt256 x, in UInt256 y, out UInt256 res) // At this point, we know // x/y ; x > y > 0 - res = default; // initialize with zeros + UInt256 intermediate = default; // initialize with zeros const int length = 4; - Udivrem(ref Unsafe.As(ref res), ref Unsafe.As(ref Unsafe.AsRef(in x)), length, y, out UInt256 _); + Udivrem(ref Unsafe.As(ref intermediate), ref Unsafe.As(ref Unsafe.AsRef(in x)), length, y, out UInt256 _); + res = intermediate; } public void Divide(in UInt256 a, out UInt256 res) => Divide(this, a, out res); @@ -1332,8 +1337,7 @@ public static void Lsh(in UInt256 x, int n, out UInt256 res) } } - res = Zero; - ulong z0 = res.u0, z1 = res.u1, z2 = res.u2, z3 = res.u3; + ulong z0 = 0, z1 = 0, z2 = 0; ulong a = 0, b = 0; // Big swaps first if (n > 192) @@ -1376,7 +1380,7 @@ public static void Lsh(in UInt256 x, int n, out UInt256 res) sh128: a = Rsh(res.u2, 64 - n); z2 = Lsh(res.u2, n) | b; - + ulong z3; sh192: z3 = Lsh(res.u3, n) | a; @@ -1426,9 +1430,12 @@ public static void Rsh(in UInt256 x, int n, out UInt256 res) } } - res = Zero; - ulong z0 = res.u0, z1 = res.u1, z2 = res.u2, z3 = res.u3; ulong a = 0, b = 0; + ulong z3; + ulong z2; + ulong z1; + + ulong z0; // Big swaps first if (n > 192) {