Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
214 changes: 159 additions & 55 deletions circuit-prover/src/batch_stark_prover.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,12 @@ use p3_batch_stark::{BatchProof, StarkGenericConfig as MSGC, StarkInstance, Val
use p3_circuit::tables::Traces;
use p3_field::{BasedVectorSpace, Field};
use p3_matrix::Matrix;
use p3_matrix::dense::RowMajorMatrix;
use p3_mmcs_air::air::{MmcsTableConfig, MmcsVerifyAir};
use thiserror::Error;
use tracing::instrument;

use crate::air::{AddAir, ConstAir, MulAir, PublicAir, WitnessAir};
use crate::config::StarkField;
use crate::field_params::ExtractBinomialW;
use crate::prover::TablePacking;

#[repr(usize)]
Expand Down Expand Up @@ -200,23 +198,24 @@ where
self.table_packing
}

/// Generate a unified batch STARK proof for all circuit tables.
/// Generate a batch STARK proof for extension field circuits with explicit W.
///
/// The caller must provide the binomial parameter W for the extension field.
#[instrument(skip_all)]
pub fn prove_all_tables<EF>(
pub fn prove_all_tables_extension<EF>(
&self,
traces: &Traces<EF>,
w: MVal<SC>,
) -> Result<BatchStarkProof<SC>, BatchStarkProverError>
where
EF: Field + BasedVectorSpace<MVal<SC>> + ExtractBinomialW<MVal<SC>>,
EF: Field + BasedVectorSpace<MVal<SC>>,
{
let w_opt = EF::extract_w();
match EF::DIMENSION {
1 => self.prove::<EF, 1>(traces, None),
2 => self.prove::<EF, 2>(traces, w_opt),
4 => self.prove::<EF, 4>(traces, w_opt),
6 => self.prove::<EF, 6>(traces, w_opt),
8 => self.prove::<EF, 8>(traces, w_opt),
d => Err(BatchStarkProverError::UnsupportedDegree(d)),
2 => self.prove_extension::<EF, 2>(traces, w),
4 => self.prove_extension::<EF, 4>(traces, w),
6 => self.prove_extension::<EF, 6>(traces, w),
8 => self.prove_extension::<EF, 8>(traces, w),
_ => Err(BatchStarkProverError::UnsupportedDegree(EF::DIMENSION)),
}
}

Expand All @@ -235,15 +234,119 @@ where
}
}

/// Generate a batch STARK proof for a specific extension field degree.
/// Generate a batch STARK proof for base field circuits (D=1).
pub fn prove_all_tables_base<EF>(
&self,
traces: &Traces<EF>,
) -> Result<BatchStarkProof<SC>, BatchStarkProverError>
where
EF: Field + BasedVectorSpace<MVal<SC>>,
{
const D: usize = 1;
// TODO: Consider parallelizing AIR construction and trace-to-matrix conversions.
let packing = self.table_packing;
let add_lanes = packing.add_lanes();
let mul_lanes = packing.mul_lanes();

// Witness
let witness_rows = traces.witness_trace.values.len();
let witness_air = WitnessAir::<_, D>::new(witness_rows);
let witness_matrix = WitnessAir::<_, D>::trace_to_matrix(&traces.witness_trace);

// Const
let const_rows = traces.const_trace.values.len();
let const_air = ConstAir::new(const_rows);
let const_matrix = ConstAir::<_, D>::trace_to_matrix(&traces.const_trace);

// Public
let public_rows = traces.public_trace.values.len();
let public_air = PublicAir::new(public_rows);
let public_matrix = PublicAir::<_, D>::trace_to_matrix(&traces.public_trace);

// Add
let add_rows = traces.add_trace.lhs_values.len();
let add_air = AddAir::new(add_rows, add_lanes);
let add_matrix = AddAir::<_, D>::trace_to_matrix(&traces.add_trace, add_lanes);

// Mul - base field doesn't need binomial parameter
let mul_rows = traces.mul_trace.lhs_values.len();
let mul_air = MulAir::new(mul_rows, mul_lanes);
let mul_matrix = MulAir::<_, D>::trace_to_matrix(&traces.mul_trace, mul_lanes);

// Mmcs
let mmcs_air = MmcsVerifyAir::new(self.mmcs_config);
let mmcs_matrix = MmcsVerifyAir::trace_to_matrix(&self.mmcs_config, &traces.mmcs_trace);
let mmcs_rows = mmcs_matrix.height();

// Wrap AIRs in enum for heterogeneous batching and build instances in fixed order.
let air_witness = CircuitTableAir::Witness(witness_air);
let air_const = CircuitTableAir::Const(const_air);
let air_public = CircuitTableAir::Public(public_air);
let air_add = CircuitTableAir::Add(add_air);
let air_mul = CircuitTableAir::Mul(mul_air);
let air_mmcs = CircuitTableAir::Mmcs(mmcs_air);

// Pre-size for performance
let mut instances = Vec::with_capacity(NUM_TABLES);
instances.extend([
StarkInstance {
air: &air_witness,
trace: witness_matrix,
public_values: vec![],
},
StarkInstance {
air: &air_const,
trace: const_matrix,
public_values: vec![],
},
StarkInstance {
air: &air_public,
trace: public_matrix,
public_values: vec![],
},
StarkInstance {
air: &air_add,
trace: add_matrix,
public_values: vec![],
},
StarkInstance {
air: &air_mul,
trace: mul_matrix,
public_values: vec![],
},
StarkInstance {
air: &air_mmcs,
trace: mmcs_matrix,
public_values: vec![],
},
]);

let proof = p3_batch_stark::prove_batch(&self.config, instances);

Ok(BatchStarkProof {
proof,
table_packing: packing,
rows: RowCounts::new([
witness_rows,
const_rows,
public_rows,
add_rows,
mul_rows,
mmcs_rows,
]),
ext_degree: D,
w_binomial: None,
mmcs_config: self.mmcs_config,
})
}

