Skip to content
Open
Show file tree
Hide file tree
Changes from 26 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 3 additions & 49 deletions circuit-prover/src/batch_stark_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,7 @@ use p3_air::{Air, AirBuilder, BaseAir};
use p3_batch_stark::{BatchProof, StarkGenericConfig as MSGC, StarkInstance, Val as MVal};
use p3_circuit::tables::Traces;
use p3_field::{BasedVectorSpace, Field};
use p3_matrix::Matrix;
use p3_matrix::dense::RowMajorMatrix;
use p3_mmcs_air::air::{MmcsTableConfig, MmcsVerifyAir};
use thiserror::Error;
use tracing::instrument;

Expand All @@ -28,11 +26,10 @@ pub enum Table {
Public = 2,
Add = 3,
Mul = 4,
Mmcs = 5,
}

/// Number of circuit tables included in the unified batch STARK proof.
pub const NUM_TABLES: usize = Table::Mmcs as usize + 1;
pub const NUM_TABLES: usize = Table::Mul as usize + 1;

/// Row counts wrapper with type-safe indexing.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
Expand Down Expand Up @@ -85,8 +82,6 @@ where
pub ext_degree: usize,
/// The binomial coefficient `W` for extension field multiplication, if `ext_degree > 1`.
pub w_binomial: Option<MVal<SC>>,
/// The configuration for the MMCS table.
pub mmcs_config: MmcsTableConfig,
}

impl<SC> core::fmt::Debug for BatchStarkProof<SC>
Expand All @@ -99,7 +94,6 @@ where
.field("rows", &self.rows)
.field("ext_degree", &self.ext_degree)
.field("w_binomial", &self.w_binomial)
.field("mmcs_config", &self.mmcs_config)
.finish()
}
}
Expand All @@ -111,7 +105,6 @@ where
{
config: SC,
table_packing: TablePacking,
mmcs_config: MmcsTableConfig,
}

/// Errors for the batch STARK table prover.
Expand All @@ -137,7 +130,6 @@ enum CircuitTableAir<F: Field, const D: usize> {
Public(PublicAir<F, D>),
Add(AddAir<F, D>),
Mul(MulAir<F, D>),
Mmcs(MmcsVerifyAir<F>),
}

impl<F: Field, const D: usize> BaseAir<F> for CircuitTableAir<F, D> {
Expand All @@ -148,7 +140,6 @@ impl<F: Field, const D: usize> BaseAir<F> for CircuitTableAir<F, D> {
Self::Public(a) => a.width(),
Self::Add(a) => a.width(),
Self::Mul(a) => a.width(),
Self::Mmcs(a) => a.width(),
}
}
}
Expand All @@ -165,7 +156,6 @@ where
Self::Public(a) => a.eval(builder),
Self::Add(a) => a.eval(builder),
Self::Mul(a) => a.eval(builder),
Self::Mmcs(a) => a.eval(builder),
}
}
}
Expand All @@ -179,7 +169,6 @@ where
Self {
config,
table_packing: TablePacking::default(),
mmcs_config: MmcsTableConfig::default(),
}
}

Expand All @@ -189,12 +178,6 @@ where
self
}

#[must_use]
pub fn with_mmcs_config(mut self, mmcs_config: MmcsTableConfig) -> Self {
self.mmcs_config = mmcs_config;
self
}

#[inline]
pub const fn table_packing(&self) -> TablePacking {
self.table_packing
Expand Down Expand Up @@ -289,19 +272,12 @@ where
let mul_matrix: RowMajorMatrix<MVal<SC>> =
MulAir::<MVal<SC>, D>::trace_to_matrix(&traces.mul_trace, mul_lanes);

// Mmcs
let mmcs_air = MmcsVerifyAir::<MVal<SC>>::new(self.mmcs_config);
let mmcs_matrix: RowMajorMatrix<MVal<SC>> =
MmcsVerifyAir::trace_to_matrix(&self.mmcs_config, &traces.mmcs_trace);
let mmcs_rows: usize = mmcs_matrix.height();

// Wrap AIRs in enum for heterogeneous batching and build instances in fixed order.
let air_witness = CircuitTableAir::Witness(witness_air);
let air_const = CircuitTableAir::Const(const_air);
let air_public = CircuitTableAir::Public(public_air);
let air_add = CircuitTableAir::Add(add_air);
let air_mul = CircuitTableAir::Mul(mul_air);
let air_mmcs = CircuitTableAir::Mmcs(mmcs_air);

// Pre-size for performance
let mut instances = Vec::with_capacity(NUM_TABLES);
Expand Down Expand Up @@ -331,29 +307,16 @@ where
trace: mul_matrix,
public_values: vec![],
},
StarkInstance {
air: &air_mmcs,
trace: mmcs_matrix,
public_values: vec![],
},
]);

