From 7edc5f10334969a20076eff291c52c0703a1e318 Mon Sep 17 00:00:00 2001 From: Sai Deng Date: Thu, 6 Nov 2025 09:53:00 -0800 Subject: [PATCH] rm --- book/src/circuit_building.md | 4 +- circuit-prover/Cargo.toml | 4 - circuit-prover/examples/fibonacci.rs | 13 +- circuit-prover/examples/mmcs_verify.rs | 152 ------ circuit-prover/src/batch_stark_prover.rs | 35 +- circuit-prover/src/config.rs | 6 +- circuit-prover/src/lib.rs | 10 +- circuit-prover/src/prover.rs | 663 ----------------------- 8 files changed, 50 insertions(+), 837 deletions(-) delete mode 100644 circuit-prover/examples/mmcs_verify.rs delete mode 100644 circuit-prover/src/prover.rs diff --git a/book/src/circuit_building.md b/book/src/circuit_building.md index 21254aa..9c7ce59 100644 --- a/book/src/circuit_building.md +++ b/book/src/circuit_building.md @@ -123,13 +123,13 @@ runner.set_public_inputs(&[expected_fib])?; // Instantiate prover instance let config = build_standard_config_koalabear(); -let multi_prover = MultiTableProver::new(config); +let prover = BatchStarkProver::new(config); // Generate traces let traces = runner.run()?; // Prove the program -let proof = multi_prover.prove_all_tables(&traces)?; +let proof = prover.prove_all_tables(&traces)?; ``` ## Key takeaways diff --git a/circuit-prover/Cargo.toml b/circuit-prover/Cargo.toml index cc1c22c..ec4279a 100644 --- a/circuit-prover/Cargo.toml +++ b/circuit-prover/Cargo.toml @@ -46,7 +46,3 @@ parallel = ["p3-maybe-rayon/parallel"] [[example]] name = "fibonacci" path = "examples/fibonacci.rs" - -[[example]] -name = "mmcs_verify" -path = "examples/mmcs_verify.rs" diff --git a/circuit-prover/examples/fibonacci.rs b/circuit-prover/examples/fibonacci.rs index 330bb73..56385fe 100644 --- a/circuit-prover/examples/fibonacci.rs +++ b/circuit-prover/examples/fibonacci.rs @@ -1,11 +1,11 @@ use std::env; +use std::error::Error; /// Fibonacci circuit: Compute F(n) and prove correctness /// Public input: expected_result (F(n)) use p3_baby_bear::BabyBear; use p3_circuit::CircuitBuilder; -use p3_circuit_prover::prover::ProverError; -use p3_circuit_prover::{MultiTableProver, TablePacking, config}; +use p3_circuit_prover::{BatchStarkProver, TablePacking, config}; use p3_field::PrimeCharacteristicRing; use tracing_forest::ForestLayer; use tracing_forest::util::LevelFilter; @@ -26,7 +26,7 @@ fn init_logger() { .init(); } -fn main() -> Result<(), ProverError> { +fn main() -> Result<(), Box> { init_logger(); let n = env::args() @@ -64,9 +64,10 @@ fn main() -> Result<(), ProverError> { let traces = runner.run()?; let config = config::baby_bear().build(); let table_packing = TablePacking::from_counts(4, 1); - let multi_prover = MultiTableProver::new(config).with_table_packing(table_packing); - let proof = multi_prover.prove_all_tables(&traces)?; - multi_prover.verify_all_tables(&proof) + let prover = BatchStarkProver::new(config).with_table_packing(table_packing); + let proof = prover.prove_all_tables(&traces)?; + prover.verify_all_tables(&proof)?; + Ok(()) } fn compute_fibonacci_classical(n: usize) -> F { diff --git a/circuit-prover/examples/mmcs_verify.rs b/circuit-prover/examples/mmcs_verify.rs deleted file mode 100644 index 05f0a21..0000000 --- a/circuit-prover/examples/mmcs_verify.rs +++ /dev/null @@ -1,152 +0,0 @@ -use std::env; - -/// Mmcs verification circuit: Prove knowledge of a leaf in a Mmcs tree -/// Public inputs: leaf_hash, leaf_index, expected_root -/// Private inputs: mmcs path (siblings + directions) -use p3_baby_bear::BabyBear; -use p3_circuit::ops::MmcsVerifyConfig; -use p3_circuit::tables::MmcsPrivateData; -use p3_circuit::{CircuitBuilder, ExprId, MmcsOps, NonPrimitiveOpPrivateData}; -use p3_circuit_prover::prover::ProverError; -use p3_circuit_prover::{MultiTableProver, config}; -use p3_field::PrimeCharacteristicRing; -use p3_field::extension::BinomialExtensionField; -use tracing_forest::ForestLayer; -use tracing_forest::util::LevelFilter; -use tracing_subscriber::layer::SubscriberExt; -use tracing_subscriber::util::SubscriberInitExt; -use tracing_subscriber::{EnvFilter, Registry}; - -type F = BinomialExtensionField; - -fn init_logger() { - let env_filter = EnvFilter::builder() - .with_default_directive(LevelFilter::INFO.into()) - .from_env_lossy(); - - Registry::default() - .with(env_filter) - .with(ForestLayer::default()) - .init(); -} - -fn main() -> Result<(), ProverError> { - init_logger(); - - let depth = env::args().nth(1).and_then(|s| s.parse().ok()).unwrap_or(3); - let config = config::baby_bear().build(); - let compress = config::baby_bear_compression(); - let mmcs_config = MmcsVerifyConfig::babybear_quartic_extension_default(); - - let mut builder = CircuitBuilder::::new(); - builder.enable_mmcs(&mmcs_config); - - // Public inputs: leaf hash and expected root hash - // The leaves will contain `mmcs_config.ext_field_digest_elems` wires, - // when the leaf index is odd, and an empty vector otherwise. This means - // we're proving the opening of an Mmcs to matrices of height 2^depth, 2^(depth -1), ... - let leaves: Vec> = (0..depth) - .map(|i| { - (0..if i % 2 == 0 && i != depth - 1 { - mmcs_config.ext_field_digest_elems - } else { - 0 - }) - .map(|_| builder.alloc_public_input("leaf_hash")) - .collect::>() - }) - .collect(); - let directions: Vec = (0..depth) - .map(|_| builder.alloc_public_input("directions")) - .collect(); - let expected_root = (0..mmcs_config.ext_field_digest_elems) - .map(|_| builder.alloc_public_input("expected_root")) - .collect::>(); - // Add a Mmcs verification operation - // This declares that leaf_hash and expected_root are connected to witness bus - // The AIR constraints will verify the Mmcs path is valid - let mmcs_op_id = builder.add_mmcs_verify(&leaves, &directions, &expected_root)?; - - builder.dump_allocation_log(); - - let circuit = builder.build()?; - let mut runner = circuit.runner(); - - // Set public inputs - // - let leaves_value: Vec> = (0..depth) - .map(|i| { - if i % 2 == 0 && i != depth - 1 { - vec![ - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::from_u64(42), - ] - } else { - vec![] - } - }) - .collect(); // Our leaf value - let siblings: Vec> = (0..depth) - .map(|i| { - vec![ - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::ZERO, - F::from_u64((i + 1) * 10), - ] - }) - .collect(); - - // the index is 0b1010... - let directions: Vec = (0..depth).map(|i| i % 2 == 0).collect(); - - let MmcsPrivateData { - path_states: intermediate_states, - .. - } = MmcsPrivateData::new( - &compress, - &mmcs_config, - &leaves_value, - &siblings, - &directions, - )?; - let expected_root_value = intermediate_states - .last() - .expect("There is always at least the leaf hash") - .0 - .clone(); - - let mut public_inputs = vec![]; - public_inputs.extend(leaves_value.iter().flatten()); - public_inputs.extend(directions.iter().map(|dir| F::from_bool(*dir))); - public_inputs.extend(&expected_root_value); - - runner.set_public_inputs(&public_inputs)?; - // Set private Mmcs path data - runner.set_non_primitive_op_private_data( - mmcs_op_id, - NonPrimitiveOpPrivateData::MmcsVerify(MmcsPrivateData::new( - &compress, - &mmcs_config, - &leaves_value, - &siblings, - &directions, - )?), - )?; - let traces = runner.run()?; - let multi_prover = MultiTableProver::new(config).with_mmcs_table(mmcs_config.into()); - let proof = multi_prover.prove_all_tables(&traces)?; - multi_prover.verify_all_tables(&proof)?; - - Ok(()) -} diff --git a/circuit-prover/src/batch_stark_prover.rs b/circuit-prover/src/batch_stark_prover.rs index e9aa70a..c10b8ef 100644 --- a/circuit-prover/src/batch_stark_prover.rs +++ b/circuit-prover/src/batch_stark_prover.rs @@ -16,7 +16,40 @@ use tracing::instrument; use crate::air::{AddAir, ConstAir, MulAir, PublicAir, WitnessAir}; use crate::config::StarkField; use crate::field_params::ExtractBinomialW; -use crate::prover::TablePacking; + +// Packing configuration for Add/Mul tables. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct TablePacking { + add_lanes: usize, + mul_lanes: usize, +} + +impl TablePacking { + pub fn new(add_lanes: usize, mul_lanes: usize) -> Self { + Self { + add_lanes: add_lanes.max(1), + mul_lanes: mul_lanes.max(1), + } + } + + pub fn from_counts(add_lanes: usize, mul_lanes: usize) -> Self { + Self::new(add_lanes, mul_lanes) + } + + pub const fn add_lanes(self) -> usize { + self.add_lanes + } + + pub const fn mul_lanes(self) -> usize { + self.mul_lanes + } +} + +impl Default for TablePacking { + fn default() -> Self { + Self::new(1, 1) + } +} #[repr(usize)] #[derive(Copy, Clone, Debug, PartialEq, Eq)] diff --git a/circuit-prover/src/config.rs b/circuit-prover/src/config.rs index 55d2f52..a8a0c57 100644 --- a/circuit-prover/src/config.rs +++ b/circuit-prover/src/config.rs @@ -193,7 +193,7 @@ where /// /// ```ignore /// let config = config::baby_bear().build(); -/// let prover = MultiTableProver::new(config); +/// let prover = BatchStarkProver::new(config); /// ``` #[inline] pub fn baby_bear() @@ -228,7 +228,7 @@ pub fn baby_bear_compression() -> impl PseudoCompressionFunction<[BabyBear; 8], /// /// ```ignore /// let config = config::koala_bear().build(); -/// let prover = MultiTableProver::new(config); +/// let prover = BatchStarkProver::new(config); /// ``` #[inline] pub fn koala_bear() @@ -263,7 +263,7 @@ pub fn koala_bear_compression() -> impl PseudoCompressionFunction<[KoalaBear; 8] /// /// ```ignore /// let config = config::goldilocks().build(); -/// let prover = MultiTableProver::new(config); +/// let prover = BatchStarkProver::new(config); /// ``` #[inline] pub fn goldilocks() diff --git a/circuit-prover/src/lib.rs b/circuit-prover/src/lib.rs index f421c4c..aad3dd7 100644 --- a/circuit-prover/src/lib.rs +++ b/circuit-prover/src/lib.rs @@ -8,7 +8,7 @@ //! - `CD`: FRI challenge field degree, independent of `D`. //! //! - Build a field-specific config via `config::{babybear_config, koalabear_config, goldilocks_config}`. -//! - Create a `MultiTableProver` from that config. +//! - Create a `BatchStarkProver` from that config. //! - Generate traces from a `p3_circuit::Circuit` runner and prove/verify. //! //! Example (BabyBear): @@ -17,7 +17,7 @@ //! use p3_baby_bear::BabyBear; //! use p3_circuit::builder::CircuitBuilder; //! use p3_circuit_prover::config::babybear_config::build_standard_config_babybear; -//! use p3_circuit_prover::MultiTableProver; +//! use p3_circuit_prover::BatchStarkProver; //! //! let mut builder = CircuitBuilder::::new(); //! let x = builder.add_public_input(); @@ -29,7 +29,7 @@ //! runner.set_public_inputs(&[BabyBear::from_u64(1), BabyBear::from_u64(2)]).unwrap(); //! let traces = runner.run().unwrap(); //! let cfg = build_standard_config_babybear(); -//! let prover = MultiTableProver::new(cfg); +//! let prover = BatchStarkProver::new(cfg); //! let proof = prover.prove_all_tables(&traces).unwrap(); //! prover.verify_all_tables(&proof).unwrap(); //! ``` @@ -41,8 +41,6 @@ pub mod air; pub mod batch_stark_prover; pub mod config; pub mod field_params; -pub mod prover; // Re-export main API -pub use batch_stark_prover::{BatchStarkProof, BatchStarkProver}; -pub use prover::{MultiTableProof, MultiTableProver, TablePacking}; +pub use batch_stark_prover::{BatchStarkProof, BatchStarkProver, TablePacking}; diff --git a/circuit-prover/src/prover.rs b/circuit-prover/src/prover.rs deleted file mode 100644 index d9d6ea4..0000000 --- a/circuit-prover/src/prover.rs +++ /dev/null @@ -1,663 +0,0 @@ -//! Multi-table prover and verifier for STARK proofs. -//! -//! Generic roles and degrees: -//! - F: Prover/verifier base field (BabyBear/KoalaBear/Goldilocks). All PCS/FRI arithmetic runs over `F`. -//! - P: Cryptographic permutation over `F` used by the hash/compress functions and challenger. -//! - EF: Element field used in circuit traces. Either the base field `F` or a binomial extension `BinomialExtensionField`. -//! - D: Element-field extension degree. Must equal `EF::DIMENSION` and is used by AIRs like `WitnessAir` to expand EF values into D base limbs. -//! - CD: Challenge field degree for FRI (independent of `D`). The challenger/PCS use `BinomialExtensionField`. -//! -//! Supports base fields (D=1) and binomial extension fields (D>1), with automatic -//! detection of the binomial parameter `W` for extension-field multiplication. - -use alloc::vec; -use alloc::vec::Vec; - -use p3_circuit::tables::Traces; -use p3_circuit::{CircuitBuilderError, CircuitError}; -use p3_field::{BasedVectorSpace, Field}; -use p3_mmcs_air::air::{MmcsTableConfig, MmcsVerifyAir}; -use p3_uni_stark::{StarkGenericConfig, Val, prove, verify}; -use thiserror::Error; -use tracing::instrument; - -use crate::air::{AddAir, ConstAir, MulAir, PublicAir, WitnessAir}; -use crate::config::StarkField; -use crate::field_params::ExtractBinomialW; - -/// Configuration for packing multiple primitive operations into a single AIR row. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct TablePacking { - add_lanes: usize, - mul_lanes: usize, -} - -impl TablePacking { - pub fn new(add_lanes: usize, mul_lanes: usize) -> Self { - Self { - add_lanes: add_lanes.max(1), - mul_lanes: mul_lanes.max(1), - } - } - - pub fn from_counts(add_lanes: usize, mul_lanes: usize) -> Self { - Self::new(add_lanes, mul_lanes) - } - - pub const fn add_lanes(self) -> usize { - self.add_lanes - } - - pub const fn mul_lanes(self) -> usize { - self.mul_lanes - } -} - -impl Default for TablePacking { - fn default() -> Self { - Self::new(1, 1) - } -} - -/// STARK proof type alias for convenience. -pub type StarkProof = p3_uni_stark::Proof; - -/// Proof and metadata for a single table. -pub struct TableProof -where - SC: StarkGenericConfig, -{ - pub proof: StarkProof, - /// Number of logical rows (operations) prior to any per-row packing. - pub rows: usize, -} - -/// Complete proof bundle containing proofs for all circuit tables. -/// -/// Includes metadata for verification, such as: -/// - `ext_degree`: circuit element extension degree used in traces (may differ from challenge degree). -/// - `w_binomial`: binomial parameter `W` for element-field multiplication, when applicable. -pub struct MultiTableProof -where - SC: StarkGenericConfig, -{ - pub witness: TableProof, - pub constants: TableProof, - pub public: TableProof, - pub add: TableProof, - pub mul: TableProof, - pub mmcs: TableProof, - /// Packing configuration used when generating the proofs. - pub table_packing: TablePacking, - /// Extension field degree: 1 for base field; otherwise the extension degree used. - pub ext_degree: usize, - /// Binomial parameter W for extension fields (e.g., x^D = W); None for base fields - pub w_binomial: Option>, -} - -/// Multi-table STARK prover for circuit execution traces. -/// -/// Generic over `SC: StarkGenericConfig` to support different field configurations. -pub struct MultiTableProver -where - SC: StarkGenericConfig, -{ - config: SC, - table_packing: TablePacking, - mmcs_config: MmcsTableConfig, -} - -/// Errors that can arise during proving or verification. -#[derive(Debug, Error)] -pub enum ProverError { - /// Unsupported extension degree encountered. - #[error("unsupported extension degree: {0} (supported: 1,2,4,6,8)")] - UnsupportedDegree(usize), - - /// Missing binomial parameter W for extension-field multiplication. - #[error("missing binomial parameter W for extension-field multiplication")] - MissingWForExtension, - - /// Circuit execution error. - #[error("circuit error: {0}")] - Circuit(#[from] CircuitError), - - /// Circuit building/lowering error. - #[error("circuit build error: {0}")] - Builder(#[from] CircuitBuilderError), - - /// Verification failed for a specific table/phase. - #[error("verification failed in {phase}")] - VerificationFailed { phase: &'static str }, -} - -impl MultiTableProver -where - SC: StarkGenericConfig, - Val: StarkField, -{ - pub fn new(config: SC) -> Self { - Self { - config, - table_packing: TablePacking::default(), - mmcs_config: MmcsTableConfig::default(), - } - } - - pub fn with_table_packing(mut self, table_packing: TablePacking) -> Self { - self.table_packing = table_packing; - self - } - - pub fn set_table_packing(&mut self, table_packing: TablePacking) { - self.table_packing = table_packing; - } - - pub const fn table_packing(&self) -> TablePacking { - self.table_packing - } - - pub fn with_mmcs_table(mut self, mmcs_config: MmcsTableConfig) -> Self { - self.mmcs_config = mmcs_config; - self - } - - /// Generate proofs for all circuit tables. - /// - /// Automatically detects whether to use base field or binomial extension field - /// proving based on the circuit element type `EF`. For extension fields, - /// the binomial parameter W is automatically extracted. - #[instrument(skip_all)] - pub fn prove_all_tables( - &self, - traces: &Traces, - ) -> Result, ProverError> - where - EF: Field + BasedVectorSpace> + ExtractBinomialW>, - { - let pis = vec![]; - let w_opt = EF::extract_w(); - match EF::DIMENSION { - 1 => self.prove_for_degree::(traces, &pis, None), - 2 => self.prove_for_degree::(traces, &pis, w_opt), - 4 => self.prove_for_degree::(traces, &pis, w_opt), - 6 => self.prove_for_degree::(traces, &pis, w_opt), - 8 => self.prove_for_degree::(traces, &pis, w_opt), - d => Err(ProverError::UnsupportedDegree(d)), - } - } - - /// Verify all proofs in the given proof bundle. - /// Uses the recorded extension degree and binomial parameter recorded during proving. - pub fn verify_all_tables(&self, proof: &MultiTableProof) -> Result<(), ProverError> { - let pis = vec![]; - - let w_opt = proof.w_binomial; - match proof.ext_degree { - 1 => self.verify_for_degree::<1>(proof, &pis, None), - 2 => self.verify_for_degree::<2>(proof, &pis, w_opt), - 4 => self.verify_for_degree::<4>(proof, &pis, w_opt), - 6 => self.verify_for_degree::<6>(proof, &pis, w_opt), - 8 => self.verify_for_degree::<8>(proof, &pis, w_opt), - d => Err(ProverError::UnsupportedDegree(d)), - } - } - - // Internal implementation methods - - /// Prove all tables for a fixed extension degree. - fn prove_for_degree( - &self, - traces: &Traces, - pis: &Vec>, - w_binomial: Option>, - ) -> Result, ProverError> - where - EF: Field + BasedVectorSpace>, - { - debug_assert_eq!(D, EF::DIMENSION, "D parameter must match EF::DIMENSION"); - let table_packing = self.table_packing; - let add_lanes = table_packing.add_lanes(); - let mul_lanes = table_packing.mul_lanes(); - // Witness - let witness_matrix = WitnessAir::, D>::trace_to_matrix(&traces.witness_trace); - let witness_air = WitnessAir::, D>::new(traces.witness_trace.values.len()); - let witness_proof = prove(&self.config, &witness_air, witness_matrix, pis); - - // Const - let const_matrix = ConstAir::, D>::trace_to_matrix(&traces.const_trace); - let const_air = ConstAir::, D>::new(traces.const_trace.values.len()); - let const_proof = prove(&self.config, &const_air, const_matrix, pis); - - // Public - let public_matrix = PublicAir::, D>::trace_to_matrix(&traces.public_trace); - let public_air = PublicAir::, D>::new(traces.public_trace.values.len()); - let public_proof = prove(&self.config, &public_air, public_matrix, pis); - - // Add - let add_matrix = AddAir::, D>::trace_to_matrix(&traces.add_trace, add_lanes); - let add_air = AddAir::, D>::new(traces.add_trace.lhs_values.len(), add_lanes); - let add_proof = prove(&self.config, &add_air, add_matrix, pis); - - // Multiplication (uses binomial arithmetic for extension fields) - let mul_matrix = MulAir::, D>::trace_to_matrix(&traces.mul_trace, mul_lanes); - let mul_air: MulAir, D> = if D == 1 { - MulAir::, D>::new(traces.mul_trace.lhs_values.len(), mul_lanes) - } else { - let w = w_binomial.ok_or(ProverError::MissingWForExtension)?; - MulAir::, D>::new_binomial(traces.mul_trace.lhs_values.len(), mul_lanes, w) - }; - let mul_proof = prove(&self.config, &mul_air, mul_matrix, pis); - - let mmcs_matrix = MmcsVerifyAir::trace_to_matrix(&self.mmcs_config, &traces.mmcs_trace); - let mmcs_air = MmcsVerifyAir::new(self.mmcs_config); - let mmcs_proof = prove(&self.config, &mmcs_air, mmcs_matrix, pis); - - Ok(MultiTableProof { - witness: TableProof { - proof: witness_proof, - rows: traces.witness_trace.values.len(), - }, - constants: TableProof { - proof: const_proof, - rows: traces.const_trace.values.len(), - }, - public: TableProof { - proof: public_proof, - rows: traces.public_trace.values.len(), - }, - add: TableProof { - proof: add_proof, - rows: traces.add_trace.lhs_values.len(), - }, - mul: TableProof { - proof: mul_proof, - rows: traces.mul_trace.lhs_values.len(), - }, - mmcs: TableProof { - proof: mmcs_proof, - rows: traces - .mmcs_trace - .mmcs_paths - .iter() - .map(|path| path.left_values.len() + 1) - .sum(), - }, - table_packing, - ext_degree: D, - w_binomial: if D > 1 { w_binomial } else { None }, - }) - } - - /// Verify all tables for a fixed extension degree. - fn verify_for_degree( - &self, - proof: &MultiTableProof, - pis: &Vec>, - w_binomial: Option>, - ) -> Result<(), ProverError> { - let table_packing = proof.table_packing; - let add_lanes = table_packing.add_lanes(); - let mul_lanes = table_packing.mul_lanes(); - // Witness - let witness_air = WitnessAir::, D>::new(proof.witness.rows); - verify(&self.config, &witness_air, &proof.witness.proof, pis) - .map_err(|_| ProverError::VerificationFailed { phase: "witness" })?; - - // Const - let const_air = ConstAir::, D>::new(proof.constants.rows); - verify(&self.config, &const_air, &proof.constants.proof, pis) - .map_err(|_| ProverError::VerificationFailed { phase: "const" })?; - - // Public - let public_air = PublicAir::, D>::new(proof.public.rows); - verify(&self.config, &public_air, &proof.public.proof, pis) - .map_err(|_| ProverError::VerificationFailed { phase: "public" })?; - - // Add - let add_air = AddAir::, D>::new(proof.add.rows, add_lanes); - verify(&self.config, &add_air, &proof.add.proof, pis) - .map_err(|_| ProverError::VerificationFailed { phase: "add" })?; - - // Mul - let mul_air: MulAir, D> = if D == 1 { - MulAir::new(proof.mul.rows, mul_lanes) - } else { - let w = w_binomial.ok_or(ProverError::MissingWForExtension)?; - MulAir::new_binomial(proof.mul.rows, mul_lanes, w) - }; - verify(&self.config, &mul_air, &proof.mul.proof, pis) - .map_err(|_| ProverError::VerificationFailed { phase: "mul" })?; - - // MmcsVerify - - let mmcs_air = MmcsVerifyAir::new(self.mmcs_config); - verify(&self.config, &mmcs_air, &proof.mmcs.proof, pis).map_err(|_| { - ProverError::VerificationFailed { - phase: "mmcs_verify", - } - })?; - - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use p3_baby_bear::BabyBear; - use p3_circuit::CircuitBuilder; - use p3_field::extension::BinomialExtensionField; - use p3_field::{BasedVectorSpace, PrimeCharacteristicRing}; - use p3_goldilocks::Goldilocks; - use p3_koala_bear::KoalaBear; - - use super::*; - use crate::config; - - #[test] - fn test_babybear_prover_base_field() -> Result<(), ProverError> { - let mut builder = CircuitBuilder::::new(); - - // Create circuit: x + 5 * 2 - 3 + (-1) = expected_result, then assert result == expected - let x = builder.add_public_input(); - let expected_result = builder.add_public_input(); // Add expected result as public input - let c5 = builder.add_const(BabyBear::from_u64(5)); - let c2 = builder.add_const(BabyBear::from_u64(2)); - let c3 = builder.add_const(BabyBear::from_u64(3)); - let neg_one = builder.add_const(BabyBear::NEG_ONE); // Field boundary test - - let mul_result = builder.mul(c5, c2); // 5 * 2 = 10 - let add_result = builder.add(x, mul_result); // x + 10 - let sub_result = builder.sub(add_result, c3); // (x + 10) - 3 - let final_result = builder.add(sub_result, neg_one); // + (-1) for boundary - - // Constrain: final_result - expected_result == 0 - let diff = builder.sub(final_result, expected_result); - builder.assert_zero(diff); - - let circuit = builder.build()?; - let mut runner = circuit.runner(); - - // Set public inputs: x = 7, expected = 7 + 10 - 3 + (-1) = 13 - let x_val = BabyBear::from_u64(7); - let expected_val = BabyBear::from_u64(13); // 7 + 10 - 3 - 1 = 13 - runner.set_public_inputs(&[x_val, expected_val])?; - - let traces = runner.run()?; - - // Create BabyBear prover and prove all tables - let config = config::baby_bear().build(); - let multi_prover = MultiTableProver::new(config); - let proof = multi_prover.prove_all_tables(&traces)?; - - // Verify all proofs - multi_prover.verify_all_tables(&proof)?; - Ok(()) - } - - #[test] - fn test_babybear_prover_extension_field_d4() -> Result<(), ProverError> { - type ExtField = BinomialExtensionField; - let mut builder = CircuitBuilder::::new(); - - // Create circuit: x * y + z - w = expected_result, then assert result == expected - let x = builder.add_public_input(); - let y = builder.add_public_input(); - let z = builder.add_public_input(); - let expected_result = builder.add_public_input(); // Add expected result as public input - let w = builder.add_const( - ExtField::from_basis_coefficients_slice(&[ - BabyBear::NEG_ONE, // -1 boundary test - BabyBear::ZERO, - BabyBear::ONE, - BabyBear::TWO, - ]) - .unwrap(), - ); - - let xy = builder.mul(x, y); // Extension field multiplication - let add_result = builder.add(xy, z); - let sub_result = builder.sub(add_result, w); - - // Constrain: sub_result - expected_result == 0 - let diff = builder.sub(sub_result, expected_result); - builder.assert_zero(diff); - - let circuit = builder.build()?; - let mut runner = circuit.runner(); - - // Set public inputs with all non-zero coefficients - let x_val = ExtField::from_basis_coefficients_slice(&[ - BabyBear::from_u64(2), - BabyBear::from_u64(3), - BabyBear::from_u64(5), - BabyBear::from_u64(7), - ]) - .unwrap(); - let y_val = ExtField::from_basis_coefficients_slice(&[ - BabyBear::from_u64(11), - BabyBear::from_u64(13), - BabyBear::from_u64(17), - BabyBear::from_u64(19), - ]) - .unwrap(); - let z_val = ExtField::from_basis_coefficients_slice(&[ - BabyBear::from_u64(23), - BabyBear::from_u64(29), - BabyBear::from_u64(31), - BabyBear::from_u64(37), - ]) - .unwrap(); - let w_val = ExtField::from_basis_coefficients_slice(&[ - BabyBear::NEG_ONE, - BabyBear::ZERO, - BabyBear::ONE, - BabyBear::TWO, - ]) - .unwrap(); - - // Compute expected result: x * y + z - w - let xy_expected = x_val * y_val; - let add_expected = xy_expected + z_val; - let expected_val = add_expected - w_val; - - runner.set_public_inputs(&[x_val, y_val, z_val, expected_val])?; - let traces = runner.run()?; - - // Create BabyBear prover for extension field (D=4) - let config = config::baby_bear().build(); - let multi_prover = MultiTableProver::new(config); - let proof = multi_prover.prove_all_tables(&traces)?; - - // Verify proof has correct extension degree and W parameter - assert_eq!(proof.ext_degree, 4); - // Derive W via trait to avoid hardcoding constants - let expected_w = >::extract_w().unwrap(); - assert_eq!(proof.w_binomial, Some(expected_w)); - - multi_prover.verify_all_tables(&proof)?; - Ok(()) - } - - #[test] - fn test_koalabear_prover_base_field() -> Result<(), ProverError> { - let mut builder = CircuitBuilder::::new(); - - // Create circuit: a * b + c - d = expected_result, then assert result == expected - let a = builder.add_public_input(); - let b = builder.add_public_input(); - let expected_result = builder.add_public_input(); // Add expected result as public input - let c = builder.add_const(KoalaBear::from_u64(100)); - let d = builder.add_const(KoalaBear::NEG_ONE); // Boundary test - - let ab = builder.mul(a, b); - let add_result = builder.add(ab, c); - let final_result = builder.sub(add_result, d); - - // Constrain: final_result - expected_result == 0 - let diff = builder.sub(final_result, expected_result); - builder.assert_zero(diff); - - let circuit = builder.build()?; - let mut runner = circuit.runner(); - - // Set public inputs: a=42, b=13, expected = 42*13 + 100 - (-1) = 546 + 100 + 1 = 647 - let a_val = KoalaBear::from_u64(42); - let b_val = KoalaBear::from_u64(13); - let expected_val = KoalaBear::from_u64(647); // 42*13 + 100 - (-1) = 647 - runner.set_public_inputs(&[a_val, b_val, expected_val])?; - let traces = runner.run()?; - - // Create KoalaBear prover - let config = config::koala_bear().build(); - let multi_prover = MultiTableProver::new(config); - let proof = multi_prover.prove_all_tables(&traces)?; - - multi_prover.verify_all_tables(&proof)?; - Ok(()) - } - - #[test] - fn test_koalabear_prover_extension_field_d8() -> Result<(), ProverError> { - type KBExtField = BinomialExtensionField; - let mut builder = CircuitBuilder::::new(); - - // Create circuit: x * y * z = expected_result, then assert result == expected - let x = builder.add_public_input(); - let y = builder.add_public_input(); - let expected_result = builder.add_public_input(); // Add expected result as public input - let z = builder.add_const( - KBExtField::from_basis_coefficients_slice(&[ - KoalaBear::from_u64(1), - KoalaBear::NEG_ONE, // Mix: 1 and -1 - KoalaBear::from_u64(2), - KoalaBear::from_u64(3), - KoalaBear::from_u64(4), - KoalaBear::from_u64(5), - KoalaBear::from_u64(6), - KoalaBear::from_u64(7), - ]) - .unwrap(), - ); - - let xy = builder.mul(x, y); // First extension multiplication - let xyz = builder.mul(xy, z); // Second extension multiplication - - // Constrain: xyz - expected_result == 0 - let diff = builder.sub(xyz, expected_result); - builder.assert_zero(diff); - - let circuit = builder.build()?; - let mut runner = circuit.runner(); - - // Set public inputs with diverse coefficients - let x_val = KBExtField::from_basis_coefficients_slice(&[ - KoalaBear::from_u64(4), - KoalaBear::from_u64(6), - KoalaBear::from_u64(8), - KoalaBear::from_u64(10), - KoalaBear::from_u64(12), - KoalaBear::from_u64(14), - KoalaBear::from_u64(16), - KoalaBear::from_u64(18), - ]) - .unwrap(); - let y_val = KBExtField::from_basis_coefficients_slice(&[ - KoalaBear::from_u64(12), - KoalaBear::from_u64(14), - KoalaBear::from_u64(16), - KoalaBear::from_u64(18), - KoalaBear::from_u64(20), - KoalaBear::from_u64(22), - KoalaBear::from_u64(24), - KoalaBear::from_u64(26), - ]) - .unwrap(); - let z_val = KBExtField::from_basis_coefficients_slice(&[ - KoalaBear::from_u64(1), - KoalaBear::NEG_ONE, - KoalaBear::from_u64(2), - KoalaBear::from_u64(3), - KoalaBear::from_u64(4), - KoalaBear::from_u64(5), - KoalaBear::from_u64(6), - KoalaBear::from_u64(7), - ]) - .unwrap(); - - // Compute expected result: x * y * z - let xy_expected = x_val * y_val; - let expected_val = xy_expected * z_val; - - runner.set_public_inputs(&[x_val, y_val, expected_val])?; - let traces = runner.run()?; - - // Create KoalaBear prover for extension field (D=8) - let config = config::koala_bear().build(); - let multi_prover = MultiTableProver::new(config); - let proof = multi_prover.prove_all_tables(&traces)?; - - // Verify proof has correct extension degree and W parameter for KoalaBear (D=8) - assert_eq!(proof.ext_degree, 8); - let expected_w_kb = >::extract_w().unwrap(); - assert_eq!(proof.w_binomial, Some(expected_w_kb)); - - multi_prover.verify_all_tables(&proof)?; - Ok(()) - } - - #[test] - fn test_goldilocks_prover_extension_field_d2() -> Result<(), ProverError> { - type ExtField = BinomialExtensionField; - let mut builder = CircuitBuilder::::new(); - - // Simple circuit over D=2: x * y + z = expected - let x = builder.add_public_input(); - let y = builder.add_public_input(); - let z = builder.add_public_input(); - let expected_result = builder.add_public_input(); - - let xy = builder.mul(x, y); - let res = builder.add(xy, z); - - let diff = builder.sub(res, expected_result); - builder.assert_zero(diff); - - let circuit = builder.build()?; - let mut runner = circuit.runner(); - - let x_val = ExtField::from_basis_coefficients_slice(&[ - Goldilocks::from_u64(3), - Goldilocks::NEG_ONE, - ]) - .unwrap(); - let y_val = ExtField::from_basis_coefficients_slice(&[ - Goldilocks::from_u64(7), - Goldilocks::from_u64(11), - ]) - .unwrap(); - let z_val = ExtField::from_basis_coefficients_slice(&[ - Goldilocks::from_u64(13), - Goldilocks::from_u64(17), - ]) - .unwrap(); - - let expected_val = x_val * y_val + z_val; - runner.set_public_inputs(&[x_val, y_val, z_val, expected_val])?; - let traces = runner.run()?; - - // Build Goldilocks config with challenge degree 2 (Poseidon2) - let config = config::goldilocks().build(); - let multi_prover = MultiTableProver::new(config); - - let proof = multi_prover.prove_all_tables(&traces)?; - - // Check extension metadata and verify - assert_eq!(proof.ext_degree, 2); - let expected_w = >::extract_w().unwrap(); - assert_eq!(proof.w_binomial, Some(expected_w)); - - multi_prover.verify_all_tables(&proof)?; - Ok(()) - } -}