|
6 | 6 | """
|
7 | 7 |
|
8 | 8 | import math
|
| 9 | +import operator |
9 | 10 | import random
|
10 | 11 | from enum import Enum, auto
|
11 | 12 | from typing import cast
|
@@ -1315,6 +1316,109 @@ def test_worst_memory_access(
|
1315 | 1316 | )
|
1316 | 1317 |
|
1317 | 1318 |
|
| 1319 | +@pytest.mark.valid_from("Cancun") |
| 1320 | +@pytest.mark.parametrize("mod_bits", [255, 191, 127, 63]) |
| 1321 | +@pytest.mark.parametrize("op", [Op.ADDMOD, Op.MULMOD]) |
| 1322 | +def test_worst_modarith( |
| 1323 | + state_test: StateTestFiller, |
| 1324 | + pre: Alloc, |
| 1325 | + mod_bits: int, |
| 1326 | + op: Op, |
| 1327 | +): |
| 1328 | + """ |
| 1329 | + Test running a block with as many "op" instructions with arguments of the parametrized range. |
| 1330 | + The test program consists of code segments evaluating the "op chain": |
| 1331 | + mod[0] = calldataload(0) |
| 1332 | + mod[1] = (fixed_arg op args[indexes[0]]) % mod[0] |
| 1333 | + mod[2] = (fixed_arg op args[indexes[1]]) % mod[1] |
| 1334 | + The "args" is a pool of 15 constants pushed to the EVM stack at the program start. |
| 1335 | + The "fixed_arg" is the 0xFF...FF constant added to the EVM stack by PUSH32 |
| 1336 | + just before executing the "op". |
| 1337 | + The order of accessing the numerators is selected in a way the mod value remains in the range |
| 1338 | + as long as possible. |
| 1339 | + """ |
| 1340 | + fixed_arg = 2**256 - 1 |
| 1341 | + num_args = 15 |
| 1342 | + |
| 1343 | + # Pick the modulus min value so that it is _unlikely_ to drop to the lower word count. |
| 1344 | + assert mod_bits >= 63 |
| 1345 | + mod_min = 2 ** (mod_bits - 63) |
| 1346 | + |
| 1347 | + # Select the random seed giving the longest found op chain. |
| 1348 | + # You can look for a longer one by increasing the op_chain_len. This will activate |
| 1349 | + # the while loop below. |
| 1350 | + op_chain_len = 666 |
| 1351 | + match op, mod_bits: |
| 1352 | + case Op.ADDMOD, 255: |
| 1353 | + seed = 4 |
| 1354 | + case Op.ADDMOD, 191: |
| 1355 | + seed = 2 |
| 1356 | + case Op.ADDMOD, 127: |
| 1357 | + seed = 2 |
| 1358 | + case Op.ADDMOD, 63: |
| 1359 | + seed = 64 |
| 1360 | + case Op.MULMOD, 255: |
| 1361 | + seed = 5 |
| 1362 | + case Op.MULMOD, 191: |
| 1363 | + seed = 389 |
| 1364 | + case Op.MULMOD, 127: |
| 1365 | + seed = 5 |
| 1366 | + case Op.MULMOD, 63: |
| 1367 | + # For this setup we were not able to find an op-chain longer than 600. |
| 1368 | + seed = 4193 |
| 1369 | + op_chain_len = 600 |
| 1370 | + case _: |
| 1371 | + raise ValueError(f"{mod_bits}-bit {op} not supported.") |
| 1372 | + |
| 1373 | + while True: |
| 1374 | + rng = random.Random(seed) |
| 1375 | + args = [rng.randint(2**255, 2**256 - 1) for _ in range(num_args)] |
| 1376 | + initial_mod = rng.randint(2 ** (mod_bits - 1), 2**mod_bits - 1) |
| 1377 | + |
| 1378 | + # Evaluate the op chain and collect the order of accessing numerators. |
| 1379 | + op_fn = operator.add if op == Op.ADDMOD else operator.mul |
| 1380 | + mod = initial_mod |
| 1381 | + indexes: list[int] = [] |
| 1382 | + while mod >= mod_min and len(indexes) < op_chain_len: |
| 1383 | + results = [op_fn(a, fixed_arg) % mod for a in args] |
| 1384 | + i = max(range(len(results)), key=results.__getitem__) # And pick the best one. |
| 1385 | + mod = results[i] |
| 1386 | + indexes.append(i) |
| 1387 | + |
| 1388 | + assert len(indexes) == op_chain_len # Disable if you want to find longer op chains. |
| 1389 | + if len(indexes) == op_chain_len: |
| 1390 | + break |
| 1391 | + seed += 1 |
| 1392 | + print(f"{seed=}") |
| 1393 | + |
| 1394 | + code_constant_pool = sum((Op.PUSH32[n] for n in args), Bytecode()) |
| 1395 | + code_segment = ( |
| 1396 | + Op.CALLDATALOAD(0) |
| 1397 | + + sum(make_dup(len(args) - i) + Op.PUSH32[fixed_arg] + op for i in indexes) |
| 1398 | + + Op.POP |
| 1399 | + ) |
| 1400 | + # Construct the final code. Because of the usage of PUSH32 the code segment is very long, |
| 1401 | + # so don't try to include multiple of these. |
| 1402 | + code = code_constant_pool + Op.JUMPDEST + code_segment + Op.JUMP(len(code_constant_pool)) |
| 1403 | + assert (MAX_CODE_SIZE - len(code_segment)) < len(code) <= MAX_CODE_SIZE |
| 1404 | + |
| 1405 | + env = Environment() |
| 1406 | + |
| 1407 | + tx = Transaction( |
| 1408 | + to=pre.deploy_contract(code=code), |
| 1409 | + data=initial_mod.to_bytes(32, byteorder="big"), |
| 1410 | + gas_limit=env.gas_limit, |
| 1411 | + sender=pre.fund_eoa(), |
| 1412 | + ) |
| 1413 | + |
| 1414 | + state_test( |
| 1415 | + env=env, |
| 1416 | + pre=pre, |
| 1417 | + post={}, |
| 1418 | + tx=tx, |
| 1419 | + ) |
| 1420 | + |
| 1421 | + |
1318 | 1422 | @pytest.mark.valid_from("Cancun")
|
1319 | 1423 | def test_empty_block(
|
1320 | 1424 | blockchain_test: BlockchainTestFiller,
|
|
0 commit comments