diff --git a/circuit-prover/src/prover.rs b/circuit-prover/src/prover.rs index d9d6ea4..155fd7b 100644 --- a/circuit-prover/src/prover.rs +++ b/circuit-prover/src/prover.rs @@ -10,16 +10,17 @@ //! Supports base fields (D=1) and binomial extension fields (D>1), with automatic //! detection of the binomial parameter `W` for extension-field multiplication. +use alloc::boxed::Box; use alloc::vec; use alloc::vec::Vec; use p3_circuit::tables::Traces; use p3_circuit::{CircuitBuilderError, CircuitError}; +use p3_field::extension::BinomialExtensionField; use p3_field::{BasedVectorSpace, Field}; use p3_mmcs_air::air::{MmcsTableConfig, MmcsVerifyAir}; use p3_uni_stark::{StarkGenericConfig, Val, prove, verify}; use thiserror::Error; -use tracing::instrument; use crate::air::{AddAir, ConstAir, MulAir, PublicAir, WitnessAir}; use crate::config::StarkField; @@ -81,12 +82,16 @@ pub struct MultiTableProof where SC: StarkGenericConfig, { + // Primitive tables pub witness: TableProof, pub constants: TableProof, pub public: TableProof, pub add: TableProof, pub mul: TableProof, - pub mmcs: TableProof, + + /// Dynamic non-primitive table proofs, extensible at runtime. + pub non_primitives: Vec>, + /// Packing configuration used when generating the proofs. pub table_packing: TablePacking, /// Extension field degree: 1 for base field; otherwise the extension degree used. @@ -95,6 +100,19 @@ where pub w_binomial: Option>, } +/// Dynamic table proof entry for non-primitive tables +pub struct TableProofEntry +where + SC: StarkGenericConfig, +{ + /// Identifier for the table. + pub id: &'static str, + /// Proof for the table. + pub proof: StarkProof, + /// Number of logical rows (operations) prior to any per-row packing. + pub rows: usize, +} + /// Multi-table STARK prover for circuit execution traces. /// /// Generic over `SC: StarkGenericConfig` to support different field configurations. @@ -104,7 +122,8 @@ where { config: SC, table_packing: TablePacking, - mmcs_config: MmcsTableConfig, + /// Registered non-primitive provers. + non_primitive_provers: Vec>>, } /// Errors that can arise during proving or verification. @@ -140,7 +159,7 @@ where Self { config, table_packing: TablePacking::default(), - mmcs_config: MmcsTableConfig::default(), + non_primitive_provers: Vec::new(), } } @@ -157,17 +176,25 @@ where self.table_packing } + /// Register MMCS verification plugin pub fn with_mmcs_table(mut self, mmcs_config: MmcsTableConfig) -> Self { - self.mmcs_config = mmcs_config; + let plugin = Box::new(MmcsProver { + config: mmcs_config, + }); + self.register_prover(plugin); self } + /// Register a non-primitive prover + pub fn register_prover(&mut self, plugin: Box>) { + self.non_primitive_provers.push(plugin); + } + /// Generate proofs for all circuit tables. /// /// Automatically detects whether to use base field or binomial extension field /// proving based on the circuit element type `EF`. For extension fields, /// the binomial parameter W is automatically extracted. - #[instrument(skip_all)] pub fn prove_all_tables( &self, traces: &Traces, @@ -249,9 +276,59 @@ where }; let mul_proof = prove(&self.config, &mul_air, mul_matrix, pis); - let mmcs_matrix = MmcsVerifyAir::trace_to_matrix(&self.mmcs_config, &traces.mmcs_trace); - let mmcs_air = MmcsVerifyAir::new(self.mmcs_config); - let mmcs_proof = prove(&self.config, &mmcs_air, mmcs_matrix, pis); + // Handle all registered non-primitive tables dynamically + let mut non_primitives = Vec::new(); + match D { + 1 => { + let t: &Traces> = unsafe { transmute_traces(traces) }; + for p in &self.non_primitive_provers { + if let Some(entry) = p.prove_d1(&self.config, table_packing, t, pis) { + non_primitives.push(entry); + } + } + } + 2 => { + type EF2 = BinomialExtensionField; + + let t: &Traces>> = unsafe { transmute_traces(traces) }; + for p in &self.non_primitive_provers { + if let Some(entry) = p.prove_d2(&self.config, table_packing, t, pis) { + non_primitives.push(entry); + } + } + } + 4 => { + type EF4 = BinomialExtensionField; + + let t: &Traces>> = unsafe { transmute_traces(traces) }; + for p in &self.non_primitive_provers { + if let Some(entry) = p.prove_d4(&self.config, table_packing, t, pis) { + non_primitives.push(entry); + } + } + } + 6 => { + type EF6 = BinomialExtensionField; + + let t: &Traces>> = unsafe { transmute_traces(traces) }; + for p in &self.non_primitive_provers { + if let Some(entry) = p.prove_d6(&self.config, table_packing, t, pis) { + non_primitives.push(entry); + } + } + } + 8 => { + type EF8 = BinomialExtensionField; + + let t: &Traces>> = unsafe { transmute_traces(traces) }; + for p in &self.non_primitive_provers { + if let Some(entry) = p.prove_d8(&self.config, table_packing, t, pis) { + non_primitives.push(entry); + } + } + } + _ => unreachable!(), + } Ok(MultiTableProof { witness: TableProof { @@ -274,15 +351,7 @@ where proof: mul_proof, rows: traces.mul_trace.lhs_values.len(), }, - mmcs: TableProof { - proof: mmcs_proof, - rows: traces - .mmcs_trace - .mmcs_paths - .iter() - .map(|path| path.left_values.len() + 1) - .sum(), - }, + non_primitives, table_packing, ext_degree: D, w_binomial: if D > 1 { w_binomial } else { None }, @@ -329,19 +398,264 @@ where verify(&self.config, &mul_air, &proof.mul.proof, pis) .map_err(|_| ProverError::VerificationFailed { phase: "mul" })?; - // MmcsVerify - - let mmcs_air = MmcsVerifyAir::new(self.mmcs_config); - verify(&self.config, &mmcs_air, &proof.mmcs.proof, pis).map_err(|_| { - ProverError::VerificationFailed { - phase: "mmcs_verify", - } - })?; + // Verify non-primitive tables + for entry in &proof.non_primitives { + let plugin = self + .non_primitive_provers + .iter() + .find(|p| p.id() == entry.id) + .ok_or(ProverError::VerificationFailed { + phase: "unknown_non_primitive", + })?; + plugin.verify(&self.config, D, proof.table_packing, entry, w_binomial, pis)?; + } Ok(()) } } +#[allow(clippy::ptr_arg)] // Plonky3 treats `pis` as `&Vec<_>`. +/// Table prover plugin trait, used to prove and verify non-primitive tables. +/// +/// Because of some limitations of object-safety, we need to split all instances +/// of the generic `prove` method into extension-degree-specific `prove_dN` methods. +/// +/// Users wishing to implement a new non-primitive table prover can implement this trait +/// and use the `impl_table_prover_degrees_from_base!` macro to automatically derive +/// the extension-degree-specific `prove_dN` methods. +pub trait TableProver: Send + Sync +where + SC: StarkGenericConfig, +{ + /// Identifier for this prover. + fn id(&self) -> &'static str; + + /// Prove a non-primitive table in the base field. + fn prove_d1( + &self, + cfg: &SC, + packing: TablePacking, + traces: &Traces>, + pis: &Vec>, + ) -> Option>; + + /// Prove a non-primitive table in the extension field of degree 2. + fn prove_d2( + &self, + cfg: &SC, + packing: TablePacking, + traces: &Traces, 2>>, + pis: &Vec>, + ) -> Option>; + + /// Prove a non-primitive table in the extension field of degree 4. + fn prove_d4( + &self, + cfg: &SC, + packing: TablePacking, + traces: &Traces, 4>>, + pis: &Vec>, + ) -> Option>; + + /// Prove a non-primitive table in the extension field of degree 6. + fn prove_d6( + &self, + cfg: &SC, + packing: TablePacking, + traces: &Traces, 6>>, + pis: &Vec>, + ) -> Option>; + + /// Prove a non-primitive table in the extension field of degree 8. + fn prove_d8( + &self, + cfg: &SC, + packing: TablePacking, + traces: &Traces, 8>>, + pis: &Vec>, + ) -> Option>; + + /// Verify a non-primitive table. + fn verify( + &self, + cfg: &SC, + degree: usize, + packing: TablePacking, + entry: &TableProofEntry, + w_binomial: Option>, + pis: &Vec>, + ) -> Result<(), ProverError>; +} + +#[inline(always)] +unsafe fn transmute_traces(t: &Traces) -> &Traces { + unsafe { &*(t as *const _ as *const Traces) } +} + +#[macro_export] +/// Macro to implement the `TableProver` trait for a given base prover. +/// +/// It will derive all the `prove_dN` methods for proving in extension fields of degree N. +/// +/// # Examples +/// +/// ```ignore +/// pub struct MyProver { pub config: MyConfig } +/// +/// impl MyProver { +/// fn prove_base(&self, cfg: &SC, packing: TablePacking, traces: &Traces>, pis: &Vec>) -> Option> { +/// Some(TableProofEntry { id: "my_prover", proof: prove(cfg, &air, matrix, pis), rows: traces.values.len() }) +/// } +/// } +/// +/// impl TableProver for MyProver { +/// fn id(&self) -> &'static str { "my_prover" } +/// +/// // Derive all extension-degree-specific prove methods from the base prove method. +/// impl_table_prover_degrees_from_base!(MyProver, prove_base); +/// +/// fn verify( +/// &self, +/// cfg: &SC, +/// degree: usize, +/// packing: TablePacking, +/// entry: &TableProofEntry, +/// w_binomial: Option>, +/// pis: &Vec>, +/// ) -> Result<(), ProverError> { Ok(()) } +/// } +/// +/// ``` +macro_rules! impl_table_prover_degrees_from_base { + ($base:ident) => { + fn prove_d1( + &self, + cfg: &SC, + packing: $crate::prover::TablePacking, + traces: &p3_circuit::tables::Traces>, + pis: &alloc::vec::Vec>, + ) -> Option<$crate::prover::TableProofEntry> { + self.$base::(cfg, packing, traces, pis) + } + + fn prove_d2( + &self, + cfg: &SC, + packing: $crate::prover::TablePacking, + traces: &p3_circuit::tables::Traces< + p3_field::extension::BinomialExtensionField, 2>, + >, + pis: &alloc::vec::Vec>, + ) -> Option<$crate::prover::TableProofEntry> { + let t: &p3_circuit::tables::Traces> = + unsafe { $crate::prover::transmute_traces(traces) }; + self.$base::(cfg, packing, t, pis) + } + + fn prove_d4( + &self, + cfg: &SC, + packing: $crate::prover::TablePacking, + traces: &p3_circuit::tables::Traces< + p3_field::extension::BinomialExtensionField, 4>, + >, + pis: &alloc::vec::Vec>, + ) -> Option<$crate::prover::TableProofEntry> { + let t: &p3_circuit::tables::Traces> = + unsafe { $crate::prover::transmute_traces(traces) }; + self.$base::(cfg, packing, t, pis) + } + + fn prove_d6( + &self, + cfg: &SC, + packing: $crate::prover::TablePacking, + traces: &p3_circuit::tables::Traces< + p3_field::extension::BinomialExtensionField, 6>, + >, + pis: &alloc::vec::Vec>, + ) -> Option<$crate::prover::TableProofEntry> { + let t: &p3_circuit::tables::Traces> = + unsafe { $crate::prover::transmute_traces(traces) }; + self.$base::(cfg, packing, t, pis) + } + + fn prove_d8( + &self, + cfg: &SC, + packing: $crate::prover::TablePacking, + traces: &p3_circuit::tables::Traces< + p3_field::extension::BinomialExtensionField, 8>, + >, + pis: &alloc::vec::Vec>, + ) -> Option<$crate::prover::TableProofEntry> { + let t: &p3_circuit::tables::Traces> = + unsafe { $crate::prover::transmute_traces(traces) }; + self.$base::(cfg, packing, t, pis) + } + }; +} + +/// MMCS prover plugin +pub struct MmcsProver { + pub config: MmcsTableConfig, +} + +impl MmcsProver { + fn prove_base( + &self, + cfg: &SC, + _packing: TablePacking, + traces: &Traces>, + pis: &Vec>, + ) -> Option> + where + SC: StarkGenericConfig, + Val: StarkField, + { + let t = &traces.mmcs_trace; + if t.mmcs_paths.is_empty() { + return None; + } + let rows: usize = t.mmcs_paths.iter().map(|p| p.left_values.len() + 1).sum(); + let matrix = MmcsVerifyAir::trace_to_matrix(&self.config, t); + let air = MmcsVerifyAir::new(self.config); + let proof = prove(cfg, &air, matrix, pis); + Some(TableProofEntry { + id: >::id(self), + proof, + rows, + }) + } +} + +impl TableProver for MmcsProver +where + SC: StarkGenericConfig, + Val: StarkField, +{ + fn id(&self) -> &'static str { + "mmcs_verify" + } + + impl_table_prover_degrees_from_base!(prove_base); + + fn verify( + &self, + cfg: &SC, + _degree: usize, + _packing: TablePacking, + entry: &TableProofEntry, + _w_binomial: Option>, + pis: &Vec>, + ) -> Result<(), ProverError> { + let air = MmcsVerifyAir::new(self.config); + verify(cfg, &air, &entry.proof, pis).map_err(|_| ProverError::VerificationFailed { + phase: >::id(self), + }) + } +} + #[cfg(test)] mod tests { use p3_baby_bear::BabyBear; diff --git a/circuit/src/builder/circuit_builder.rs b/circuit/src/builder/circuit_builder.rs index 3f557ef..545dea8 100644 --- a/circuit/src/builder/circuit_builder.rs +++ b/circuit/src/builder/circuit_builder.rs @@ -67,6 +67,22 @@ where self.config.enable_mmcs(mmcs_config); } + /// Enables HashAbsorb operations. + pub fn enable_hash_absorb(&mut self, reset: bool) { + self.config.enable_hash_absorb(reset); + } + + /// Enables HashSqueeze operations. + pub fn enable_hash_squeeze(&mut self) { + self.config.enable_hash_squeeze(); + } + + /// Enables hash operations. + pub fn enable_hash(&mut self, reset: bool) { + self.enable_hash_absorb(reset); + self.enable_hash_squeeze(); + } + /// Enables FRI verification operations. pub fn enable_fri(&mut self) { self.config.enable_fri(); diff --git a/circuit/src/builder/config.rs b/circuit/src/builder/config.rs index 48865cc..b3d68b3 100644 --- a/circuit/src/builder/config.rs +++ b/circuit/src/builder/config.rs @@ -31,6 +31,19 @@ impl BuilderConfig { ); } + /// Enables HashAbsorb operations. + pub fn enable_hash_absorb(&mut self, reset: bool) { + self.enable_op( + NonPrimitiveOpType::HashAbsorb { reset }, + NonPrimitiveOpConfig::None, + ); + } + + /// Enables HashSqueeze operations. + pub fn enable_hash_squeeze(&mut self) { + self.enable_op(NonPrimitiveOpType::HashSqueeze, NonPrimitiveOpConfig::None); + } + /// Enables FRI verification operations. pub fn enable_fri(&mut self) { // TODO: Add FRI ops when available. diff --git a/circuit/src/tables/mmcs.rs b/circuit/src/tables/mmcs.rs index a152498..fb7c45a 100644 --- a/circuit/src/tables/mmcs.rs +++ b/circuit/src/tables/mmcs.rs @@ -685,7 +685,7 @@ mod tests { ]; let root_exprs = (0..config.ext_field_digest_elems) .map(|_| builder.add_public_input()) - .collect::>(); + .collect::>(); let mmcs_op_id = builder .add_mmcs_verify(&leaves_expr, &directions_expr, &root_exprs) .unwrap(); diff --git a/circuit/src/tables/runner.rs b/circuit/src/tables/runner.rs index 939b608..7374da5 100644 --- a/circuit/src/tables/runner.rs +++ b/circuit/src/tables/runner.rs @@ -261,15 +261,14 @@ impl CircuitRunner { #[cfg(test)] mod tests { extern crate std; - use alloc::vec; use std::println; use p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; use p3_field::{BasedVectorSpace, PrimeCharacteristicRing}; + use super::*; use crate::builder::CircuitBuilder; - use crate::types::WitnessId; #[test] fn test_table_generation_basic() {