let proof = p3_batch_stark::prove_batch(&self.config, instances);

Ok(BatchStarkProof {
proof,
table_packing: packing,
rows: RowCounts::new([
witness_rows,
const_rows,
public_rows,
add_rows,
mul_rows,
mmcs_rows,
]),
rows: RowCounts::new([witness_rows, const_rows, public_rows, add_rows, mul_rows]),
ext_degree: D,
w_binomial: if D > 1 { w_binomial } else { None },
mmcs_config: self.mmcs_config,
})
}

Expand Down Expand Up @@ -396,16 +359,7 @@ where
w,
))
};
let mmcs_air = CircuitTableAir::Mmcs(MmcsVerifyAir::<MVal<SC>>::new(proof.mmcs_config));

let airs = vec![
witness_air,
const_air,
public_air,
add_air,
mul_air,
mmcs_air,
];
let airs = vec![witness_air, const_air, public_air, add_air, mul_air];
// TODO: Handle public values.
let pvs: Vec<Vec<MVal<SC>>> = vec![Vec::new(); NUM_TABLES];

Expand Down
2 changes: 2 additions & 0 deletions recursion/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ categories.workspace = true
[dependencies]
# Plonky3 dependencies
p3-air.workspace = true
p3-batch-stark.workspace = true
p3-challenger.workspace = true
p3-circuit-prover.workspace = true
p3-commit.workspace = true
p3-field.workspace = true
p3-fri.workspace = true
Expand Down
163 changes: 162 additions & 1 deletion recursion/src/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@ use alloc::vec::Vec;

