Skip to content

Commit 12bcc4b

Browse files
jsignmarioevz
andauthored
feat(tests/zkevm): amortized bn128_pairings (#1656)
* zkevm: add optimized bn128 pairings Signed-off-by: Ignacio Hagopian <jsign.uy@gmail.com> * keep original test too Signed-off-by: Ignacio Hagopian <jsign.uy@gmail.com> * add slow marker Signed-off-by: Ignacio Hagopian <jsign.uy@gmail.com> * add extra cases Signed-off-by: Ignacio Hagopian <jsign.uy@gmail.com> * add assertions Signed-off-by: Ignacio Hagopian <jsign.uy@gmail.com> * fix lints Signed-off-by: Ignacio Hagopian <jsign.uy@gmail.com> * use state_test Signed-off-by: Ignacio Hagopian <jsign.uy@gmail.com> * Update tests/zkevm/test_worst_compute.py --------- Signed-off-by: Ignacio Hagopian <jsign.uy@gmail.com> Co-authored-by: Mario Vega <marioevz@gmail.com>
1 parent 3810d22 commit 12bcc4b

File tree

1 file changed

+116
-5
lines changed

1 file changed

+116
-5
lines changed

tests/zkevm/test_worst_compute.py

Lines changed: 116 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,9 @@
1010
from typing import cast
1111

1212
import pytest
13+
from py_ecc.bn128 import G1, G2, multiply
1314

15+
from ethereum_test_base_types.base_types import Bytes
1416
from ethereum_test_forks import Fork
1517
from ethereum_test_tools import (
1618
Address,
@@ -19,6 +21,7 @@
1921
BlockchainTestFiller,
2022
Bytecode,
2123
Environment,
24+
StateTestFiller,
2225
Transaction,
2326
)
2427
from ethereum_test_tools.code.generators import While
@@ -295,10 +298,6 @@ def test_worst_modexp(
295298
pytest.param(
296299
0x08,
297300
[
298-
# TODO: the following are only two inputs, but this can be extended
299-
# to more inputs to amortize costs as much as possible. Additionally,
300-
# there might be worse pairings that can be used.
301-
#
302301
# First pairing
303302
"1C76476F4DEF4BB94541D57EBBA1193381FFA7AA76ADA664DD31C16024C43F59",
304303
"3034DD2920F673E204FEE2811C678745FC819B55D3E9D294E45C9B03A76AEF41",
@@ -314,7 +313,20 @@ def test_worst_modexp(
314313
"090689D0585FF075EC9E99AD690C3395BC4B313370B38EF355ACDADCD122975B",
315314
"12C85EA5DB8C6DEB4AAB71808DCB408FE3D1E7690C43D37B4CE6CC0166FA7DAA",
316315
],
317-
id="bn128_pairing",
316+
id="bn128_two_pairings",
317+
),
318+
pytest.param(
319+
0x08,
320+
[
321+
# First pairing
322+
"1C76476F4DEF4BB94541D57EBBA1193381FFA7AA76ADA664DD31C16024C43F59",
323+
"3034DD2920F673E204FEE2811C678745FC819B55D3E9D294E45C9B03A76AEF41",
324+
"209DD15EBFF5D46C4BD888E51A93CF99A7329636C63514396B4A452003A35BF7",
325+
"04BF11CA01483BFA8B34B43561848D28905960114C8AC04049AF4B6315A41678",
326+
"2BB8324AF6CFC93537A2AD1A445CFD0CA2A71ACD7AC41FADBF933C2A51BE344D",
327+
"120A2A4CF30C1BF9845F20C6FE39E07EA2CCE61F0C9BB048165FE5E4DE877550",
328+
],
329+
id="bn128_one_pairing",
318330
),
319331
pytest.param(
320332
Blake2bSpec.BLAKE2_PRECOMPILE_ADDRESS,
@@ -955,3 +967,102 @@ def test_empty_block(
955967
post={},
956968
blocks=[Block(txs=[])],
957969
)
970+
971+
972+
@pytest.mark.valid_from("Cancun")
973+
@pytest.mark.slow()
974+
def test_amortized_bn128_pairings(
975+
state_test: StateTestFiller,
976+
pre: Alloc,
977+
fork: Fork,
978+
):
979+
"""Test running a block with as many BN128 pairings as possible."""
980+
env = Environment()
981+
982+
base_cost = 45_000
983+
pairing_cost = 34_000
984+
size_per_pairing = 192
985+
986+
gsc = fork.gas_costs()
987+
intrinsic_gas_calculator = fork.transaction_intrinsic_cost_calculator()
988+
mem_exp_gas_calculator = fork.memory_expansion_gas_calculator()
989+
990+
# This is a theoretical maximum number of pairings that can be done in a block.
991+
# It is only used for an upper bound for calculating the optimal number of pairings below.
992+
maximum_number_of_pairings = (env.gas_limit - base_cost) // pairing_cost
993+
994+
# Discover the optimal number of pairings balancing two dimensions:
995+
# 1. Amortize the precompile base cost as much as possible.
996+
# 2. The cost of the memory expansion.
997+
max_pairings = 0
998+
optimal_per_call_num_pairings = 0
999+
for i in range(1, maximum_number_of_pairings + 1):
1000+
# We'll pass all pairing arguments via calldata.
1001+
available_gas_after_intrinsic = env.gas_limit - intrinsic_gas_calculator(
1002+
calldata=[0xFF] * size_per_pairing * i # 0xFF is to indicate non-zero bytes.
1003+
)
1004+
available_gas_after_expansion = max(
1005+
0,
1006+
available_gas_after_intrinsic - mem_exp_gas_calculator(new_bytes=i * size_per_pairing),
1007+
)
1008+
1009+
# This is ignoring "glue" opcodes, but helps to have a rough idea of the right
1010+
# cutting point.
1011+
approx_gas_cost_per_call = gsc.G_WARM_ACCOUNT_ACCESS + base_cost + i * pairing_cost
1012+
1013+
num_precompile_calls = available_gas_after_expansion // approx_gas_cost_per_call
1014+
num_pairings_done = num_precompile_calls * i # Each precompile call does i pairings.
1015+
1016+
if num_pairings_done > max_pairings:
1017+
max_pairings = num_pairings_done
1018+
optimal_per_call_num_pairings = i
1019+
1020+
calldata = Op.CALLDATACOPY(size=Op.CALLDATASIZE)
1021+
attack_block = Op.POP(Op.STATICCALL(Op.GAS, 0x08, 0, Op.CALLDATASIZE, 0, 0))
1022+
code = code_loop_precompile_call(calldata, attack_block)
1023+
1024+
code_address = pre.deploy_contract(code=code)
1025+
1026+
tx = Transaction(
1027+
to=code_address,
1028+
gas_limit=env.gas_limit,
1029+
data=_generate_bn128_pairs(optimal_per_call_num_pairings, 42),
1030+
sender=pre.fund_eoa(),
1031+
)
1032+
1033+
state_test(
1034+
env=env,
1035+
pre=pre,
1036+
post={},
1037+
tx=tx,
1038+
)
1039+
1040+
1041+
def _generate_bn128_pairs(n: int, seed: int = 0):
1042+
rng = random.Random(seed)
1043+
calldata = Bytes()
1044+
1045+
for _ in range(n):
1046+
priv_key_g1 = rng.randint(1, 2**32 - 1)
1047+
priv_key_g2 = rng.randint(1, 2**32 - 1)
1048+
1049+
point_x_affine = multiply(G1, priv_key_g1)
1050+
point_y_affine = multiply(G2, priv_key_g2)
1051+
1052+
assert point_x_affine is not None, "G1 multiplication resulted in point at infinity"
1053+
assert point_y_affine is not None, "G2 multiplication resulted in point at infinity"
1054+
1055+
g1_x_bytes = point_x_affine[0].n.to_bytes(32, "big")
1056+
g1_y_bytes = point_x_affine[1].n.to_bytes(32, "big")
1057+
g1_serialized = g1_x_bytes + g1_y_bytes
1058+
1059+
g2_x_c1_bytes = point_y_affine[0].coeffs[1].n.to_bytes(32, "big") # type: ignore
1060+
g2_x_c0_bytes = point_y_affine[0].coeffs[0].n.to_bytes(32, "big") # type: ignore
1061+
g2_y_c1_bytes = point_y_affine[1].coeffs[1].n.to_bytes(32, "big") # type: ignore
1062+
g2_y_c0_bytes = point_y_affine[1].coeffs[0].n.to_bytes(32, "big") # type: ignore
1063+
g2_serialized = g2_x_c1_bytes + g2_x_c0_bytes + g2_y_c1_bytes + g2_y_c0_bytes
1064+
1065+
pair_calldata = g1_serialized + g2_serialized
1066+
calldata = Bytes(calldata + pair_calldata)
1067+
1068+
return calldata

0 commit comments

Comments
 (0)