Skip to content

Commit bfb62fe

Browse files
committed
circuit prover: simplify w handling in batch stark prover
1 parent a8760d0 commit bfb62fe

File tree

4 files changed

+205
-122
lines changed

4 files changed

+205
-122
lines changed

circuit-prover/src/batch_stark_prover.rs

Lines changed: 159 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,12 @@ use p3_batch_stark::{BatchProof, StarkGenericConfig as MSGC, StarkInstance, Val
1010
use p3_circuit::tables::Traces;
1111
use p3_field::{BasedVectorSpace, Field};
1212
use p3_matrix::Matrix;
13-
use p3_matrix::dense::RowMajorMatrix;
1413
use p3_mmcs_air::air::{MmcsTableConfig, MmcsVerifyAir};
1514
use thiserror::Error;
1615
use tracing::instrument;
1716

1817
use crate::air::{AddAir, ConstAir, MulAir, PublicAir, WitnessAir};
1918
use crate::config::StarkField;
20-
use crate::field_params::ExtractBinomialW;
2119
use crate::prover::TablePacking;
2220

2321
#[repr(usize)]
@@ -200,23 +198,24 @@ where
200198
self.table_packing
201199
}
202200

203-
/// Generate a unified batch STARK proof for all circuit tables.
201+
/// Generate a batch STARK proof for extension field circuits with explicit W.
202+
///
203+
/// The caller must provide the binomial parameter W for the extension field.
204204
#[instrument(skip_all)]
205-
pub fn prove_all_tables<EF>(
205+
pub fn prove_all_tables_extension<EF>(
206206
&self,
207207
traces: &Traces<EF>,
208+
w: MVal<SC>,
208209
) -> Result<BatchStarkProof<SC>, BatchStarkProverError>
209210
where
210-
EF: Field + BasedVectorSpace<MVal<SC>> + ExtractBinomialW<MVal<SC>>,
211+
EF: Field + BasedVectorSpace<MVal<SC>>,
211212
{
212-
let w_opt = EF::extract_w();
213213
match EF::DIMENSION {
214-
1 => self.prove::<EF, 1>(traces, None),
215-
2 => self.prove::<EF, 2>(traces, w_opt),
216-
4 => self.prove::<EF, 4>(traces, w_opt),
217-
6 => self.prove::<EF, 6>(traces, w_opt),
218-
8 => self.prove::<EF, 8>(traces, w_opt),
219-
d => Err(BatchStarkProverError::UnsupportedDegree(d)),
214+
2 => self.prove_extension::<EF, 2>(traces, w),
215+
4 => self.prove_extension::<EF, 4>(traces, w),
216+
6 => self.prove_extension::<EF, 6>(traces, w),
217+
8 => self.prove_extension::<EF, 8>(traces, w),
218+
_ => Err(BatchStarkProverError::UnsupportedDegree(EF::DIMENSION)),
220219
}
221220
}
222221

@@ -235,15 +234,119 @@ where
235234
}
236235
}
237236

