Skip to content

zkevm: amortized bn128_pairings #1656

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
May 26, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 116 additions & 5 deletions tests/zkevm/test_worst_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@
from typing import cast

import pytest
from py_ecc.bn128 import G1, G2, multiply

from ethereum_test_base_types.base_types import Bytes
from ethereum_test_forks import Fork
from ethereum_test_tools import (
Address,
Expand All @@ -19,6 +21,7 @@
BlockchainTestFiller,
Bytecode,
Environment,
StateTestFiller,
Transaction,
)
from ethereum_test_tools.code.generators import While
Expand Down Expand Up @@ -295,10 +298,6 @@ def test_worst_modexp(
pytest.param(
0x08,
[
# TODO: the following are only two inputs, but this can be extended
# to more inputs to amortize costs as much as possible. Additionally,
# there might be worse pairings that can be used.
#
# First pairing
"1C76476F4DEF4BB94541D57EBBA1193381FFA7AA76ADA664DD31C16024C43F59",
"3034DD2920F673E204FEE2811C678745FC819B55D3E9D294E45C9B03A76AEF41",
Expand All @@ -314,7 +313,20 @@ def test_worst_modexp(
"090689D0585FF075EC9E99AD690C3395BC4B313370B38EF355ACDADCD122975B",
"12C85EA5DB8C6DEB4AAB71808DCB408FE3D1E7690C43D37B4CE6CC0166FA7DAA",
],
id="bn128_pairing",
id="bn128_two_pairings",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving this test as two_pairings to cover the point multiplication of the pairing check.

),
pytest.param(
0x08,
[
# First pairing
"1C76476F4DEF4BB94541D57EBBA1193381FFA7AA76ADA664DD31C16024C43F59",
"3034DD2920F673E204FEE2811C678745FC819B55D3E9D294E45C9B03A76AEF41",
"209DD15EBFF5D46C4BD888E51A93CF99A7329636C63514396B4A452003A35BF7",
"04BF11CA01483BFA8B34B43561848D28905960114C8AC04049AF4B6315A41678",
"2BB8324AF6CFC93537A2AD1A445CFD0CA2A71ACD7AC41FADBF933C2A51BE344D",
"120A2A4CF30C1BF9845F20C6FE39E07EA2CCE61F0C9BB048165FE5E4DE877550",
],
id="bn128_one_pairing",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating the case of one pairing, which maximizes the number of total final exponentiations in the block.

),
pytest.param(
Blake2bSpec.BLAKE2_PRECOMPILE_ADDRESS,
Expand Down Expand Up @@ -955,3 +967,102 @@ def test_empty_block(
post={},
blocks=[Block(txs=[])],
)


@pytest.mark.valid_from("Cancun")
@pytest.mark.slow()
def test_amortized_bn128_pairings(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the new test that optimizes for number of miller loops. That is, amortize the base cost (final exponentiation) in as many miller loops as possible while keeping in check quadratic memory expansion costs.

state_test: StateTestFiller,
pre: Alloc,
fork: Fork,
):
"""Test running a block with as many BN128 pairings as possible."""
env = Environment()

base_cost = 45_000
pairing_cost = 34_000
Comment on lines +982 to +983
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LMK if these costs can be pulled from somewhere. If not, we can leave it this way.

size_per_pairing = 192

gsc = fork.gas_costs()
intrinsic_gas_calculator = fork.transaction_intrinsic_cost_calculator()
mem_exp_gas_calculator = fork.memory_expansion_gas_calculator()

# This is a theoretical maximum number of pairings that can be done in a block.
# It is only used for an upper bound for calculating the optimal number of pairings below.
maximum_number_of_pairings = (env.gas_limit - base_cost) // pairing_cost

# Discover the optimal number of pairings balancing two dimensions:
# 1. Amortize the precompile base cost as much as possible.
# 2. The cost of the memory expansion.
max_pairings = 0
optimal_per_call_num_pairings = 0
for i in range(1, maximum_number_of_pairings + 1):
# We'll pass all pairing arguments via calldata.
available_gas_after_intrinsic = env.gas_limit - intrinsic_gas_calculator(
calldata=[0xFF] * size_per_pairing * i # 0xFF is to indicate non-zero bytes.
)
available_gas_after_expansion = max(
0,
available_gas_after_intrinsic - mem_exp_gas_calculator(new_bytes=i * size_per_pairing),
)

# This is ignoring "glue" opcodes, but helps to have a rough idea of the right
# cutting point.
approx_gas_cost_per_call = gsc.G_WARM_ACCOUNT_ACCESS + base_cost + i * pairing_cost

num_precompile_calls = available_gas_after_expansion // approx_gas_cost_per_call
num_pairings_done = num_precompile_calls * i # Each precompile call does i pairings.

if num_pairings_done > max_pairings:
max_pairings = num_pairings_done
optimal_per_call_num_pairings = i

calldata = Op.CALLDATACOPY(size=Op.CALLDATASIZE)
attack_block = Op.POP(Op.STATICCALL(Op.GAS, 0x08, 0, Op.CALLDATASIZE, 0, 0))
code = code_loop_precompile_call(calldata, attack_block)

code_address = pre.deploy_contract(code=code)

tx = Transaction(
to=code_address,
gas_limit=env.gas_limit,
data=_generate_bn128_pairs(optimal_per_call_num_pairings, 42),
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the discovered optimal number of pairings per precompile call, generate the corresponding pairs with a fixed seed.

sender=pre.fund_eoa(),
)

state_test(
env=env,
pre=pre,
post={},
tx=tx,
)


def _generate_bn128_pairs(n: int, seed: int = 0):
rng = random.Random(seed)
calldata = Bytes()

for _ in range(n):
priv_key_g1 = rng.randint(1, 2**32 - 1)
priv_key_g2 = rng.randint(1, 2**32 - 1)

point_x_affine = multiply(G1, priv_key_g1)
point_y_affine = multiply(G2, priv_key_g2)

assert point_x_affine is not None, "G1 multiplication resulted in point at infinity"
assert point_y_affine is not None, "G2 multiplication resulted in point at infinity"

g1_x_bytes = point_x_affine[0].n.to_bytes(32, "big")
g1_y_bytes = point_x_affine[1].n.to_bytes(32, "big")
g1_serialized = g1_x_bytes + g1_y_bytes

g2_x_c1_bytes = point_y_affine[0].coeffs[1].n.to_bytes(32, "big") # type: ignore
g2_x_c0_bytes = point_y_affine[0].coeffs[0].n.to_bytes(32, "big") # type: ignore
g2_y_c1_bytes = point_y_affine[1].coeffs[1].n.to_bytes(32, "big") # type: ignore
g2_y_c0_bytes = point_y_affine[1].coeffs[0].n.to_bytes(32, "big") # type: ignore
g2_serialized = g2_x_c1_bytes + g2_x_c0_bytes + g2_y_c1_bytes + g2_y_c0_bytes

pair_calldata = g1_serialized + g2_serialized
calldata = Bytes(calldata + pair_calldata)

return calldata
Loading