diff --git a/circuit-prover/src/batch_stark_prover.rs b/circuit-prover/src/batch_stark_prover.rs index 24f8d88..e9aa70a 100644 --- a/circuit-prover/src/batch_stark_prover.rs +++ b/circuit-prover/src/batch_stark_prover.rs @@ -9,9 +9,7 @@ use p3_air::{Air, AirBuilder, BaseAir}; use p3_batch_stark::{BatchProof, StarkGenericConfig as MSGC, StarkInstance, Val as MVal}; use p3_circuit::tables::Traces; use p3_field::{BasedVectorSpace, Field}; -use p3_matrix::Matrix; use p3_matrix::dense::RowMajorMatrix; -use p3_mmcs_air::air::{MmcsTableConfig, MmcsVerifyAir}; use thiserror::Error; use tracing::instrument; @@ -28,11 +26,11 @@ pub enum Table { Public = 2, Add = 3, Mul = 4, - Mmcs = 5, } +// TODO(Robin): Remove with dynamic dispatch /// Number of circuit tables included in the unified batch STARK proof. -pub const NUM_TABLES: usize = Table::Mmcs as usize + 1; +pub const NUM_TABLES: usize = Table::Mul as usize + 1; /// Row counts wrapper with type-safe indexing. #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -85,8 +83,6 @@ where pub ext_degree: usize, /// The binomial coefficient `W` for extension field multiplication, if `ext_degree > 1`. pub w_binomial: Option>, - /// The configuration for the MMCS table. - pub mmcs_config: MmcsTableConfig, } impl core::fmt::Debug for BatchStarkProof @@ -99,7 +95,6 @@ where .field("rows", &self.rows) .field("ext_degree", &self.ext_degree) .field("w_binomial", &self.w_binomial) - .field("mmcs_config", &self.mmcs_config) .finish() } } @@ -111,7 +106,6 @@ where { config: SC, table_packing: TablePacking, - mmcs_config: MmcsTableConfig, } /// Errors for the batch STARK table prover. @@ -137,7 +131,6 @@ enum CircuitTableAir { Public(PublicAir), Add(AddAir), Mul(MulAir), - Mmcs(MmcsVerifyAir), } impl BaseAir for CircuitTableAir { @@ -148,7 +141,6 @@ impl BaseAir for CircuitTableAir { Self::Public(a) => a.width(), Self::Add(a) => a.width(), Self::Mul(a) => a.width(), - Self::Mmcs(a) => a.width(), } } } @@ -165,7 +157,6 @@ where Self::Public(a) => a.eval(builder), Self::Add(a) => a.eval(builder), Self::Mul(a) => a.eval(builder), - Self::Mmcs(a) => a.eval(builder), } } } @@ -179,7 +170,6 @@ where Self { config, table_packing: TablePacking::default(), - mmcs_config: MmcsTableConfig::default(), } } @@ -189,12 +179,6 @@ where self } - #[must_use] - pub fn with_mmcs_config(mut self, mmcs_config: MmcsTableConfig) -> Self { - self.mmcs_config = mmcs_config; - self - } - #[inline] pub const fn table_packing(&self) -> TablePacking { self.table_packing @@ -289,19 +273,12 @@ where let mul_matrix: RowMajorMatrix> = MulAir::, D>::trace_to_matrix(&traces.mul_trace, mul_lanes); - // Mmcs - let mmcs_air = MmcsVerifyAir::>::new(self.mmcs_config); - let mmcs_matrix: RowMajorMatrix> = - MmcsVerifyAir::trace_to_matrix(&self.mmcs_config, &traces.mmcs_trace); - let mmcs_rows: usize = mmcs_matrix.height(); - // Wrap AIRs in enum for heterogeneous batching and build instances in fixed order. let air_witness = CircuitTableAir::Witness(witness_air); let air_const = CircuitTableAir::Const(const_air); let air_public = CircuitTableAir::Public(public_air); let air_add = CircuitTableAir::Add(add_air); let air_mul = CircuitTableAir::Mul(mul_air); - let air_mmcs = CircuitTableAir::Mmcs(mmcs_air); // Pre-size for performance let mut instances = Vec::with_capacity(NUM_TABLES); @@ -331,11 +308,6 @@ where trace: mul_matrix, public_values: vec![], }, - StarkInstance { - air: &air_mmcs, - trace: mmcs_matrix, - public_values: vec![], - }, ]); let proof = p3_batch_stark::prove_batch(&self.config, instances); @@ -343,17 +315,9 @@ where Ok(BatchStarkProof { proof, table_packing: packing, - rows: RowCounts::new([ - witness_rows, - const_rows, - public_rows, - add_rows, - mul_rows, - mmcs_rows, - ]), + rows: RowCounts::new([witness_rows, const_rows, public_rows, add_rows, mul_rows]), ext_degree: D, w_binomial: if D > 1 { w_binomial } else { None }, - mmcs_config: self.mmcs_config, }) } @@ -396,16 +360,7 @@ where w, )) }; - let mmcs_air = CircuitTableAir::Mmcs(MmcsVerifyAir::>::new(proof.mmcs_config)); - - let airs = vec![ - witness_air, - const_air, - public_air, - add_air, - mul_air, - mmcs_air, - ]; + let airs = vec![witness_air, const_air, public_air, add_air, mul_air]; // TODO: Handle public values. let pvs: Vec>> = vec![Vec::new(); NUM_TABLES]; diff --git a/recursion/Cargo.toml b/recursion/Cargo.toml index 61db370..fc60c5e 100644 --- a/recursion/Cargo.toml +++ b/recursion/Cargo.toml @@ -12,7 +12,9 @@ categories.workspace = true [dependencies] # Plonky3 dependencies p3-air.workspace = true +p3-batch-stark.workspace = true p3-challenger.workspace = true +p3-circuit-prover.workspace = true p3-commit.workspace = true p3-field.workspace = true p3-fri.workspace = true diff --git a/recursion/src/generation.rs b/recursion/src/generation.rs index 911a5b9..c2e2711 100644 --- a/recursion/src/generation.rs +++ b/recursion/src/generation.rs @@ -3,9 +3,11 @@ use alloc::vec::Vec; use itertools::zip_eq; use p3_air::Air; +use p3_batch_stark::BatchProof; +use p3_batch_stark::config::{observe_base_as_ext, observe_instance_binding}; use p3_challenger::{CanObserve, CanSample, CanSampleBits, FieldChallenger, GrindingChallenger}; use p3_commit::{BatchOpening, Mmcs, Pcs, PolynomialSpace}; -use p3_field::{Field, PrimeCharacteristicRing, PrimeField, TwoAdicField}; +use p3_field::{BasedVectorSpace, Field, PrimeCharacteristicRing, PrimeField, TwoAdicField}; use p3_fri::{FriProof, TwoAdicFriPcs}; use p3_uni_stark::{ Domain, Proof, StarkGenericConfig, SymbolicAirBuilder, Val, VerifierConstraintFolder, @@ -28,6 +30,9 @@ pub enum GenerationError { #[error("Witness check failed during challenge generation.")] InvalidPowWitness, + + #[error("Invalid proof shape: {0}")] + InvalidProofShape(&'static str), } /// A type alias for a single opening point and its values. @@ -182,6 +187,162 @@ where Ok(challenges) } +/// Generates the challenges used in the verification of a batch-STARK proof. +pub fn generate_batch_challenges( + airs: &[A], + config: &SC, + proof: &BatchProof, + public_values: &[Vec>], + extra_params: Option<&[usize]>, +) -> Result, GenerationError> +where + A: Air>> + for<'a> Air>, + SC::Pcs: PcsGeneration>::Proof>, +{ + debug_assert_eq!(config.is_zk(), 0, "batch recursion assumes non-ZK"); + if SC::Pcs::ZK { + return Err(GenerationError::InvalidProofShape( + "batch-STARK challenge generation does not support ZK mode", + )); + } + + let BatchProof { + commitments, + opened_values, + opening_proof, + degree_bits, + } = proof; + + let n_instances = airs.len(); + if n_instances == 0 + || opened_values.instances.len() != n_instances + || public_values.len() != n_instances + || degree_bits.len() != n_instances + { + return Err(GenerationError::InvalidProofShape( + "instance metadata length mismatch", + )); + } + + let pcs = config.pcs(); + let mut challenger = config.initialise_challenger(); + + observe_base_as_ext::(&mut challenger, Val::::from_usize(n_instances)); + + for inst in &opened_values.instances { + if inst + .quotient_chunks + .iter() + .any(|c| c.len() != SC::Challenge::DIMENSION) + { + return Err(GenerationError::InvalidProofShape( + "invalid quotient chunk length", + )); + } + } + + let mut log_quotient_degrees = Vec::with_capacity(n_instances); + let mut quotient_degrees = Vec::with_capacity(n_instances); + for (air, pv) in airs.iter().zip(public_values.iter()) { + let log_qd = get_log_quotient_degree::, A>(air, 0, pv.len(), config.is_zk()); + let quotient_degree = 1 << (log_qd + config.is_zk()); + log_quotient_degrees.push(log_qd); + quotient_degrees.push(quotient_degree); + } + + for i in 0..n_instances { + let ext_db = degree_bits[i]; + let base_db = + ext_db + .checked_sub(config.is_zk()) + .ok_or(GenerationError::InvalidProofShape( + "extended degree smaller than zk adjustment", + ))?; + observe_instance_binding::( + &mut challenger, + ext_db, + base_db, + A::width(&airs[i]), + quotient_degrees[i], + ); + } + + challenger.observe(commitments.main.clone()); + for pv in public_values { + challenger.observe_slice(pv); + } + let alpha = challenger.sample_algebra_element(); + + challenger.observe(commitments.quotient_chunks.clone()); + let zeta = challenger.sample_algebra_element(); + + let ext_trace_domains: Vec<_> = degree_bits + .iter() + .map(|&ext_db| pcs.natural_domain_for_degree(1 << ext_db)) + .collect(); + + let mut coms_to_verify = Vec::new(); + + let trace_round = ext_trace_domains + .iter() + .zip(opened_values.instances.iter()) + .map(|(ext_dom, inst)| { + let zeta_next = ext_dom + .next_point(zeta) + .ok_or(GenerationError::InvalidProofShape( + "trace domain lacks next point", + ))?; + Ok(( + *ext_dom, + vec![ + (zeta, inst.trace_local.clone()), + (zeta_next, inst.trace_next.clone()), + ], + )) + }) + .collect::, GenerationError>>()?; + coms_to_verify.push((commitments.main.clone(), trace_round)); + + let quotient_domains: Vec> = degree_bits + .iter() + .zip(ext_trace_domains.iter()) + .zip(log_quotient_degrees.iter()) + .map(|((&ext_db, ext_dom), &log_qd)| { + let base_db = ext_db - config.is_zk(); + let q_domain = ext_dom.create_disjoint_domain(1 << (base_db + log_qd + config.is_zk())); + q_domain.split_domains(1 << (log_qd + config.is_zk())) + }) + .collect(); + + let mut quotient_round = Vec::new(); + for (domains, inst) in quotient_domains.iter().zip(opened_values.instances.iter()) { + if inst.quotient_chunks.len() != domains.len() { + return Err(GenerationError::InvalidProofShape( + "quotient chunk count mismatch", + )); + } + for (domain, values) in domains.iter().zip(inst.quotient_chunks.iter()) { + quotient_round.push((*domain, vec![(zeta, values.clone())])); + } + } + coms_to_verify.push((commitments.quotient_chunks.clone(), quotient_round)); + + let pcs_challenges = pcs.generate_challenges( + config, + &mut challenger, + &coms_to_verify, + opening_proof, + extra_params, + )?; + + let mut challenges = Vec::with_capacity(2 + pcs_challenges.len()); + challenges.push(alpha); + challenges.push(zeta); + challenges.extend(pcs_challenges); + + Ok(challenges) +} + type InnerFriProof = FriProof< ::Challenge, FriMmcs, diff --git a/recursion/src/lib.rs b/recursion/src/lib.rs index 5b1601b..c7fd0e9 100644 --- a/recursion/src/lib.rs +++ b/recursion/src/lib.rs @@ -14,11 +14,14 @@ pub mod types; pub mod verifier; pub use challenger::CircuitChallenger; -pub use generation::{GenerationError, PcsGeneration, generate_challenges}; +pub use generation::{ + GenerationError, PcsGeneration, generate_batch_challenges, generate_challenges, +}; pub use pcs::fri::{FriVerifierParams, MAX_QUERY_INDEX_BITS}; pub use public_inputs::{ - CommitmentOpening, FriVerifierInputs, PublicInputBuilder, StarkVerifierInputs, - StarkVerifierInputsBuilder, construct_stark_verifier_inputs, + BatchStarkVerifierInputsBuilder, CommitmentOpening, FriVerifierInputs, PublicInputBuilder, + StarkVerifierInputs, StarkVerifierInputsBuilder, construct_batch_stark_verifier_inputs, + construct_stark_verifier_inputs, }; pub use traits::{ Recursive, RecursiveAir, RecursiveChallenger, RecursiveExtensionMmcs, RecursiveMmcs, @@ -28,4 +31,7 @@ pub use types::{ CommitmentTargets, OpenedValuesTargets, ProofTargets, RecursiveLagrangeSelectors, StarkChallenges, Target, }; -pub use verifier::{ObservableCommitment, VerificationError, verify_circuit}; +pub use verifier::{ + BatchProofTargets, InstanceOpenedValuesTargets, ObservableCommitment, VerificationError, + verify_batch_circuit, verify_circuit, +}; diff --git a/recursion/src/pcs/fri/targets.rs b/recursion/src/pcs/fri/targets.rs index 6e142ab..a25fd7e 100644 --- a/recursion/src/pcs/fri/targets.rs +++ b/recursion/src/pcs/fri/targets.rs @@ -437,6 +437,7 @@ where InputProofTargets, SC::Challenge, RecursiveInputMmcs>, >; + /// Observes all opened values and derives PCS-specific challenges. fn get_challenges_circuit( circuit: &mut CircuitBuilder, challenger: &mut CircuitChallenger, @@ -446,7 +447,6 @@ where ) -> Vec { let fri_proof = &proof_targets.opening_proof; - // Observe all opened values (trace, quotient chunks, random) opened_values.observe(circuit, challenger); // Sample FRI alpha (for batch opening reduction) diff --git a/recursion/src/public_inputs.rs b/recursion/src/public_inputs.rs index 4bb8589..4529e93 100644 --- a/recursion/src/public_inputs.rs +++ b/recursion/src/public_inputs.rs @@ -3,6 +3,7 @@ use alloc::vec::Vec; +use p3_batch_stark::BatchProof; use p3_circuit::CircuitBuilder; use p3_commit::Pcs; use p3_field::{BasedVectorSpace, Field, PrimeField64}; @@ -11,6 +12,7 @@ use p3_uni_stark::{Proof, StarkGenericConfig, Val}; use crate::ProofTargets; use crate::pcs::MAX_QUERY_INDEX_BITS; use crate::traits::Recursive; +use crate::verifier::BatchProofTargets; /// Builder for constructing public inputs. /// @@ -239,6 +241,44 @@ where .build() } +/// Constructs the public input values for a multi-instance STARK verification circuit. +pub fn construct_batch_stark_verifier_inputs( + air_public_values: &[Vec], + proof_values: &[EF], + challenges: &[EF], + num_queries: usize, +) -> Vec +where + F: Field + PrimeField64, + EF: Field + BasedVectorSpace + From, +{ + let mut builder = PublicInputBuilder::new(); + + for instance_pv in air_public_values { + builder.add_proof_values(instance_pv.iter().map(|&v| v.into())); + } + + builder.add_proof_values(proof_values.iter().copied()); + builder.add_challenges(challenges.iter().copied()); + + let num_regular_challenges = challenges.len().saturating_sub(num_queries); + for &query_index in &challenges[num_regular_challenges..] { + let coeffs = query_index.as_basis_coefficients_slice(); + let index_usize = coeffs[0].as_canonical_u64() as usize; + + for k in 0..MAX_QUERY_INDEX_BITS { + let bit: EF = if (index_usize >> k) & 1 == 1 { + EF::ONE + } else { + EF::ZERO + }; + builder.add_challenge(bit); + } + } + + builder.build() +} + /// Builder that handles both target allocation during circuit creation and value packing during execution. /// /// # Example @@ -335,6 +375,81 @@ where } } +/// Builder for multi-instance STARK verification circuits. +pub struct BatchStarkVerifierInputsBuilder +where + SC: StarkGenericConfig, + Comm: Recursive< + SC::Challenge, + Input = >::Commitment, + >, + OpeningProof: + Recursive>::Proof>, +{ + /// AIR public input targets per instance. + pub air_public_targets: Vec>, + /// Allocated proof structure targets. + pub proof_targets: BatchProofTargets, +} + +impl BatchStarkVerifierInputsBuilder +where + SC: StarkGenericConfig, + Comm: Recursive< + SC::Challenge, + Input = >::Commitment, + >, + OpeningProof: + Recursive>::Proof>, +{ + /// Allocate all targets during circuit building. + pub fn allocate( + circuit: &mut CircuitBuilder, + proof: &BatchProof, + air_public_counts: &[usize], + ) -> Self { + assert_eq!( + air_public_counts.len(), + proof.opened_values.instances.len(), + "public input count must match number of instances" + ); + + let air_public_targets = air_public_counts + .iter() + .map(|&count| (0..count).map(|_| circuit.add_public_input()).collect()) + .collect(); + + let proof_targets = BatchProofTargets::new(circuit, proof); + + Self { + air_public_targets, + proof_targets, + } + } + + /// Pack actual values in the same order as allocated targets. + pub fn pack_values( + &self, + air_public_values: &[Vec>], + proof: &BatchProof, + challenges: &[SC::Challenge], + num_queries: usize, + ) -> Vec + where + Val: PrimeField64, + SC::Challenge: BasedVectorSpace> + From>, + { + let proof_values = BatchProofTargets::::get_values(proof); + + construct_batch_stark_verifier_inputs( + air_public_values, + &proof_values, + challenges, + num_queries, + ) + } +} + #[cfg(test)] mod tests { use p3_baby_bear::BabyBear; diff --git a/recursion/src/verifier/batch_stark.rs b/recursion/src/verifier/batch_stark.rs new file mode 100644 index 0000000..d378fa8 --- /dev/null +++ b/recursion/src/verifier/batch_stark.rs @@ -0,0 +1,550 @@ +use alloc::string::ToString; +use alloc::vec::Vec; +use alloc::{format, vec}; +use core::marker::PhantomData; + +use p3_air::{Air as P3Air, AirBuilder as P3AirBuilder, BaseAir as P3BaseAir}; +use p3_batch_stark::BatchProof; +use p3_circuit::CircuitBuilder; +use p3_circuit::utils::ColumnsTargets; +use p3_circuit_prover::air::{AddAir, ConstAir, MulAir, PublicAir, WitnessAir}; +use p3_circuit_prover::batch_stark_prover::{RowCounts, Table}; +use p3_commit::{Pcs, PolynomialSpace}; +use p3_field::{BasedVectorSpace, Field, PrimeCharacteristicRing}; +use p3_uni_stark::StarkGenericConfig; + +use super::{ObservableCommitment, VerificationError, recompose_quotient_from_chunks_circuit}; +use crate::challenger::CircuitChallenger; +use crate::traits::{Recursive, RecursiveAir, RecursiveChallenger, RecursivePcs}; +use crate::types::{CommitmentTargets, OpenedValuesTargets, ProofTargets}; +use crate::{BatchStarkVerifierInputsBuilder, Target}; + +/// Type alias for PCS verifier parameters. +pub type PcsVerifierParams = + <::Pcs as RecursivePcs< + SC, + InputProof, + OpeningProof, + Comm, + <::Pcs as Pcs< + ::Challenge, + ::Challenger, + >>::Domain, + >>::VerifierParams; + +// TODO(Robin): Remove with dynamic dispatch +/// Wrapper enum for heterogeneous circuit table AIRs used by circuit-prover tables. +pub enum CircuitTablesAir { + Witness(WitnessAir), + Const(ConstAir), + Public(PublicAir), + Add(AddAir), + Mul(MulAir), +} + +impl P3BaseAir for CircuitTablesAir { + fn width(&self) -> usize { + match self { + Self::Witness(a) => P3BaseAir::width(a), + Self::Const(a) => P3BaseAir::width(a), + Self::Public(a) => P3BaseAir::width(a), + Self::Add(a) => P3BaseAir::width(a), + Self::Mul(a) => P3BaseAir::width(a), + } + } +} + +impl P3Air for CircuitTablesAir +where + AB: P3AirBuilder, + AB::F: Field, +{ + fn eval(&self, builder: &mut AB) { + match self { + Self::Witness(a) => a.eval(builder), + Self::Const(a) => a.eval(builder), + Self::Public(a) => a.eval(builder), + Self::Add(a) => a.eval(builder), + Self::Mul(a) => a.eval(builder), + } + } +} + +/// Build and attach a recursive verifier circuit for a circuit-prover BatchStarkProof. +/// +/// This reconstructs the circuit table AIRs from the proof metadata (rows + packing) so callers +/// don't need to pass `circuit_airs` explicitly. Returns the allocated input builder to pack +/// public inputs afterwards. +pub fn verify_p3_recursion_proof_circuit< + SC: StarkGenericConfig, + Comm: Recursive< + SC::Challenge, + Input = >::Commitment, + > + Clone + + ObservableCommitment, + InputProof: Recursive, + OpeningProof: Recursive>::Proof>, + const RATE: usize, + const TRACE_D: usize, +>( + config: &SC, + circuit: &mut CircuitBuilder, + proof: &p3_circuit_prover::batch_stark_prover::BatchStarkProof, + pcs_params: &PcsVerifierParams, +) -> Result, VerificationError> +where + ::Pcs: RecursivePcs< + SC, + InputProof, + OpeningProof, + Comm, + >::Domain, + >, + SC::Challenge: PrimeCharacteristicRing, + <::Pcs as Pcs>::Domain: Clone, +{ + assert_eq!(proof.ext_degree, TRACE_D, "trace extension degree mismatch"); + let rows: RowCounts = proof.rows; + let packing = proof.table_packing; + let add_lanes = packing.add_lanes(); + let mul_lanes = packing.mul_lanes(); + + let circuit_airs = vec![ + CircuitTablesAir::Witness(WitnessAir::::new( + rows[Table::Witness], + )), + CircuitTablesAir::Const(ConstAir::::new(rows[Table::Const])), + CircuitTablesAir::Public(PublicAir::::new( + rows[Table::Public], + )), + CircuitTablesAir::Add(AddAir::::new( + rows[Table::Add], + add_lanes, + )), + CircuitTablesAir::Mul(MulAir::::new( + rows[Table::Mul], + mul_lanes, + )), + ]; + + // TODO: public values are empty for all circuit tables for now. + let air_public_counts = vec![0usize; proof.proof.opened_values.instances.len()]; + let verifier_inputs = BatchStarkVerifierInputsBuilder::::allocate( + circuit, + &proof.proof, + &air_public_counts, + ); + + verify_batch_circuit::< + CircuitTablesAir, + SC, + Comm, + InputProof, + OpeningProof, + RATE, + >( + config, + &circuit_airs, + circuit, + &verifier_inputs.proof_targets, + &verifier_inputs.air_public_targets, + pcs_params, + )?; + + Ok(verifier_inputs) +} + +/// Opened values for a single STARK instance within the batch-proof. +#[derive(Clone)] +pub struct InstanceOpenedValuesTargets { + pub trace_local: Vec, + pub trace_next: Vec, + pub quotient_chunks: Vec>, + _phantom: PhantomData, +} + +/// Recursive targets for a batch-STARK proof. +/// +/// The `flattened` field stores the aggregated commitments, opened values, and opening proof in the +/// same layout expected by single-instance PCS logic. The `instances` field retains per-instance +/// opened values so that AIR constraints can be enforced individually. +pub struct BatchProofTargets< + SC: StarkGenericConfig, + Comm: Recursive, + OpeningProof: Recursive, +> { + pub flattened: ProofTargets, + pub instances: Vec>, + pub degree_bits: Vec, +} + +impl< + SC: StarkGenericConfig, + Comm: Recursive>::Commitment>, + OpeningProof: Recursive>::Proof>, +> Recursive for BatchProofTargets +{ + type Input = BatchProof; + + fn new(circuit: &mut CircuitBuilder, input: &Self::Input) -> Self { + let trace_targets = Comm::new(circuit, &input.commitments.main); + let quotient_chunks_targets = Comm::new(circuit, &input.commitments.quotient_chunks); + + // Flattened opened values are ordered as: + // 1. All `trace_local` rows per instance (instance 0 .. N) + // 2. All `trace_next` rows per instance (instance 0 .. N) + // 3. Quotient chunks for each instance in commit order + let mut aggregated_trace_local = Vec::new(); + let mut aggregated_trace_next = Vec::new(); + let mut aggregated_quotient_chunks = Vec::new(); + let mut instances = Vec::with_capacity(input.opened_values.instances.len()); + + for inst in &input.opened_values.instances { + let trace_local = + circuit.alloc_public_inputs(inst.trace_local.len(), "trace local values"); + aggregated_trace_local.extend(trace_local.iter().copied()); + + let trace_next = + circuit.alloc_public_inputs(inst.trace_next.len(), "trace next values"); + aggregated_trace_next.extend(trace_next.iter().copied()); + + let mut quotient_chunks = Vec::with_capacity(inst.quotient_chunks.len()); + for chunk in &inst.quotient_chunks { + let chunk_targets = + circuit.alloc_public_inputs(chunk.len(), "quotient chunk values"); + aggregated_quotient_chunks.push(chunk_targets.clone()); + quotient_chunks.push(chunk_targets); + } + + instances.push(InstanceOpenedValuesTargets { + trace_local, + trace_next, + quotient_chunks, + _phantom: PhantomData, + }); + } + + let opened_values_targets = OpenedValuesTargets { + trace_local_targets: aggregated_trace_local, + trace_next_targets: aggregated_trace_next, + quotient_chunks_targets: aggregated_quotient_chunks, + random_targets: None, + _phantom: PhantomData, + }; + + let flattened = ProofTargets { + commitments_targets: CommitmentTargets { + trace_targets, + quotient_chunks_targets, + random_commit: None, + _phantom: PhantomData, + }, + opened_values_targets, + opening_proof: OpeningProof::new(circuit, &input.opening_proof), + // Placeholder value: degree_bits is not used from the flattened ProofTargets in batch verification. + // The actual per-instance degree bits are stored in BatchProofTargets.degree_bits (Vec) + // and used directly by the verifier. The flattened structure is only used for PCS verification + // which doesn't access this field. + degree_bits: 0, + }; + + Self { + flattened, + instances, + degree_bits: input.degree_bits.clone(), + } + } + + fn get_values(input: &Self::Input) -> Vec { + let commitments = p3_uni_stark::Commitments { + trace: input.commitments.main.clone(), + quotient_chunks: input.commitments.quotient_chunks.clone(), + random: None, + }; + + let mut values = CommitmentTargets::::get_values(&commitments); + + // Opened values, preserving per-instance allocation order. + for inst in &input.opened_values.instances { + values.extend(inst.trace_local.iter().copied()); + values.extend(inst.trace_next.iter().copied()); + for chunk in &inst.quotient_chunks { + values.extend(chunk.iter().copied()); + } + } + + values.extend(OpeningProof::get_values(&input.opening_proof)); + values + } +} + +/// Verify a batch-STARK proof inside a recursive circuit. +pub fn verify_batch_circuit< + A, + SC: StarkGenericConfig, + Comm: Recursive< + SC::Challenge, + Input = >::Commitment, + > + Clone + + ObservableCommitment, + InputProof: Recursive, + OpeningProof: Recursive, + const RATE: usize, +>( + config: &SC, + airs: &[A], + circuit: &mut CircuitBuilder, + proof_targets: &BatchProofTargets, + public_values: &[Vec], + pcs_params: &PcsVerifierParams, +) -> Result<(), VerificationError> +where + A: RecursiveAir, + ::Pcs: RecursivePcs< + SC, + InputProof, + OpeningProof, + Comm, + >::Domain, + >, + SC::Challenge: PrimeCharacteristicRing, + <::Pcs as Pcs>::Domain: Clone, +{ + //TODO: Add support for ZK mode. + debug_assert_eq!(config.is_zk(), 0, "batch recursion assumes non-ZK"); + if airs.is_empty() { + return Err(VerificationError::InvalidProofShape( + "batch-STARK verification requires at least one instance".to_string(), + )); + } + + if airs.len() != proof_targets.instances.len() + || airs.len() != public_values.len() + || airs.len() != proof_targets.degree_bits.len() + { + return Err(VerificationError::InvalidProofShape( + "Mismatch between number of AIRs, instances, public values, or degree bits".to_string(), + )); + } + + let pcs = config.pcs(); + + let flattened = &proof_targets.flattened; + let commitments_targets = &flattened.commitments_targets; + let opened_values_targets = &flattened.opened_values_targets; + let opening_proof = &flattened.opening_proof; + let instances = &proof_targets.instances; + let degree_bits = &proof_targets.degree_bits; + + if commitments_targets.random_commit.is_some() { + return Err(VerificationError::InvalidProofShape( + "Batch-STARK verifier does not support random commitments".to_string(), + )); + } + + let n_instances = airs.len(); + + // Pre-compute per-instance quotient degrees and validate proof shape. + let mut log_quotient_degrees = Vec::with_capacity(n_instances); + let mut quotient_degrees = Vec::with_capacity(n_instances); + for ((air, instance), public_vals) in airs.iter().zip(instances.iter()).zip(public_values) { + let air_width = A::width(air); + if instance.trace_local.len() != air_width || instance.trace_next.len() != air_width { + return Err(VerificationError::InvalidProofShape(format!( + "Instance has incorrect trace width: expected {}, got {} / {}", + air_width, + instance.trace_local.len(), + instance.trace_next.len() + ))); + } + + let log_qd = A::get_log_quotient_degree(air, public_vals.len(), config.is_zk()); + let quotient_degree = 1 << (log_qd + config.is_zk()); + + if instance.quotient_chunks.len() != quotient_degree { + return Err(VerificationError::InvalidProofShape(format!( + "Instance quotient chunk count mismatch: expected {}, got {}", + quotient_degree, + instance.quotient_chunks.len() + ))); + } + + if instance + .quotient_chunks + .iter() + .any(|chunk| chunk.len() != SC::Challenge::DIMENSION) + { + return Err(VerificationError::InvalidProofShape(format!( + "Invalid quotient chunk length: expected {}", + SC::Challenge::DIMENSION + ))); + } + + log_quotient_degrees.push(log_qd); + quotient_degrees.push(quotient_degree); + } + + // Challenger initialisation mirrors the native batch-STARK verifier transcript. + let mut challenger = CircuitChallenger::::new(); + let inst_count_target = circuit.alloc_const( + SC::Challenge::from_usize(n_instances), + "number of instances", + ); + challenger.observe(circuit, inst_count_target); + + for ((&ext_db, quotient_degree), air) in degree_bits + .iter() + .zip(quotient_degrees.iter()) + .zip(airs.iter()) + { + let base_db = ext_db.checked_sub(config.is_zk()).ok_or_else(|| { + VerificationError::InvalidProofShape( + "Extended degree bits smaller than ZK adjustment".to_string(), + ) + })?; + let base_db_target = + circuit.alloc_const(SC::Challenge::from_usize(base_db), "base degree bits"); + let ext_db_target = + circuit.alloc_const(SC::Challenge::from_usize(ext_db), "extended degree bits"); + let width_target = + circuit.alloc_const(SC::Challenge::from_usize(A::width(air)), "air width"); + let quotient_chunks_target = circuit.alloc_const( + SC::Challenge::from_usize(*quotient_degree), + "quotient chunk count", + ); + + challenger.observe(circuit, ext_db_target); + challenger.observe(circuit, base_db_target); + challenger.observe(circuit, width_target); + challenger.observe(circuit, quotient_chunks_target); + } + + challenger.observe_slice( + circuit, + &commitments_targets.trace_targets.to_observation_targets(), + ); + for pv in public_values { + challenger.observe_slice(circuit, pv); + } + let alpha = challenger.sample(circuit); + + challenger.observe_slice( + circuit, + &commitments_targets + .quotient_chunks_targets + .to_observation_targets(), + ); + let zeta = challenger.sample(circuit); + + // Build per-instance domains. + let mut trace_domains = Vec::with_capacity(n_instances); + let mut ext_trace_domains = Vec::with_capacity(n_instances); + for &ext_db in degree_bits { + let base_db = ext_db - config.is_zk(); + trace_domains.push(pcs.natural_domain_for_degree(1 << base_db)); + ext_trace_domains.push(pcs.natural_domain_for_degree(1 << ext_db)); + } + + // Collect commitments with opening points for PCS verification. + let mut coms_to_verify = vec![]; + + let trace_round: Vec<_> = ext_trace_domains + .iter() + .zip(instances.iter()) + .map(|(ext_dom, inst)| { + let first_point = pcs.first_point(ext_dom); + let next_point = ext_dom.next_point(first_point).ok_or_else(|| { + VerificationError::InvalidProofShape( + "Trace domain does not provide next point".to_string(), + ) + })?; + let generator = next_point * first_point.inverse(); + let generator_const = circuit.add_const(generator); + let zeta_next = circuit.mul(zeta, generator_const); + Ok(( + *ext_dom, + vec![ + (zeta, inst.trace_local.clone()), + (zeta_next, inst.trace_next.clone()), + ], + )) + }) + .collect::>()?; + coms_to_verify.push((commitments_targets.trace_targets.clone(), trace_round)); + + let quotient_domains: Vec> = degree_bits + .iter() + .zip(ext_trace_domains.iter()) + .zip(log_quotient_degrees.iter()) + .map(|((&ext_db, ext_dom), &log_qd)| { + let base_db = ext_db - config.is_zk(); + let q_domain = ext_dom.create_disjoint_domain(1 << (base_db + log_qd + config.is_zk())); + q_domain.split_domains(1 << (log_qd + config.is_zk())) + }) + .collect(); + + let mut quotient_round = Vec::new(); + for (domains, inst) in quotient_domains.iter().zip(instances.iter()) { + if domains.len() != inst.quotient_chunks.len() { + return Err(VerificationError::InvalidProofShape( + "Quotient chunk count mismatch across domains".to_string(), + )); + } + for (domain, values) in domains.iter().zip(inst.quotient_chunks.iter()) { + quotient_round.push((*domain, vec![(zeta, values.clone())])); + } + } + coms_to_verify.push(( + commitments_targets.quotient_chunks_targets.clone(), + quotient_round, + )); + + let pcs_challenges = SC::Pcs::get_challenges_circuit::( + circuit, + &mut challenger, + flattened, + opened_values_targets, + pcs_params, + ); + + pcs.verify_circuit( + circuit, + &pcs_challenges, + &coms_to_verify, + opening_proof, + pcs_params, + )?; + + // Verify AIR constraints per instance. + for i in 0..n_instances { + let air = &airs[i]; + let inst = &instances[i]; + let trace_domain = &trace_domains[i]; + let public_vals = &public_values[i]; + let domains = "ient_domains[i]; + + let quotient = recompose_quotient_from_chunks_circuit::( + circuit, + domains, + &inst.quotient_chunks, + zeta, + pcs, + ); + + let sels = pcs.selectors_at_point_circuit(circuit, trace_domain, &zeta); + let columns_targets = ColumnsTargets { + challenges: &[], + public_values: public_vals, + local_prep_values: &[], + next_prep_values: &[], + local_values: &inst.trace_local, + next_values: &inst.trace_next, + }; + let folded_constraints = air.eval_folded_circuit(circuit, &sels, &alpha, columns_targets); + + let folded_mul = circuit.mul(folded_constraints, sels.inv_vanishing); + circuit.connect(folded_mul, quotient); + } + + Ok(()) +} diff --git a/recursion/src/verifier/mod.rs b/recursion/src/verifier/mod.rs index 1a460ef..4bdb423 100644 --- a/recursion/src/verifier/mod.rs +++ b/recursion/src/verifier/mod.rs @@ -1,10 +1,15 @@ //! STARK verification within recursive circuits. +mod batch_stark; mod errors; mod observable; mod quotient; mod stark; +pub use batch_stark::{ + BatchProofTargets, CircuitTablesAir, InstanceOpenedValuesTargets, PcsVerifierParams, + verify_batch_circuit, verify_p3_recursion_proof_circuit, +}; pub use errors::VerificationError; pub use observable::ObservableCommitment; pub use quotient::recompose_quotient_from_chunks_circuit; diff --git a/recursion/src/verifier/stark.rs b/recursion/src/verifier/stark.rs index a3c02e7..8705aff 100644 --- a/recursion/src/verifier/stark.rs +++ b/recursion/src/verifier/stark.rs @@ -10,7 +10,6 @@ use p3_circuit::utils::ColumnsTargets; use p3_commit::Pcs; use p3_field::{BasedVectorSpace, PrimeCharacteristicRing}; use p3_uni_stark::StarkGenericConfig; -use p3_util::zip_eq::zip_eq; use super::{ObservableCommitment, VerificationError, recompose_quotient_from_chunks_circuit}; use crate::Target; @@ -188,15 +187,19 @@ where ( quotient_chunks_targets.clone(), // Check the commitment on the randomized domains - zip_eq( - randomized_quotient_chunks_domains.iter(), - opened_quotient_chunks_targets, - VerificationError::InvalidProofShape( - "Randomized quotient chunks length mismatch".to_string(), - ), - )? - .map(|(domain, values)| (*domain, vec![(zeta, values.clone())])) - .collect_vec(), + { + if randomized_quotient_chunks_domains.len() != opened_quotient_chunks_targets.len() + { + return Err(VerificationError::InvalidProofShape( + "Randomized quotient chunks length mismatch".to_string(), + )); + } + randomized_quotient_chunks_domains + .iter() + .zip(opened_quotient_chunks_targets) + .map(|(domain, values)| (*domain, vec![(zeta, values.clone())])) + .collect_vec() + }, ), ]); diff --git a/recursion/tests/fibonacci_batch_stark_prover.rs b/recursion/tests/fibonacci_batch_stark_prover.rs new file mode 100644 index 0000000..cd16e87 --- /dev/null +++ b/recursion/tests/fibonacci_batch_stark_prover.rs @@ -0,0 +1,260 @@ +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_baby_bear::{BabyBear, Poseidon2BabyBear}; +use p3_challenger::DuplexChallenger; +use p3_circuit::CircuitBuilder; +use p3_circuit_prover::air::{AddAir, ConstAir, MulAir, PublicAir, WitnessAir}; +use p3_circuit_prover::batch_stark_prover::Table; +use p3_circuit_prover::{BatchStarkProver, TablePacking}; +use p3_commit::ExtensionMmcs; +use p3_dft::Radix2DitParallel; +use p3_field::extension::BinomialExtensionField; +use p3_field::{Field, PrimeCharacteristicRing}; +use p3_fri::{TwoAdicFriPcs, create_test_fri_params}; +use p3_merkle_tree::MerkleTreeMmcs; +use p3_recursion::generation::generate_batch_challenges; +use p3_recursion::pcs::fri::{ + FriProofTargets, FriVerifierParams, HashTargets, InputProofTargets, RecExtensionValMmcs, + RecValMmcs, Witness, +}; +use p3_recursion::verifier::verify_p3_recursion_proof_circuit; +use p3_symmetric::{PaddingFreeSponge, TruncatedPermutation}; +use p3_uni_stark::{StarkConfig, StarkGenericConfig, Val}; +use rand::SeedableRng; +use rand::rngs::SmallRng; + +type F = BabyBear; + +/// Wrapper enum for heterogeneous circuit table AIRs +enum CircuitTableAir { + Witness(WitnessAir), + Const(ConstAir), + Public(PublicAir), + Add(AddAir), + Mul(MulAir), +} + +impl BaseAir for CircuitTableAir { + fn width(&self) -> usize { + match self { + Self::Witness(a) => a.width(), + Self::Const(a) => a.width(), + Self::Public(a) => a.width(), + Self::Add(a) => a.width(), + Self::Mul(a) => a.width(), + } + } +} + +impl Air for CircuitTableAir +where + AB: AirBuilder, + AB::F: Field, +{ + fn eval(&self, builder: &mut AB) { + match self { + Self::Witness(a) => a.eval(builder), + Self::Const(a) => a.eval(builder), + Self::Public(a) => a.eval(builder), + Self::Add(a) => a.eval(builder), + Self::Mul(a) => a.eval(builder), + } + } +} + +// Type aliases for the BabyBear config with D=4 extension +const D: usize = 4; +const RATE: usize = 8; +const DIGEST_ELEMS: usize = 8; +type Challenge = BinomialExtensionField; +type Dft = Radix2DitParallel; +type Perm = Poseidon2BabyBear<16>; +type MyHash = PaddingFreeSponge; +type MyCompress = TruncatedPermutation; +type ValMmcs = MerkleTreeMmcs<::Packing, ::Packing, MyHash, MyCompress, 8>; +type ChallengeMmcs = ExtensionMmcs; +type Challenger = DuplexChallenger; +type MyPcs = TwoAdicFriPcs; +type MyConfig = StarkConfig; + +// Type for the FRI proof used in recursive verification +type InnerFri = FriProofTargets< + Val, + ::Challenge, + RecExtensionValMmcs< + Val, + ::Challenge, + DIGEST_ELEMS, + RecValMmcs, DIGEST_ELEMS, MyHash, MyCompress>, + >, + InputProofTargets< + Val, + ::Challenge, + RecValMmcs, DIGEST_ELEMS, MyHash, MyCompress>, + >, + Witness>, +>; + +#[test] +fn test_fibonacci_batch_verifier() { + let n: usize = 100; + + let mut builder = CircuitBuilder::new(); + + // Public input: expected F(n) + let expected_result = builder.alloc_public_input("expected_result"); + + // Compute F(n) iteratively + let mut a = builder.alloc_const(F::ZERO, "F(0)"); + let mut b = builder.alloc_const(F::ONE, "F(1)"); + + // TODO: remove this once we always have non-empty MUL tables + builder.mul(a, b); + for _i in 2..=n { + let next = builder.add(a, b); + a = b; + b = next; + } + + // Assert computed F(n) equals expected result + builder.connect(b, expected_result); + + builder.dump_allocation_log(); + + let circuit = builder.build().unwrap(); + let mut runner = circuit.runner(); + + // Set public input + let expected_fib = compute_fibonacci_classical(n); + runner.set_public_inputs(&[expected_fib]).unwrap(); + + let traces = runner.run().unwrap(); + + // Use a seeded RNG for deterministic permutations + let mut rng = SmallRng::seed_from_u64(42); + let perm = Perm::new_from_rng_128(&mut rng); + let hash = MyHash::new(perm.clone()); + let compress = MyCompress::new(perm.clone()); + let val_mmcs = ValMmcs::new(hash, compress); + let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone()); + let dft = Dft::default(); + + // Create test FRI params with log_final_poly_len = 0 + let fri_params = create_test_fri_params(challenge_mmcs, 0); + + // Create config for proving + let pcs_proving = MyPcs::new(dft, val_mmcs.clone(), fri_params); + let challenger_proving = Challenger::new(perm.clone()); + let config_proving = MyConfig::new(pcs_proving, challenger_proving); + + let table_packing = TablePacking::from_counts(4, 1); + let prover = BatchStarkProver::new(config_proving).with_table_packing(table_packing); + let batch_stark_proof = prover.prove_all_tables(&traces).unwrap(); + prover.verify_all_tables(&batch_stark_proof).unwrap(); + + // Now verify the batch STARK proof recursively + let dft2 = Dft::default(); + let mut rng2 = SmallRng::seed_from_u64(42); + let perm2 = Perm::new_from_rng_128(&mut rng2); + let hash2 = MyHash::new(perm2.clone()); + let compress2 = MyCompress::new(perm2.clone()); + let val_mmcs2 = ValMmcs::new(hash2, compress2); + let challenge_mmcs2 = ChallengeMmcs::new(val_mmcs2.clone()); + let fri_params2 = create_test_fri_params(challenge_mmcs2, 0); + let fri_verifier_params = FriVerifierParams::from(&fri_params2); + let pow_bits = fri_params2.proof_of_work_bits; + let log_height_max = fri_params2.log_final_poly_len + fri_params2.log_blowup; + let pcs_verif = MyPcs::new(dft2, val_mmcs2, fri_params2); + let challenger_verif = Challenger::new(perm2); + let config = MyConfig::new(pcs_verif, challenger_verif); + + // Extract proof components + let batch_proof = &batch_stark_proof.proof; + let rows = batch_stark_proof.rows; + let packing = batch_stark_proof.table_packing; + + const TRACE_D: usize = 1; // Proof traces are in base field + + // Base field AIRs for native challenge generation + let native_airs = vec![ + CircuitTableAir::Witness(WitnessAir::::new(rows[Table::Witness])), + CircuitTableAir::Const(ConstAir::::new(rows[Table::Const])), + CircuitTableAir::Public(PublicAir::::new(rows[Table::Public])), + CircuitTableAir::Add(AddAir::::new( + rows[Table::Add], + packing.add_lanes(), + )), + CircuitTableAir::Mul(MulAir::::new( + rows[Table::Mul], + packing.mul_lanes(), + )), + ]; + + // Public values (empty for all 5 circuit tables, using base field) + let pis: Vec> = vec![vec![]; 5]; + + // Build the recursive verification circuit + let mut circuit_builder = CircuitBuilder::new(); + + // Attach verifier without manually building circuit_airs + let verifier_inputs = verify_p3_recursion_proof_circuit::< + MyConfig, + HashTargets, + InputProofTargets>, + InnerFri, + RATE, + TRACE_D, + >( + &config, + &mut circuit_builder, + &batch_stark_proof, + &fri_verifier_params, + ) + .unwrap(); + + // Build the circuit + let verification_circuit = circuit_builder.build().unwrap(); + let expected_public_input_len = verification_circuit.public_flat_len; + + // Generate all the challenge values for batch proof (uses base field AIRs) + let all_challenges = generate_batch_challenges( + &native_airs, + &config, + batch_proof, + &pis, + Some(&[pow_bits, log_height_max]), + ) + .unwrap(); + + // Pack values using the builder + let num_queries = batch_proof.opening_proof.query_proofs.len(); + let public_inputs = + verifier_inputs.pack_values(&pis, batch_proof, &all_challenges, num_queries); + + assert_eq!(public_inputs.len(), expected_public_input_len); + assert!(!public_inputs.is_empty()); + + // Actually run the circuit to ensure constraints are satisfiable + let mut runner = verification_circuit.runner(); + runner.set_public_inputs(&public_inputs).unwrap(); + let _traces = runner.run().unwrap(); +} + +fn compute_fibonacci_classical(n: usize) -> F { + if n == 0 { + return F::ZERO; + } + if n == 1 { + return F::ONE; + } + + let mut a = F::ZERO; + let mut b = F::ONE; + + for _i in 2..=n { + let next = a + b; + a = b; + b = next; + } + + b +}