238-
/// Generate a batch STARK proof for a specific extension field degree.
237+
/// Generate a batch STARK proof for base field circuits (D=1).
238+
pub fn prove_all_tables_base<EF>(
239+
&self,
240+
traces: &Traces<EF>,
241+
) -> Result<BatchStarkProof<SC>, BatchStarkProverError>
242+
where
243+
EF: Field + BasedVectorSpace<MVal<SC>>,
244+
{
245+
const D: usize = 1;
246+
// TODO: Consider parallelizing AIR construction and trace-to-matrix conversions.
247+
let packing = self.table_packing;
248+
let add_lanes = packing.add_lanes();
249+
let mul_lanes = packing.mul_lanes();
250+
251+
// Witness
252+
let witness_rows = traces.witness_trace.values.len();
253+
let witness_air = WitnessAir::<_, D>::new(witness_rows);
254+
let witness_matrix = WitnessAir::<_, D>::trace_to_matrix(&traces.witness_trace);
255+
256+
// Const
257+
let const_rows = traces.const_trace.values.len();
258+
let const_air = ConstAir::new(const_rows);
259+
let const_matrix = ConstAir::<_, D>::trace_to_matrix(&traces.const_trace);
260+
261+
// Public
262+
let public_rows = traces.public_trace.values.len();
263+
let public_air = PublicAir::new(public_rows);
264+
let public_matrix = PublicAir::<_, D>::trace_to_matrix(&traces.public_trace);
265+
266+
// Add
267+
let add_rows = traces.add_trace.lhs_values.len();
268+
let add_air = AddAir::new(add_rows, add_lanes);
269+
let add_matrix = AddAir::<_, D>::trace_to_matrix(&traces.add_trace, add_lanes);
270+
271+
// Mul - base field doesn't need binomial parameter
272+
let mul_rows = traces.mul_trace.lhs_values.len();
273+
let mul_air = MulAir::new(mul_rows, mul_lanes);
274+
let mul_matrix = MulAir::<_, D>::trace_to_matrix(&traces.mul_trace, mul_lanes);
275+
276+
// Mmcs
277+
let mmcs_air = MmcsVerifyAir::new(self.mmcs_config);
278+
let mmcs_matrix = MmcsVerifyAir::trace_to_matrix(&self.mmcs_config, &traces.mmcs_trace);
279+
let mmcs_rows = mmcs_matrix.height();
280+
281+
// Wrap AIRs in enum for heterogeneous batching and build instances in fixed order.
282+
let air_witness = CircuitTableAir::Witness(witness_air);
283+
let air_const = CircuitTableAir::Const(const_air);
284+
let air_public = CircuitTableAir::Public(public_air);
285+
let air_add = CircuitTableAir::Add(add_air);
286+
let air_mul = CircuitTableAir::Mul(mul_air);
287+
let air_mmcs = CircuitTableAir::Mmcs(mmcs_air);
288+
289+
// Pre-size for performance
290+
let mut instances = Vec::with_capacity(NUM_TABLES);
291+
instances.extend([
292+
StarkInstance {
293+
air: &air_witness,
294+
trace: witness_matrix,
295+
public_values: vec![],
296+
},
297+
StarkInstance {
298+
air: &air_const,
299+
trace: const_matrix,
300+
public_values: vec![],
301+
},
302+
StarkInstance {
303+
air: &air_public,
304+
trace: public_matrix,
305+
public_values: vec![],
306+
},
307+
StarkInstance {
308+
air: &air_add,
309+
trace: add_matrix,
310+
public_values: vec![],
311+
},
312+
StarkInstance {
313+
air: &air_mul,
314+
trace: mul_matrix,
315+
public_values: vec![],
316+
},
317+
StarkInstance {
318+
air: &air_mmcs,
319+
trace: mmcs_matrix,
320+
public_values: vec![],
321+
},
322+
]);
323+
324+
let proof = p3_batch_stark::prove_batch(&self.config, instances);
325+
326+
Ok(BatchStarkProof {
327+
proof,
328+
table_packing: packing,
329+
rows: RowCounts::new([
330+
witness_rows,
331+
const_rows,
332+
public_rows,
333+
add_rows,
334+
mul_rows,
335+
mmcs_rows,
336+
]),
337+
ext_degree: D,
338+
w_binomial: None,
339+
mmcs_config: self.mmcs_config,
340+
})
341+
}
342+
343+
/// Generate a batch STARK proof for extension field circuits (D>1).
239344
///
240-
/// This is the core proving logic that handles all circuit tables for a given
241-
/// extension field dimension. It constructs AIRs, converts traces to matrices,
242-
/// and generates the unified proof.
243-
fn prove<EF, const D: usize>(
345+
/// The binomial parameter W must be provided by the caller.
346+
fn prove_extension<EF, const D: usize>(
244347
&self,
245348
traces: &Traces<EF>,
246-
w_binomial: Option<MVal<SC>>,
349+
w: MVal<SC>,
247350
) -> Result<BatchStarkProof<SC>, BatchStarkProverError>
248351
where
249352
EF: Field + BasedVectorSpace<MVal<SC>>,
@@ -256,43 +359,32 @@ where
256359

257360
// Witness
258361
let witness_rows = traces.witness_trace.values.len();
259-
let witness_air = WitnessAir::<MVal<SC>, D>::new(witness_rows);
260-
let witness_matrix: RowMajorMatrix<MVal<SC>> =
261-
WitnessAir::<MVal<SC>, D>::trace_to_matrix(&traces.witness_trace);
362+
let witness_air = WitnessAir::<_, D>::new(witness_rows);
363+
let witness_matrix = WitnessAir::<_, D>::trace_to_matrix(&traces.witness_trace);
262364

263365
// Const
264366
let const_rows = traces.const_trace.values.len();
265-
let const_air = ConstAir::<MVal<SC>, D>::new(const_rows);
266-
let const_matrix: RowMajorMatrix<MVal<SC>> =
267-
ConstAir::<MVal<SC>, D>::trace_to_matrix(&traces.const_trace);
367+
let const_air = ConstAir::new(const_rows);
368+
let const_matrix = ConstAir::<_, D>::trace_to_matrix(&traces.const_trace);
268369

269370
// Public
270371
let public_rows = traces.public_trace.values.len();
271-
let public_air = PublicAir::<MVal<SC>, D>::new(public_rows);
272-
let public_matrix: RowMajorMatrix<MVal<SC>> =
273-
PublicAir::<MVal<SC>, D>::trace_to_matrix(&traces.public_trace);
372+
let public_air = PublicAir::new(public_rows);
373+
let public_matrix = PublicAir::<_, D>::trace_to_matrix(&traces.public_trace);
274374

275375
// Add
276376
let add_rows = traces.add_trace.lhs_values.len();
277-
let add_air = AddAir::<MVal<SC>, D>::new(add_rows, add_lanes);
278-
let add_matrix: RowMajorMatrix<MVal<SC>> =
279-
AddAir::<MVal<SC>, D>::trace_to_matrix(&traces.add_trace, add_lanes);
377+
let add_air = AddAir::new(add_rows, add_lanes);
378+
let add_matrix = AddAir::<_, D>::trace_to_matrix(&traces.add_trace, add_lanes);
280379

281-
// Mul
380+
// Mul - extension field uses provided W parameter
282381
let mul_rows = traces.mul_trace.lhs_values.len();
283-
let mul_air: MulAir<MVal<SC>, D> = if D == 1 {
284-
MulAir::<MVal<SC>, D>::new(mul_rows, mul_lanes)
285-
} else {
286-
let w = w_binomial.ok_or(BatchStarkProverError::MissingWForExtension)?;
287-
MulAir::<MVal<SC>, D>::new_binomial(mul_rows, mul_lanes, w)
288-
};
289-
let mul_matrix: RowMajorMatrix<MVal<SC>> =
290-
MulAir::<MVal<SC>, D>::trace_to_matrix(&traces.mul_trace, mul_lanes);
382+
let mul_air = MulAir::new_binomial(mul_rows, mul_lanes, w);
383+
let mul_matrix = MulAir::<_, D>::trace_to_matrix(&traces.mul_trace, mul_lanes);
291384

292385
// Mmcs
293-
let mmcs_air = MmcsVerifyAir::<MVal<SC>>::new(self.mmcs_config);
294-
let mmcs_matrix: RowMajorMatrix<MVal<SC>> =
295-
MmcsVerifyAir::trace_to_matrix(&self.mmcs_config, &traces.mmcs_trace);
386+
let mmcs_air = MmcsVerifyAir::new(self.mmcs_config);
387+
let mmcs_matrix = MmcsVerifyAir::trace_to_matrix(&self.mmcs_config, &traces.mmcs_trace);
296388
let mmcs_rows: usize = mmcs_matrix.height();
297389

298390
// Wrap AIRs in enum for heterogeneous batching and build instances in fixed order.
@@ -352,7 +444,7 @@ where
352444
mmcs_rows,
353445
]),
354446
ext_degree: D,
355-
w_binomial: if D > 1 { w_binomial } else { None },
447+
w_binomial: Some(w),
356448
mmcs_config: self.mmcs_config,
357449
})
358450
}
@@ -419,7 +511,7 @@ mod tests {
419511
use p3_baby_bear::BabyBear;
420512
use p3_circuit::builder::CircuitBuilder;
421513
use p3_field::PrimeCharacteristicRing;
422-
use p3_field::extension::BinomialExtensionField;
514+
use p3_field::extension::{BinomialExtensionField, BinomiallyExtendable};
423515
use p3_goldilocks::Goldilocks;
424516
use p3_koala_bear::KoalaBear;
425517

@@ -456,7 +548,7 @@ mod tests {
456548

457549
let cfg = config::baby_bear().build();
458550
let prover = BatchStarkProver::new(cfg);
459-
let proof = prover.prove_all_tables(&traces).unwrap();
551+
let proof = prover.prove_all_tables_base(&traces).unwrap();
460552
assert_eq!(proof.ext_degree, 1);
461553
assert!(proof.w_binomial.is_none());
462554
prover.verify_all_tables(&proof).unwrap();
@@ -503,11 +595,15 @@ mod tests {
503595

504596
let cfg = config::baby_bear().build();
505597
let prover = BatchStarkProver::new(cfg);
506-
let proof = prover.prove_all_tables(&traces).unwrap();
598+
let proof = prover
599+
.prove_all_tables_extension(&traces, <BabyBear as BinomiallyExtendable<4>>::W)
600+
.unwrap();
507601
assert_eq!(proof.ext_degree, 4);
508602
// Ensure W was captured
509-
let expected_w = <Ext4 as ExtractBinomialW<BabyBear>>::extract_w().unwrap();
510-
assert_eq!(proof.w_binomial, Some(expected_w));
603+
assert_eq!(
604+
proof.w_binomial,
605+
Some(<BabyBear as BinomiallyExtendable<4>>::W)
606+
);
511607
prover.verify_all_tables(&proof).unwrap();
512608
}
513609

@@ -541,7 +637,7 @@ mod tests {
541637

542638
let cfg = config::koala_bear().build();
543639
let prover = BatchStarkProver::new(cfg);
544-
let proof = prover.prove_all_tables(&traces).unwrap();
640+
let proof = prover.prove_all_tables_base(&traces).unwrap();
545641
assert_eq!(proof.ext_degree, 1);
546642
assert!(proof.w_binomial.is_none());
547643
prover.verify_all_tables(&proof).unwrap();
@@ -620,10 +716,14 @@ mod tests {
620716

621717
let cfg = config::koala_bear().build();
622718
let prover = BatchStarkProver::new(cfg);
623-
let proof = prover.prove_all_tables(&traces).unwrap();
719+
let proof = prover
720+
.prove_all_tables_extension(&traces, <KoalaBear as BinomiallyExtendable<8>>::W)
721+
.unwrap();
624722
assert_eq!(proof.ext_degree, 8);
625-
let expected_w = <KBExtField as ExtractBinomialW<KoalaBear>>::extract_w().unwrap();
626-
assert_eq!(proof.w_binomial, Some(expected_w));
723+
assert_eq!(
724+
proof.w_binomial,
725+
Some(<KoalaBear as BinomiallyExtendable<8>>::W)
726+
);
627727
prover.verify_all_tables(&proof).unwrap();
628728
}
629729

@@ -668,10 +768,14 @@ mod tests {
668768

669769
let cfg = config::goldilocks().build();
670770
let prover = BatchStarkProver::new(cfg);
671-
let proof = prover.prove_all_tables(&traces).unwrap();
771+
let proof = prover
772+
.prove_all_tables_extension(&traces, <Goldilocks as BinomiallyExtendable<2>>::W)
773+
.unwrap();
672774
assert_eq!(proof.ext_degree, 2);
673-
let expected_w = <Ext2 as ExtractBinomialW<Goldilocks>>::extract_w().unwrap();
674-
assert_eq!(proof.w_binomial, Some(expected_w));
775+
assert_eq!(
776+
proof.w_binomial,
777+
Some(<Goldilocks as BinomiallyExtendable<2>>::W)
778+
);
675779
prover.verify_all_tables(&proof).unwrap();
676780
}
677781
}

circuit-prover/src/field_params.rs

Lines changed: 0 additions & 43 deletions
This file was deleted.

circuit-prover/src/lib.rs

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ extern crate alloc;
4040
pub mod air;
4141
pub mod batch_stark_prover;
4242
pub mod config;
43-
pub mod field_params;
4443
pub mod prover;
4544

4645
// Re-export main API

0 commit comments

Comments
 (0)