use itertools::zip_eq;
use p3_air::Air;
use p3_batch_stark::BatchProof;
use p3_batch_stark::config::{observe_base_as_ext, observe_instance_binding};
use p3_challenger::{CanObserve, CanSample, CanSampleBits, FieldChallenger, GrindingChallenger};
use p3_commit::{BatchOpening, Mmcs, Pcs, PolynomialSpace};
use p3_field::{Field, PrimeCharacteristicRing, PrimeField, TwoAdicField};
use p3_field::{BasedVectorSpace, Field, PrimeCharacteristicRing, PrimeField, TwoAdicField};
use p3_fri::{FriProof, TwoAdicFriPcs};
use p3_uni_stark::{
Domain, Proof, StarkGenericConfig, SymbolicAirBuilder, Val, VerifierConstraintFolder,
Expand All @@ -28,6 +30,9 @@ pub enum GenerationError {

#[error("Witness check failed during challenge generation.")]
InvalidPowWitness,

#[error("Invalid proof shape: {0}")]
InvalidProofShape(&'static str),
}

/// A type alias for a single opening point and its values.
Expand Down Expand Up @@ -182,6 +187,162 @@ where
Ok(challenges)
}

/// Generates the challenges used in the verification of a batch-STARK proof.
pub fn generate_batch_challenges<SC: StarkGenericConfig, A>(
airs: &[A],
config: &SC,
proof: &BatchProof<SC>,
public_values: &[Vec<Val<SC>>],
extra_params: Option<&[usize]>,
) -> Result<Vec<SC::Challenge>, GenerationError>
where
A: Air<SymbolicAirBuilder<Val<SC>>> + for<'a> Air<VerifierConstraintFolder<'a, SC>>,
SC::Pcs: PcsGeneration<SC, <SC::Pcs as Pcs<SC::Challenge, SC::Challenger>>::Proof>,
{
debug_assert_eq!(config.is_zk(), 0, "batch recursion assumes non-ZK");
if SC::Pcs::ZK {
return Err(GenerationError::InvalidProofShape(
"batch-STARK challenge generation does not support ZK mode",
));
}

let BatchProof {
commitments,
opened_values,
opening_proof,
degree_bits,
} = proof;

let n_instances = airs.len();
if n_instances == 0
|| opened_values.instances.len() != n_instances
|| public_values.len() != n_instances
|| degree_bits.len() != n_instances
{
return Err(GenerationError::InvalidProofShape(
"instance metadata length mismatch",
));
}

let pcs = config.pcs();
let mut challenger = config.initialise_challenger();

observe_base_as_ext::<SC>(&mut challenger, Val::<SC>::from_usize(n_instances));

for inst in &opened_values.instances {
if inst
.quotient_chunks
.iter()
.any(|c| c.len() != SC::Challenge::DIMENSION)
{
return Err(GenerationError::InvalidProofShape(
"invalid quotient chunk length",
));
}
}

let mut log_quotient_degrees = Vec::with_capacity(n_instances);
let mut quotient_degrees = Vec::with_capacity(n_instances);
for (air, pv) in airs.iter().zip(public_values.iter()) {
let log_qd = get_log_quotient_degree::<Val<SC>, A>(air, 0, pv.len(), config.is_zk());
let quotient_degree = 1 << (log_qd + config.is_zk());
log_quotient_degrees.push(log_qd);
quotient_degrees.push(quotient_degree);
}

for i in 0..n_instances {
let ext_db = degree_bits[i];
let base_db =
ext_db
.checked_sub(config.is_zk())
.ok_or(GenerationError::InvalidProofShape(
"extended degree smaller than zk adjustment",
))?;
observe_instance_binding::<SC>(
&mut challenger,
ext_db,
base_db,
A::width(&airs[i]),
quotient_degrees[i],
);
}

challenger.observe(commitments.main.clone());
for pv in public_values {
challenger.observe_slice(pv);
}
let alpha = challenger.sample_algebra_element();

challenger.observe(commitments.quotient_chunks.clone());
let zeta = challenger.sample_algebra_element();

let ext_trace_domains: Vec<_> = degree_bits
.iter()
.map(|&ext_db| pcs.natural_domain_for_degree(1 << ext_db))
.collect();

let mut coms_to_verify = Vec::new();

let trace_round = ext_trace_domains
.iter()
.zip(opened_values.instances.iter())
.map(|(ext_dom, inst)| {
let zeta_next = ext_dom
.next_point(zeta)
.ok_or(GenerationError::InvalidProofShape(
"trace domain lacks next point",
))?;
Ok((
*ext_dom,
vec![
(zeta, inst.trace_local.clone()),
(zeta_next, inst.trace_next.clone()),
],
))
})
.collect::<Result<Vec<_>, GenerationError>>()?;
coms_to_verify.push((commitments.main.clone(), trace_round));

let quotient_domains: Vec<Vec<_>> = degree_bits
.iter()
.zip(ext_trace_domains.iter())
.zip(log_quotient_degrees.iter())
.map(|((&ext_db, ext_dom), &log_qd)| {
let base_db = ext_db - config.is_zk();
let q_domain = ext_dom.create_disjoint_domain(1 << (base_db + log_qd + config.is_zk()));
q_domain.split_domains(1 << (log_qd + config.is_zk()))
})
.collect();

let mut quotient_round = Vec::new();
for (domains, inst) in quotient_domains.iter().zip(opened_values.instances.iter()) {
if inst.quotient_chunks.len() != domains.len() {
return Err(GenerationError::InvalidProofShape(
"quotient chunk count mismatch",
));
}
for (domain, values) in domains.iter().zip(inst.quotient_chunks.iter()) {
quotient_round.push((*domain, vec![(zeta, values.clone())]));
}
}
coms_to_verify.push((commitments.quotient_chunks.clone(), quotient_round));

let pcs_challenges = pcs.generate_challenges(
config,
&mut challenger,
&coms_to_verify,
opening_proof,
extra_params,
)?;

let mut challenges = Vec::with_capacity(2 + pcs_challenges.len());
challenges.push(alpha);
challenges.push(zeta);
challenges.extend(pcs_challenges);

Ok(challenges)
}

type InnerFriProof<SC, InputMmcs, FriMmcs> = FriProof<
<SC as StarkGenericConfig>::Challenge,
FriMmcs,
Expand Down
14 changes: 10 additions & 4 deletions recursion/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,14 @@ pub mod types;
pub mod verifier;

pub use challenger::CircuitChallenger;
pub use generation::{GenerationError, PcsGeneration, generate_challenges};
pub use generation::{
GenerationError, PcsGeneration, generate_batch_challenges, generate_challenges,
};
pub use pcs::fri::{FriVerifierParams, MAX_QUERY_INDEX_BITS};
pub use public_inputs::{
CommitmentOpening, FriVerifierInputs, PublicInputBuilder, StarkVerifierInputs,
StarkVerifierInputsBuilder, construct_stark_verifier_inputs,
BatchStarkVerifierInputsBuilder, CommitmentOpening, FriVerifierInputs, PublicInputBuilder,
StarkVerifierInputs, StarkVerifierInputsBuilder, construct_batch_stark_verifier_inputs,
construct_stark_verifier_inputs,
};
pub use traits::{
Recursive, RecursiveAir, RecursiveChallenger, RecursiveExtensionMmcs, RecursiveMmcs,
Expand All @@ -28,4 +31,7 @@ pub use types::{
CommitmentTargets, OpenedValuesTargets, ProofTargets, RecursiveLagrangeSelectors,
StarkChallenges, Target,
};
pub use verifier::{ObservableCommitment, VerificationError, verify_circuit};
pub use verifier::{
BatchProofTargets, InstanceOpenedValuesTargets, ObservableCommitment, VerificationError,
verify_batch_circuit, verify_circuit,
};
Loading
Loading