@@ -10,14 +10,12 @@ use p3_batch_stark::{BatchProof, StarkGenericConfig as MSGC, StarkInstance, Val
1010use p3_circuit:: tables:: Traces ;
1111use p3_field:: { BasedVectorSpace , Field } ;
1212use p3_matrix:: Matrix ;
13- use p3_matrix:: dense:: RowMajorMatrix ;
1413use p3_mmcs_air:: air:: { MmcsTableConfig , MmcsVerifyAir } ;
1514use thiserror:: Error ;
1615use tracing:: instrument;
1716
1817use crate :: air:: { AddAir , ConstAir , MulAir , PublicAir , WitnessAir } ;
1918use crate :: config:: StarkField ;
20- use crate :: field_params:: ExtractBinomialW ;
2119use crate :: prover:: TablePacking ;
2220
2321#[ repr( usize ) ]
@@ -200,23 +198,24 @@ where
200198 self . table_packing
201199 }
202200
203- /// Generate a unified batch STARK proof for all circuit tables.
201+ /// Generate a batch STARK proof for extension field circuits with explicit W.
202+ ///
203+ /// The caller must provide the binomial parameter W for the extension field.
204204 #[ instrument( skip_all) ]
205- pub fn prove_all_tables < EF > (
205+ pub fn prove_all_tables_extension < EF > (
206206 & self ,
207207 traces : & Traces < EF > ,
208+ w : MVal < SC > ,
208209 ) -> Result < BatchStarkProof < SC > , BatchStarkProverError >
209210 where
210- EF : Field + BasedVectorSpace < MVal < SC > > + ExtractBinomialW < MVal < SC > > ,
211+ EF : Field + BasedVectorSpace < MVal < SC > > ,
211212 {
212- let w_opt = EF :: extract_w ( ) ;
213213 match EF :: DIMENSION {
214- 1 => self . prove :: < EF , 1 > ( traces, None ) ,
215- 2 => self . prove :: < EF , 2 > ( traces, w_opt) ,
216- 4 => self . prove :: < EF , 4 > ( traces, w_opt) ,
217- 6 => self . prove :: < EF , 6 > ( traces, w_opt) ,
218- 8 => self . prove :: < EF , 8 > ( traces, w_opt) ,
219- d => Err ( BatchStarkProverError :: UnsupportedDegree ( d) ) ,
214+ 2 => self . prove_extension :: < EF , 2 > ( traces, w) ,
215+ 4 => self . prove_extension :: < EF , 4 > ( traces, w) ,
216+ 6 => self . prove_extension :: < EF , 6 > ( traces, w) ,
217+ 8 => self . prove_extension :: < EF , 8 > ( traces, w) ,
218+ _ => Err ( BatchStarkProverError :: UnsupportedDegree ( EF :: DIMENSION ) ) ,
220219 }
221220 }
222221
@@ -235,15 +234,119 @@ where
235234 }
236235 }
237236
238- /// Generate a batch STARK proof for a specific extension field degree.
237+ /// Generate a batch STARK proof for base field circuits (D=1).
238+ pub fn prove_all_tables_base < EF > (
239+ & self ,
240+ traces : & Traces < EF > ,
241+ ) -> Result < BatchStarkProof < SC > , BatchStarkProverError >
242+ where
243+ EF : Field + BasedVectorSpace < MVal < SC > > ,
244+ {
245+ const D : usize = 1 ;
246+ // TODO: Consider parallelizing AIR construction and trace-to-matrix conversions.
247+ let packing = self . table_packing ;
248+ let add_lanes = packing. add_lanes ( ) ;
249+ let mul_lanes = packing. mul_lanes ( ) ;
250+
251+ // Witness
252+ let witness_rows = traces. witness_trace . values . len ( ) ;
253+ let witness_air = WitnessAir :: < _ , D > :: new ( witness_rows) ;
254+ let witness_matrix = WitnessAir :: < _ , D > :: trace_to_matrix ( & traces. witness_trace ) ;
255+
256+ // Const
257+ let const_rows = traces. const_trace . values . len ( ) ;
258+ let const_air = ConstAir :: new ( const_rows) ;
259+ let const_matrix = ConstAir :: < _ , D > :: trace_to_matrix ( & traces. const_trace ) ;
260+
261+ // Public
262+ let public_rows = traces. public_trace . values . len ( ) ;
263+ let public_air = PublicAir :: new ( public_rows) ;
264+ let public_matrix = PublicAir :: < _ , D > :: trace_to_matrix ( & traces. public_trace ) ;
265+
266+ // Add
267+ let add_rows = traces. add_trace . lhs_values . len ( ) ;
268+ let add_air = AddAir :: new ( add_rows, add_lanes) ;
269+ let add_matrix = AddAir :: < _ , D > :: trace_to_matrix ( & traces. add_trace , add_lanes) ;
270+
271+ // Mul - base field doesn't need binomial parameter
272+ let mul_rows = traces. mul_trace . lhs_values . len ( ) ;
273+ let mul_air = MulAir :: new ( mul_rows, mul_lanes) ;
274+ let mul_matrix = MulAir :: < _ , D > :: trace_to_matrix ( & traces. mul_trace , mul_lanes) ;
275+
276+ // Mmcs
277+ let mmcs_air = MmcsVerifyAir :: new ( self . mmcs_config ) ;
278+ let mmcs_matrix = MmcsVerifyAir :: trace_to_matrix ( & self . mmcs_config , & traces. mmcs_trace ) ;
279+ let mmcs_rows = mmcs_matrix. height ( ) ;
280+
281+ // Wrap AIRs in enum for heterogeneous batching and build instances in fixed order.
282+ let air_witness = CircuitTableAir :: Witness ( witness_air) ;
283+ let air_const = CircuitTableAir :: Const ( const_air) ;
284+ let air_public = CircuitTableAir :: Public ( public_air) ;
285+ let air_add = CircuitTableAir :: Add ( add_air) ;
286+ let air_mul = CircuitTableAir :: Mul ( mul_air) ;
287+ let air_mmcs = CircuitTableAir :: Mmcs ( mmcs_air) ;
288+
289+ // Pre-size for performance
290+ let mut instances = Vec :: with_capacity ( NUM_TABLES ) ;
291+ instances. extend ( [
292+ StarkInstance {
293+ air : & air_witness,
294+ trace : witness_matrix,
295+ public_values : vec ! [ ] ,
296+ } ,
297+ StarkInstance {
298+ air : & air_const,
299+ trace : const_matrix,
300+ public_values : vec ! [ ] ,
301+ } ,
302+ StarkInstance {
303+ air : & air_public,
304+ trace : public_matrix,
305+ public_values : vec ! [ ] ,
306+ } ,
307+ StarkInstance {
308+ air : & air_add,
309+ trace : add_matrix,
310+ public_values : vec ! [ ] ,
311+ } ,
312+ StarkInstance {
313+ air : & air_mul,
314+ trace : mul_matrix,
315+ public_values : vec ! [ ] ,
316+ } ,
317+ StarkInstance {
318+ air : & air_mmcs,
319+ trace : mmcs_matrix,
320+ public_values : vec ! [ ] ,
321+ } ,
322+ ] ) ;
323+
324+ let proof = p3_batch_stark:: prove_batch ( & self . config , instances) ;
325+
326+ Ok ( BatchStarkProof {
327+ proof,
328+ table_packing : packing,
329+ rows : RowCounts :: new ( [
330+ witness_rows,
331+ const_rows,
332+ public_rows,
333+ add_rows,
334+ mul_rows,
335+ mmcs_rows,
336+ ] ) ,
337+ ext_degree : D ,
338+ w_binomial : None ,
339+ mmcs_config : self . mmcs_config ,
340+ } )
341+ }
342+
343+ /// Generate a batch STARK proof for extension field circuits (D>1).
239344 ///
240- /// This is the core proving logic that handles all circuit tables for a given
241- /// extension field dimension. It constructs AIRs, converts traces to matrices,
242- /// and generates the unified proof.
243- fn prove < EF , const D : usize > (
345+ /// The binomial parameter W must be provided by the caller.
346+ fn prove_extension < EF , const D : usize > (
244347 & self ,
245348 traces : & Traces < EF > ,
246- w_binomial : Option < MVal < SC > > ,
349+ w : MVal < SC > ,
247350 ) -> Result < BatchStarkProof < SC > , BatchStarkProverError >
248351 where
249352 EF : Field + BasedVectorSpace < MVal < SC > > ,
@@ -256,43 +359,32 @@ where
256359
257360 // Witness
258361 let witness_rows = traces. witness_trace . values . len ( ) ;
259- let witness_air = WitnessAir :: < MVal < SC > , D > :: new ( witness_rows) ;
260- let witness_matrix: RowMajorMatrix < MVal < SC > > =
261- WitnessAir :: < MVal < SC > , D > :: trace_to_matrix ( & traces. witness_trace ) ;
362+ let witness_air = WitnessAir :: < _ , D > :: new ( witness_rows) ;
363+ let witness_matrix = WitnessAir :: < _ , D > :: trace_to_matrix ( & traces. witness_trace ) ;
262364
263365 // Const
264366 let const_rows = traces. const_trace . values . len ( ) ;
265- let const_air = ConstAir :: < MVal < SC > , D > :: new ( const_rows) ;
266- let const_matrix: RowMajorMatrix < MVal < SC > > =
267- ConstAir :: < MVal < SC > , D > :: trace_to_matrix ( & traces. const_trace ) ;
367+ let const_air = ConstAir :: new ( const_rows) ;
368+ let const_matrix = ConstAir :: < _ , D > :: trace_to_matrix ( & traces. const_trace ) ;
268369
269370 // Public
270371 let public_rows = traces. public_trace . values . len ( ) ;
271- let public_air = PublicAir :: < MVal < SC > , D > :: new ( public_rows) ;
272- let public_matrix: RowMajorMatrix < MVal < SC > > =
273- PublicAir :: < MVal < SC > , D > :: trace_to_matrix ( & traces. public_trace ) ;
372+ let public_air = PublicAir :: new ( public_rows) ;
373+ let public_matrix = PublicAir :: < _ , D > :: trace_to_matrix ( & traces. public_trace ) ;
274374
275375 // Add
276376 let add_rows = traces. add_trace . lhs_values . len ( ) ;
277- let add_air = AddAir :: < MVal < SC > , D > :: new ( add_rows, add_lanes) ;
278- let add_matrix: RowMajorMatrix < MVal < SC > > =
279- AddAir :: < MVal < SC > , D > :: trace_to_matrix ( & traces. add_trace , add_lanes) ;
377+ let add_air = AddAir :: new ( add_rows, add_lanes) ;
378+ let add_matrix = AddAir :: < _ , D > :: trace_to_matrix ( & traces. add_trace , add_lanes) ;
280379
281- // Mul
380+ // Mul - extension field uses provided W parameter
282381 let mul_rows = traces. mul_trace . lhs_values . len ( ) ;
283- let mul_air: MulAir < MVal < SC > , D > = if D == 1 {
284- MulAir :: < MVal < SC > , D > :: new ( mul_rows, mul_lanes)
285- } else {
286- let w = w_binomial. ok_or ( BatchStarkProverError :: MissingWForExtension ) ?;
287- MulAir :: < MVal < SC > , D > :: new_binomial ( mul_rows, mul_lanes, w)
288- } ;
289- let mul_matrix: RowMajorMatrix < MVal < SC > > =
290- MulAir :: < MVal < SC > , D > :: trace_to_matrix ( & traces. mul_trace , mul_lanes) ;
382+ let mul_air = MulAir :: new_binomial ( mul_rows, mul_lanes, w) ;
383+ let mul_matrix = MulAir :: < _ , D > :: trace_to_matrix ( & traces. mul_trace , mul_lanes) ;
291384
292385 // Mmcs
293- let mmcs_air = MmcsVerifyAir :: < MVal < SC > > :: new ( self . mmcs_config ) ;
294- let mmcs_matrix: RowMajorMatrix < MVal < SC > > =
295- MmcsVerifyAir :: trace_to_matrix ( & self . mmcs_config , & traces. mmcs_trace ) ;
386+ let mmcs_air = MmcsVerifyAir :: new ( self . mmcs_config ) ;
387+ let mmcs_matrix = MmcsVerifyAir :: trace_to_matrix ( & self . mmcs_config , & traces. mmcs_trace ) ;
296388 let mmcs_rows: usize = mmcs_matrix. height ( ) ;
297389
298390 // Wrap AIRs in enum for heterogeneous batching and build instances in fixed order.
@@ -352,7 +444,7 @@ where
352444 mmcs_rows,
353445 ] ) ,
354446 ext_degree : D ,
355- w_binomial : if D > 1 { w_binomial } else { None } ,
447+ w_binomial : Some ( w ) ,
356448 mmcs_config : self . mmcs_config ,
357449 } )
358450 }
@@ -419,7 +511,7 @@ mod tests {
419511 use p3_baby_bear:: BabyBear ;
420512 use p3_circuit:: builder:: CircuitBuilder ;
421513 use p3_field:: PrimeCharacteristicRing ;
422- use p3_field:: extension:: BinomialExtensionField ;
514+ use p3_field:: extension:: { BinomialExtensionField , BinomiallyExtendable } ;
423515 use p3_goldilocks:: Goldilocks ;
424516 use p3_koala_bear:: KoalaBear ;
425517
@@ -456,7 +548,7 @@ mod tests {
456548
457549 let cfg = config:: baby_bear ( ) . build ( ) ;
458550 let prover = BatchStarkProver :: new ( cfg) ;
459- let proof = prover. prove_all_tables ( & traces) . unwrap ( ) ;
551+ let proof = prover. prove_all_tables_base ( & traces) . unwrap ( ) ;
460552 assert_eq ! ( proof. ext_degree, 1 ) ;
461553 assert ! ( proof. w_binomial. is_none( ) ) ;
462554 prover. verify_all_tables ( & proof) . unwrap ( ) ;
@@ -503,11 +595,15 @@ mod tests {
503595
504596 let cfg = config:: baby_bear ( ) . build ( ) ;
505597 let prover = BatchStarkProver :: new ( cfg) ;
506- let proof = prover. prove_all_tables ( & traces) . unwrap ( ) ;
598+ let proof = prover
599+ . prove_all_tables_extension ( & traces, <BabyBear as BinomiallyExtendable < 4 > >:: W )
600+ . unwrap ( ) ;
507601 assert_eq ! ( proof. ext_degree, 4 ) ;
508602 // Ensure W was captured
509- let expected_w = <Ext4 as ExtractBinomialW < BabyBear > >:: extract_w ( ) . unwrap ( ) ;
510- assert_eq ! ( proof. w_binomial, Some ( expected_w) ) ;
603+ assert_eq ! (
604+ proof. w_binomial,
605+ Some ( <BabyBear as BinomiallyExtendable <4 >>:: W )
606+ ) ;
511607 prover. verify_all_tables ( & proof) . unwrap ( ) ;
512608 }
513609
@@ -541,7 +637,7 @@ mod tests {
541637
542638 let cfg = config:: koala_bear ( ) . build ( ) ;
543639 let prover = BatchStarkProver :: new ( cfg) ;
544- let proof = prover. prove_all_tables ( & traces) . unwrap ( ) ;
640+ let proof = prover. prove_all_tables_base ( & traces) . unwrap ( ) ;
545641 assert_eq ! ( proof. ext_degree, 1 ) ;
546642 assert ! ( proof. w_binomial. is_none( ) ) ;
547643 prover. verify_all_tables ( & proof) . unwrap ( ) ;
@@ -620,10 +716,14 @@ mod tests {
620716
621717 let cfg = config:: koala_bear ( ) . build ( ) ;
622718 let prover = BatchStarkProver :: new ( cfg) ;
623- let proof = prover. prove_all_tables ( & traces) . unwrap ( ) ;
719+ let proof = prover
720+ . prove_all_tables_extension ( & traces, <KoalaBear as BinomiallyExtendable < 8 > >:: W )
721+ . unwrap ( ) ;
624722 assert_eq ! ( proof. ext_degree, 8 ) ;
625- let expected_w = <KBExtField as ExtractBinomialW < KoalaBear > >:: extract_w ( ) . unwrap ( ) ;
626- assert_eq ! ( proof. w_binomial, Some ( expected_w) ) ;
723+ assert_eq ! (
724+ proof. w_binomial,
725+ Some ( <KoalaBear as BinomiallyExtendable <8 >>:: W )
726+ ) ;
627727 prover. verify_all_tables ( & proof) . unwrap ( ) ;
628728 }
629729
@@ -668,10 +768,14 @@ mod tests {
668768
669769 let cfg = config:: goldilocks ( ) . build ( ) ;
670770 let prover = BatchStarkProver :: new ( cfg) ;
671- let proof = prover. prove_all_tables ( & traces) . unwrap ( ) ;
771+ let proof = prover
772+ . prove_all_tables_extension ( & traces, <Goldilocks as BinomiallyExtendable < 2 > >:: W )
773+ . unwrap ( ) ;
672774 assert_eq ! ( proof. ext_degree, 2 ) ;
673- let expected_w = <Ext2 as ExtractBinomialW < Goldilocks > >:: extract_w ( ) . unwrap ( ) ;
674- assert_eq ! ( proof. w_binomial, Some ( expected_w) ) ;
775+ assert_eq ! (
776+ proof. w_binomial,
777+ Some ( <Goldilocks as BinomiallyExtendable <2 >>:: W )
778+ ) ;
675779 prover. verify_all_tables ( & proof) . unwrap ( ) ;
676780 }
677781}
0 commit comments