Skip to content

Commit c3e5fb6

Browse files
authored
feat(tests/zkevm): add worst-case benchmark for MOD (#1151)
Add a worst-case test running a block with as many MOD instructions with arguments of the parametrized range.
1 parent f14624b commit c3e5fb6

File tree

1 file changed

+140
-8
lines changed

1 file changed

+140
-8
lines changed

tests/zkevm/test_worst_compute.py

Lines changed: 140 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,21 @@
3737
KECCAK_RATE = 136
3838

3939

40+
def neg(x: int) -> int:
41+
"""Negate the given integer in the two's complement 256-bit range."""
42+
assert 0 <= x < 2**256
43+
return 2**256 - x
44+
45+
46+
def make_dup(index: int) -> Opcode:
47+
"""
48+
Create a DUP instruction which duplicates the index-th (counting from 0) element
49+
from the top of the stack. E.g. make_dup(0) → DUP1.
50+
"""
51+
assert 0 <= index < 16
52+
return Opcode(0x80 + index, pushed_stack_items=1, min_stack_height=index + 1)
53+
54+
4055
@pytest.mark.valid_from("Cancun")
4156
def test_worst_keccak(
4257
blockchain_test: BlockchainTestFiller,
@@ -758,7 +773,7 @@ def sar(x, s):
758773
rng = random.Random(1) # Use random with a fixed seed.
759774
initial_value = 2**256 - 1 # The initial value to be shifted; should be negative for SAR.
760775

761-
# Create the list of shift amounts if length 15 (max reachable by DUPs instructions).
776+
# Create the list of shift amounts with 15 elements (max reachable by DUPs instructions).
762777
# For the worst case keep the values small and omit values divisible by 8.
763778
shift_amounts = [x + (x >= 8) + (x >= 15) for x in range(1, 16)]
764779

@@ -775,18 +790,13 @@ def select_shift_amount(shift_fn, v):
775790
if new_v != 0:
776791
return new_v, index
777792

778-
def make_dup(i):
779-
"""Create a DUP instruction to get the i-th shift amount constant from the stack."""
780-
# TODO: Create a global helper for this.
781-
return Opcode(0x80 + (len(shift_amounts) - i))
782-
783793
code_body = Bytecode()
784794
v = initial_value
785795
while len(code_body) <= code_body_len - 4:
786796
v, i = select_shift_amount(shl, v)
787-
code_body += make_dup(i) + Op.SHL
797+
code_body += make_dup(len(shift_amounts) - i) + Op.SHL
788798
v, i = select_shift_amount(shift_right_fn, v)
789-
code_body += make_dup(i) + shift_right
799+
code_body += make_dup(len(shift_amounts) - i) + shift_right
790800

791801
code = code_prefix + code_body + code_suffix
792802
assert len(code) == MAX_CODE_SIZE - 2
@@ -806,3 +816,125 @@ def make_dup(i):
806816
post={},
807817
blocks=[Block(txs=[tx])],
808818
)
819+
820+
821+
@pytest.mark.valid_from("Cancun")
822+
@pytest.mark.parametrize("mod_bits", [255, 191, 127, 63])
823+
@pytest.mark.parametrize("op", [Op.MOD, Op.SMOD])
824+
def test_worst_mod(
825+
blockchain_test: BlockchainTestFiller,
826+
pre: Alloc,
827+
mod_bits: int,
828+
op: Op,
829+
):
830+
"""
831+
Test running a block with as many MOD instructions with arguments of the parametrized range.
832+
The test program consists of code segments evaluating the "MOD chain":
833+
mod[0] = calldataload(0)
834+
mod[1] = numerators[indexes[0]] % mod[0]
835+
mod[2] = numerators[indexes[1]] % mod[1] ...
836+
The "numerators" is a pool of 15 constants pushed to the EVM stack at the program start.
837+
The order of accessing the numerators is selected in a way the mod value remains in the range
838+
as long as possible.
839+
"""
840+
# For SMOD we negate both numerator and modulus. The underlying computation is the same,
841+
# just the SMOD implementation will have to additionally handle the sign bits.
842+
# The result stays negative.
843+
should_negate = op == Op.SMOD
844+
845+
num_numerators = 15
846+
numerator_bits = 256 if not should_negate else 255
847+
numerator_max = 2**numerator_bits - 1
848+
numerator_min = 2 ** (numerator_bits - 1)
849+
850+
# Pick the modulus min value so that it is _unlikely_ to drop to the lower word count.
851+
assert mod_bits >= 63
852+
mod_min = 2 ** (mod_bits - 63)
853+
854+
# Select the random seed giving the longest found MOD chain.
855+
# You can look for a longer one by increasing the numerators_min_len. This will activate
856+
# the while loop below.
857+
match op, mod_bits:
858+
case Op.MOD, 255:
859+
seed = 20393
860+
numerators_min_len = 750
861+
case Op.MOD, 191:
862+
seed = 25979
863+
numerators_min_len = 770
864+
case Op.MOD, 127:
865+
seed = 17671
866+
numerators_min_len = 750
867+
case Op.MOD, 63:
868+
seed = 29181
869+
numerators_min_len = 730
870+
case Op.SMOD, 255:
871+
seed = 4015
872+
numerators_min_len = 750
873+
case Op.SMOD, 191:
874+
seed = 17355
875+
numerators_min_len = 750
876+
case Op.SMOD, 127:
877+
seed = 897
878+
numerators_min_len = 750
879+
case Op.SMOD, 63:
880+
seed = 7562
881+
numerators_min_len = 720
882+
case _:
883+
raise ValueError(f"{mod_bits}-bit {op} not supported.")
884+
885+
while True:
886+
rng = random.Random(seed)
887+
888+
# Create the list of random numerators.
889+
numerators = [rng.randint(numerator_min, numerator_max) for _ in range(num_numerators)]
890+
891+
# Create the random initial modulus.
892+
initial_mod = rng.randint(2 ** (mod_bits - 1), 2**mod_bits - 1)
893+
894+
# Evaluate the MOD chain and collect the order of accessing numerators.
895+
mod = initial_mod
896+
indexes = []
897+
while mod >= mod_min:
898+
results = [n % mod for n in numerators] # Compute results for each numerator.
899+
i = max(range(len(results)), key=results.__getitem__) # And pick the best one.
900+
mod = results[i]
901+
indexes.append(i)
902+
903+
assert len(indexes) > numerators_min_len # Disable if you want to find longer MOD chains.
904+
if len(indexes) > numerators_min_len:
905+
break
906+
seed += 1
907+
print(f"{seed=}")
908+
909+
# TODO: Don't use fixed PUSH32. Let Bytecode helpers to select optimal push opcode.
910+
code_constant_pool = sum((Op.PUSH32[n] for n in numerators), Bytecode())
911+
code_prefix = code_constant_pool + Op.JUMPDEST
912+
code_suffix = Op.JUMP(len(code_constant_pool))
913+
code_body_len = MAX_CODE_SIZE - len(code_prefix) - len(code_suffix)
914+
code_segment = (
915+
Op.CALLDATALOAD(0) + sum(make_dup(len(numerators) - i) + op for i in indexes) + Op.POP
916+
)
917+
code = (
918+
code_prefix
919+
# TODO: Add int * Bytecode support
920+
+ sum(code_segment for _ in range(code_body_len // len(code_segment)))
921+
+ code_suffix
922+
)
923+
assert (MAX_CODE_SIZE - len(code_segment)) < len(code) <= MAX_CODE_SIZE
924+
925+
env = Environment()
926+
927+
input_value = initial_mod if not should_negate else neg(initial_mod)
928+
tx = Transaction(
929+
to=pre.deploy_contract(code=code),
930+
data=input_value.to_bytes(32, byteorder="big"),
931+
gas_limit=env.gas_limit,
932+
sender=pre.fund_eoa(),
933+
)
934+
935+
blockchain_test(
936+
env=env,
937+
pre=pre,
938+
post={},
939+
blocks=[Block(txs=[tx])],
940+
)

0 commit comments

Comments
 (0)