/// Generate a batch STARK proof for extension field circuits (D>1).
///
/// This is the core proving logic that handles all circuit tables for a given
/// extension field dimension. It constructs AIRs, converts traces to matrices,
/// and generates the unified proof.
fn prove<EF, const D: usize>(
/// The binomial parameter W must be provided by the caller.
fn prove_extension<EF, const D: usize>(
&self,
traces: &Traces<EF>,
w_binomial: Option<MVal<SC>>,
w: MVal<SC>,
) -> Result<BatchStarkProof<SC>, BatchStarkProverError>
where
EF: Field + BasedVectorSpace<MVal<SC>>,
Expand All @@ -256,43 +359,32 @@ where

// Witness
let witness_rows = traces.witness_trace.values.len();
let witness_air = WitnessAir::<MVal<SC>, D>::new(witness_rows);
let witness_matrix: RowMajorMatrix<MVal<SC>> =
WitnessAir::<MVal<SC>, D>::trace_to_matrix(&traces.witness_trace);
let witness_air = WitnessAir::<_, D>::new(witness_rows);
let witness_matrix = WitnessAir::<_, D>::trace_to_matrix(&traces.witness_trace);

// Const
let const_rows = traces.const_trace.values.len();
let const_air = ConstAir::<MVal<SC>, D>::new(const_rows);
let const_matrix: RowMajorMatrix<MVal<SC>> =
ConstAir::<MVal<SC>, D>::trace_to_matrix(&traces.const_trace);
let const_air = ConstAir::new(const_rows);
let const_matrix = ConstAir::<_, D>::trace_to_matrix(&traces.const_trace);

// Public
let public_rows = traces.public_trace.values.len();
let public_air = PublicAir::<MVal<SC>, D>::new(public_rows);
let public_matrix: RowMajorMatrix<MVal<SC>> =
PublicAir::<MVal<SC>, D>::trace_to_matrix(&traces.public_trace);
let public_air = PublicAir::new(public_rows);
let public_matrix = PublicAir::<_, D>::trace_to_matrix(&traces.public_trace);

// Add
let add_rows = traces.add_trace.lhs_values.len();
let add_air = AddAir::<MVal<SC>, D>::new(add_rows, add_lanes);
let add_matrix: RowMajorMatrix<MVal<SC>> =
AddAir::<MVal<SC>, D>::trace_to_matrix(&traces.add_trace, add_lanes);
let add_air = AddAir::new(add_rows, add_lanes);
let add_matrix = AddAir::<_, D>::trace_to_matrix(&traces.add_trace, add_lanes);

// Mul
// Mul - extension field uses provided W parameter
let mul_rows = traces.mul_trace.lhs_values.len();
let mul_air: MulAir<MVal<SC>, D> = if D == 1 {
MulAir::<MVal<SC>, D>::new(mul_rows, mul_lanes)
} else {
let w = w_binomial.ok_or(BatchStarkProverError::MissingWForExtension)?;
MulAir::<MVal<SC>, D>::new_binomial(mul_rows, mul_lanes, w)
};
let mul_matrix: RowMajorMatrix<MVal<SC>> =
MulAir::<MVal<SC>, D>::trace_to_matrix(&traces.mul_trace, mul_lanes);
let mul_air = MulAir::new_binomial(mul_rows, mul_lanes, w);
let mul_matrix = MulAir::<_, D>::trace_to_matrix(&traces.mul_trace, mul_lanes);

// Mmcs
let mmcs_air = MmcsVerifyAir::<MVal<SC>>::new(self.mmcs_config);
let mmcs_matrix: RowMajorMatrix<MVal<SC>> =
MmcsVerifyAir::trace_to_matrix(&self.mmcs_config, &traces.mmcs_trace);
let mmcs_air = MmcsVerifyAir::new(self.mmcs_config);
let mmcs_matrix = MmcsVerifyAir::trace_to_matrix(&self.mmcs_config, &traces.mmcs_trace);
let mmcs_rows: usize = mmcs_matrix.height();

// Wrap AIRs in enum for heterogeneous batching and build instances in fixed order.
Expand Down Expand Up @@ -352,7 +444,7 @@ where
mmcs_rows,
]),
ext_degree: D,
w_binomial: if D > 1 { w_binomial } else { None },
w_binomial: Some(w),
mmcs_config: self.mmcs_config,
})
}
Expand Down Expand Up @@ -419,7 +511,7 @@ mod tests {
use p3_baby_bear::BabyBear;
use p3_circuit::builder::CircuitBuilder;
use p3_field::PrimeCharacteristicRing;
use p3_field::extension::BinomialExtensionField;
use p3_field::extension::{BinomialExtensionField, BinomiallyExtendable};
use p3_goldilocks::Goldilocks;
use p3_koala_bear::KoalaBear;

Expand Down Expand Up @@ -456,7 +548,7 @@ mod tests {

let cfg = config::baby_bear().build();
let prover = BatchStarkProver::new(cfg);
let proof = prover.prove_all_tables(&traces).unwrap();
let proof = prover.prove_all_tables_base(&traces).unwrap();
assert_eq!(proof.ext_degree, 1);
assert!(proof.w_binomial.is_none());
prover.verify_all_tables(&proof).unwrap();
Expand Down Expand Up @@ -503,11 +595,15 @@ mod tests {

let cfg = config::baby_bear().build();
let prover = BatchStarkProver::new(cfg);
let proof = prover.prove_all_tables(&traces).unwrap();
let proof = prover
.prove_all_tables_extension(&traces, <BabyBear as BinomiallyExtendable<4>>::W)
.unwrap();
assert_eq!(proof.ext_degree, 4);
// Ensure W was captured
let expected_w = <Ext4 as ExtractBinomialW<BabyBear>>::extract_w().unwrap();
assert_eq!(proof.w_binomial, Some(expected_w));
assert_eq!(
proof.w_binomial,
Some(<BabyBear as BinomiallyExtendable<4>>::W)
);
prover.verify_all_tables(&proof).unwrap();
}

Expand Down Expand Up @@ -541,7 +637,7 @@ mod tests {

let cfg = config::koala_bear().build();
let prover = BatchStarkProver::new(cfg);
let proof = prover.prove_all_tables(&traces).unwrap();
let proof = prover.prove_all_tables_base(&traces).unwrap();
assert_eq!(proof.ext_degree, 1);
assert!(proof.w_binomial.is_none());
prover.verify_all_tables(&proof).unwrap();
Expand Down Expand Up @@ -620,10 +716,14 @@ mod tests {

let cfg = config::koala_bear().build();
let prover = BatchStarkProver::new(cfg);
let proof = prover.prove_all_tables(&traces).unwrap();
let proof = prover
.prove_all_tables_extension(&traces, <KoalaBear as BinomiallyExtendable<8>>::W)
.unwrap();
assert_eq!(proof.ext_degree, 8);
let expected_w = <KBExtField as ExtractBinomialW<KoalaBear>>::extract_w().unwrap();
assert_eq!(proof.w_binomial, Some(expected_w));
assert_eq!(
proof.w_binomial,
Some(<KoalaBear as BinomiallyExtendable<8>>::W)
);
prover.verify_all_tables(&proof).unwrap();
}

Expand Down Expand Up @@ -668,10 +768,14 @@ mod tests {

let cfg = config::goldilocks().build();
let prover = BatchStarkProver::new(cfg);
let proof = prover.prove_all_tables(&traces).unwrap();
let proof = prover
.prove_all_tables_extension(&traces, <Goldilocks as BinomiallyExtendable<2>>::W)
.unwrap();
assert_eq!(proof.ext_degree, 2);
let expected_w = <Ext2 as ExtractBinomialW<Goldilocks>>::extract_w().unwrap();
assert_eq!(proof.w_binomial, Some(expected_w));
assert_eq!(
proof.w_binomial,
Some(<Goldilocks as BinomiallyExtendable<2>>::W)
);
prover.verify_all_tables(&proof).unwrap();
}
}
43 changes: 0 additions & 43 deletions circuit-prover/src/field_params.rs

This file was deleted.

1 change: 0 additions & 1 deletion circuit-prover/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ extern crate alloc;
pub mod air;
pub mod batch_stark_prover;
pub mod config;
pub mod field_params;
pub mod prover;

// Re-export main API
Expand Down
Loading
Loading