diff --git a/circuit-prover/src/prover.rs b/circuit-prover/src/prover.rs index d9d6ea4..90e3cea 100644 --- a/circuit-prover/src/prover.rs +++ b/circuit-prover/src/prover.rs @@ -11,7 +11,6 @@ //! detection of the binomial parameter `W` for extension-field multiplication. use alloc::vec; -use alloc::vec::Vec; use p3_circuit::tables::Traces; use p3_circuit::{CircuitBuilderError, CircuitError}; @@ -209,7 +208,7 @@ where fn prove_for_degree( &self, traces: &Traces, - pis: &Vec>, + pis: &[Val], w_binomial: Option>, ) -> Result, ProverError> where @@ -293,7 +292,7 @@ where fn verify_for_degree( &self, proof: &MultiTableProof, - pis: &Vec>, + pis: &[Val], w_binomial: Option>, ) -> Result<(), ProverError> { let table_packing = proof.table_packing; diff --git a/circuit/Cargo.toml b/circuit/Cargo.toml index 99bb7e4..92d3052 100644 --- a/circuit/Cargo.toml +++ b/circuit/Cargo.toml @@ -21,6 +21,7 @@ p3-koala-bear.workspace = true p3-matrix.workspace = true p3-symmetric.workspace = true p3-uni-stark.workspace = true +p3-util.workspace = true rand.workspace = true # Other common dependencies diff --git a/circuit/src/builder/circuit_builder.rs b/circuit/src/builder/circuit_builder.rs index 0422ab7..8ec907c 100644 --- a/circuit/src/builder/circuit_builder.rs +++ b/circuit/src/builder/circuit_builder.rs @@ -6,10 +6,11 @@ use itertools::zip_eq; use p3_field::{Field, PrimeCharacteristicRing}; use super::compiler::{ExpressionLowerer, NonPrimitiveLowerer, Optimizer}; -use super::{BuilderConfig, ExpressionBuilder, PublicInputTracker}; +use super::{BuilderConfig, ExpressionBuilder}; use crate::CircuitBuilderError; +use crate::builder::public_input_tracker::PublicInputTracker; use crate::circuit::Circuit; -use crate::op::NonPrimitiveOpType; +use crate::op::{DefaultHint, NonPrimitiveOpType, WitnessHintFiller}; use crate::ops::MmcsVerifyConfig; use crate::types::{ExprId, NonPrimitiveOpId, WitnessAllocator, WitnessId}; @@ -120,16 +121,27 @@ where self.public_tracker.count() } - /// Allocates a witness hint (uninitialized witness slot set during non-primitive execution). + /// Allocates multiple witnesses. Witness hints are placeholders for values that will later be provided by a + /// `filler`. #[must_use] - pub fn alloc_witness_hint(&mut self, label: &'static str) -> ExprId { - self.expr_builder.add_witness_hint(label) + pub fn alloc_witness_hints>( + &mut self, + filler: W, + label: &'static str, + ) -> Vec { + self.expr_builder.add_witness_hints(filler, label) } - /// Allocates multiple witness hints. + /// Allocates multiple witness hints without saying how they should be filled + /// TODO: Remove this function #[must_use] - pub fn alloc_witness_hints(&mut self, count: usize, label: &'static str) -> Vec { - self.expr_builder.add_witness_hints(count, label) + pub fn alloc_witness_hints_no_filler( + &mut self, + count: usize, + label: &'static str, + ) -> Vec { + self.expr_builder + .add_witness_hints(DefaultHint { n_outputs: count }, label) } /// Adds a constant to the circuit (deduplicated). @@ -394,6 +406,7 @@ where self.expr_builder.graph(), self.expr_builder.pending_connects(), self.public_tracker.count(), + self.expr_builder.hints_with_fillers(), self.witness_alloc, ); let (primitive_ops, public_rows, expr_to_widx, public_mappings, witness_count) = @@ -423,11 +436,9 @@ where #[cfg(test)] mod tests { use alloc::vec; - use alloc::vec::Vec; use p3_baby_bear::BabyBear; use p3_field::PrimeCharacteristicRing; - use proptest::prelude::*; use super::*; @@ -700,6 +711,41 @@ mod tests { assert_eq!(circuit.primitive_ops.len(), 2); } + #[test] + fn test_build_with_witness_hint() { + let mut builder = CircuitBuilder::::new(); + let default_hint = DefaultHint { n_outputs: 1 }; + let a = builder.alloc_witness_hints(default_hint, "a"); + assert_eq!(a.len(), 1); + let circuit = builder + .build() + .expect("Circuit with operations should build"); + + assert_eq!(circuit.witness_count, 2); + assert_eq!(circuit.primitive_ops.len(), 2); + + match &circuit.primitive_ops[1] { + crate::op::Op::Unconstrained { + inputs, outputs, .. + } => { + assert_eq!(*inputs, vec![]); + assert_eq!(*outputs, vec![WitnessId(1)]); + } + _ => panic!("Expected Unconstrained at index 0"), + } + } +} + +#[cfg(test)] +mod proptests { + use alloc::vec; + + use p3_baby_bear::BabyBear; + use p3_field::PrimeCharacteristicRing; + use proptest::prelude::*; + + use super::*; + // Strategy for generating valid field elements fn field_element() -> impl Strategy { any::().prop_map(BabyBear::from_u64) diff --git a/circuit/src/builder/compiler/expression_lowerer.rs b/circuit/src/builder/compiler/expression_lowerer.rs index 2488c11..16191dd 100644 --- a/circuit/src/builder/compiler/expression_lowerer.rs +++ b/circuit/src/builder/compiler/expression_lowerer.rs @@ -1,3 +1,5 @@ +use alloc::boxed::Box; +use alloc::string::ToString; use alloc::vec; use alloc::vec::Vec; @@ -8,6 +10,7 @@ use crate::Op; use crate::builder::CircuitBuilderError; use crate::builder::compiler::get_witness_id; use crate::expr::{Expr, ExpressionGraph}; +use crate::op::WitnessHintFiller; use crate::types::{ExprId, WitnessAllocator, WitnessId}; /// Sparse disjoint-set "find" with path compression over a HashMap (iterative). @@ -70,6 +73,9 @@ pub struct ExpressionLowerer<'a, F> { /// Number of public inputs public_input_count: usize, + /// The hint witnesses with their respective filler + hints_fillers: &'a [Box>], + /// Witness allocator witness_alloc: WitnessAllocator, } @@ -83,12 +89,14 @@ where graph: &'a ExpressionGraph, pending_connects: &'a [(ExprId, ExprId)], public_input_count: usize, + hints_fillers: &'a [Box>], witness_alloc: WitnessAllocator, ) -> Self { Self { graph, pending_connects, public_input_count, + hints_fillers, witness_alloc, } } @@ -162,7 +170,7 @@ where } }; - // Pass B: emit public inputs + // Pass B: emit public inputs and process witness hints for (expr_idx, expr) in self.graph.nodes().iter().enumerate() { if let Expr::Public(pos) = expr { let id = ExprId(expr_idx as u32); @@ -179,17 +187,42 @@ where } } - // Pass C: emit arithmetic ops in creation order; tie outputs to class slot if connected + // Pass C: emit arithmetic and unconstrained ops in creation order; tie outputs to class slot if connected + let mut hints_sequence = vec![]; + let mut fillers_iter = self.hints_fillers.iter().cloned(); for (expr_idx, expr) in self.graph.nodes().iter().enumerate() { let expr_id = ExprId(expr_idx as u32); match expr { Expr::Const(_) | Expr::Public(_) => { /* handled above */ } - Expr::Witness => { - // Allocate a fresh witness slot (non-primitive op) - // Allows non-primitive operations to set values during execution that - // are not part of the central Witness bus. + Expr::Witness { last_hint } => { + let expr_id = ExprId(expr_idx as u32); let out_widx = alloc_witness_id_for_expr(expr_idx); expr_to_widx.insert(expr_id, out_widx); + hints_sequence.push(out_widx); + if *last_hint { + let filler = fillers_iter.next().expect( + "By construction, every sequence of witness must haver one filler", + ); + let inputs = filler + .inputs() + .iter() + .map(|expr_id| { + expr_to_widx + .get(expr_id) + .ok_or(CircuitBuilderError::MissingExprMapping { + expr_id: *expr_id, + context: "Unconstrained op".to_string(), + }) + .copied() + }) + .collect::, _>>()?; + primitive_ops.push(Op::Unconstrained { + inputs, + outputs: hints_sequence, + filler, + }); + hints_sequence = vec![]; + } } Expr::Add { lhs, rhs } => { let out_widx = alloc_witness_id_for_expr(expr_idx); @@ -394,9 +427,10 @@ mod tests { let quot = graph.add_expr(Expr::Div { lhs: diff, rhs: p2 }); let connects = vec![]; + let hints_fillers = vec![]; let alloc = WitnessAllocator::new(); - let lowerer = ExpressionLowerer::new(&graph, &connects, 3, alloc); + let lowerer = ExpressionLowerer::new(&graph, &connects, 3, &hints_fillers, alloc); let (prims, public_rows, expr_map, public_map, witness_count) = lowerer.lower().unwrap(); // Verify Primitives @@ -563,9 +597,10 @@ mod tests { // Group B: p1 ~ p2 ~ p3 (transitive) // Group C: sum ~ p4 (operation result shared) let connects = vec![(c_42, p0), (p1, p2), (p2, p3), (sum, p4)]; + let hints_fillers = vec![]; let alloc = WitnessAllocator::new(); - let lowerer = ExpressionLowerer::new(&graph, &connects, 5, alloc); + let lowerer = ExpressionLowerer::new(&graph, &connects, 5, &hints_fillers, alloc); let (prims, public_rows, expr_map, public_map, witness_count) = lowerer.lower().unwrap(); // Verify Primitives @@ -703,8 +738,9 @@ mod tests { }); let connects = vec![]; + let hints_fillers = vec![]; let alloc = WitnessAllocator::new(); - let lowerer = ExpressionLowerer::new(&graph, &connects, 0, alloc); + let lowerer = ExpressionLowerer::new(&graph, &connects, 0, &hints_fillers, alloc); let result = lowerer.lower(); assert!(result.is_err()); @@ -724,8 +760,9 @@ mod tests { }); let connects = vec![]; + let hints_fillers = vec![]; let alloc = WitnessAllocator::new(); - let lowerer = ExpressionLowerer::new(&graph, &connects, 0, alloc); + let lowerer = ExpressionLowerer::new(&graph, &connects, 0, &hints_fillers, alloc); let result = lowerer.lower(); assert!(result.is_err()); diff --git a/circuit/src/builder/expression_builder.rs b/circuit/src/builder/expression_builder.rs index ca6fc31..4ccceeb 100644 --- a/circuit/src/builder/expression_builder.rs +++ b/circuit/src/builder/expression_builder.rs @@ -1,11 +1,14 @@ +use alloc::boxed::Box; #[cfg(debug_assertions)] use alloc::vec; use alloc::vec::Vec; use hashbrown::HashMap; +use itertools::Itertools; use p3_field::PrimeCharacteristicRing; use crate::expr::{Expr, ExpressionGraph}; +use crate::op::WitnessHintFiller; use crate::types::ExprId; #[cfg(debug_assertions)] use crate::{AllocationEntry, AllocationType}; @@ -22,6 +25,9 @@ pub struct ExpressionBuilder { /// Equality constraints to enforce at lowering pending_connects: Vec<(ExprId, ExprId)>, + /// The witness hints together with theit witness fillers + hints_with_fillers: Vec>>, + /// Debug log of all allocations #[cfg(debug_assertions)] allocation_log: Vec, @@ -50,6 +56,7 @@ where graph, const_pool, pending_connects: Vec::new(), + hints_with_fillers: Vec::new(), #[cfg(debug_assertions)] allocation_log: Vec::new(), #[cfg(debug_assertions)] @@ -100,10 +107,11 @@ where /// Adds a witness hint to the graph. /// It will allocate a `WitnessId` during lowering, with no primitive op. + /// TODO: Make this function private. #[allow(unused_variables)] #[must_use] pub fn add_witness_hint(&mut self, label: &'static str) -> ExprId { - let expr_id = self.graph.add_expr(Expr::Witness); + let expr_id = self.graph.add_expr(Expr::Witness { last_hint: true }); #[cfg(debug_assertions)] self.allocation_log.push(AllocationEntry { @@ -117,10 +125,26 @@ where expr_id } - /// Adds multiple witness hints. + /// Adds `filler.n_outputs()` witness hints to the graph. + /// During circuit evaluation, the `filler` will derive the concrete + /// witness values for these hints. + #[allow(unused_variables)] #[must_use] - pub fn add_witness_hints(&mut self, count: usize, label: &'static str) -> Vec { - (0..count).map(|_| self.add_witness_hint(label)).collect() + pub fn add_witness_hints>( + &mut self, + filler: W, + label: &'static str, + ) -> Vec { + let n_outputs = filler.n_outputs(); + let expr_ids = (0..n_outputs) + .map(|i| { + self.graph.add_expr(Expr::Witness { + last_hint: i == n_outputs - 1, + }) + }) + .collect_vec(); + self.hints_with_fillers.push(Box::new(filler)); + expr_ids } /// Adds an addition expression to the graph. @@ -208,6 +232,11 @@ where &self.pending_connects } + /// Returns a reference to the witness hints with fillers. + pub fn hints_with_fillers(&self) -> &[Box>] { + &self.hints_with_fillers + } + /// Logs a non-primitive operation allocation. #[cfg(debug_assertions)] pub fn log_non_primitive_op( @@ -285,8 +314,10 @@ where #[cfg(test)] mod tests { use p3_baby_bear::BabyBear; + use p3_field::Field; use super::*; + use crate::CircuitError; #[test] fn test_new_builder_has_zero_constant() { @@ -527,6 +558,52 @@ mod tests { } } + #[derive(Debug, Clone)] + struct IdentityHint { + inputs: Vec, + n_outputs: usize, + } + + impl IdentityHint { + pub fn new(inputs: Vec) -> Self { + Self { + n_outputs: inputs.len(), + inputs, + } + } + } + + impl WitnessHintFiller for IdentityHint { + fn inputs(&self) -> &[ExprId] { + &self.inputs + } + + fn n_outputs(&self) -> usize { + self.n_outputs + } + + fn compute_outputs(&self, inputs_val: Vec) -> Result, CircuitError> { + Ok(inputs_val) + } + } + + #[test] + fn test_build_with_witness_hint() { + let mut builder = ExpressionBuilder::::new(); + let a = builder.add_const(BabyBear::ZERO, "a"); + let b = builder.add_const(BabyBear::ONE, "b"); + let id_hint = IdentityHint::new(vec![a, b]); + let c = builder.add_witness_hints(id_hint, "c"); + assert_eq!(c.len(), 2); + + assert_eq!(builder.graph().nodes().len(), 4); + + match (&builder.graph().nodes()[2], &builder.graph().nodes()[3]) { + (Expr::Witness { last_hint: false }, Expr::Witness { last_hint: true }) => (), + _ => panic!("Expected Witness operation"), + } + } + #[test] fn test_nested_operations() { // Test nested operations: (a + b) * (c - d) diff --git a/circuit/src/errors.rs b/circuit/src/errors.rs index efc517f..9d66fb7 100644 --- a/circuit/src/errors.rs +++ b/circuit/src/errors.rs @@ -110,4 +110,7 @@ pub enum CircuitError { /// Invalid Circuit #[error("Failed to build circuit: {error}")] InvalidCircuit { error: CircuitBuilderError }, + + #[error("Unconstrained operation input length mismatch: expected {expected}, got {got}")] + UnconstrainedOpInputLengthMismatch { expected: usize, got: usize }, } diff --git a/circuit/src/expr.rs b/circuit/src/expr.rs index a2ee656..c79088a 100644 --- a/circuit/src/expr.rs +++ b/circuit/src/expr.rs @@ -9,9 +9,11 @@ pub enum Expr { Const(F), /// Public input at declaration position Public(usize), - /// Witness hint - allocates a WitnessId without adding a primitive op - /// The value will be set during non-primitive execution (set-or-verify semantics) - Witness, + /// Witness hints — allocates a `WitnessId` representing a + /// non-deterministic hint. The boolean flag indicates whether + /// this is the last witness in a sequence of related hints, + /// where each sequence is produced through a shared generation process. + Witness { last_hint: bool }, /// Addition of two expressions Add { lhs: ExprId, rhs: ExprId }, /// Subtraction of two expressions diff --git a/circuit/src/op.rs b/circuit/src/op.rs index aafa25e..4807fec 100644 --- a/circuit/src/op.rs +++ b/circuit/src/op.rs @@ -1,4 +1,5 @@ use alloc::boxed::Box; +use alloc::vec; use alloc::vec::Vec; use core::fmt::Debug; use core::hash::Hash; @@ -6,10 +7,10 @@ use core::hash::Hash; use hashbrown::HashMap; use p3_field::Field; -use crate::CircuitError; use crate::ops::MmcsVerifyConfig; use crate::tables::MmcsPrivateData; use crate::types::{NonPrimitiveOpId, WitnessId}; +use crate::{CircuitError, ExprId}; /// Circuit operations. /// @@ -63,6 +64,16 @@ pub enum Op { out: WitnessId, }, + /// Load unconstrained values into the witness table + /// + /// Sets `witness[output]`, for each `output` in `outputs`, to arbitrary values + /// defined by `filler` + Unconstrained { + inputs: Vec, + outputs: Vec, + filler: Box>, + }, + /// Non-primitive operation with executor-based dispatch NonPrimitiveOpWithExecutor { inputs: Vec>, @@ -95,6 +106,15 @@ impl Clone for Op { b: *b, out: *out, }, + Op::Unconstrained { + inputs, + outputs, + filler, + } => Op::Unconstrained { + inputs: inputs.clone(), + outputs: outputs.clone(), + filler: filler.clone(), + }, Op::NonPrimitiveOpWithExecutor { inputs, outputs, @@ -379,3 +399,60 @@ impl Clone for Box> { self.boxed() } } + +/// A trait for defining how unconstrained data (hints) is set. +pub trait WitnessHintFiller: Debug + WitnessFillerClone { + /// Return the `ExprId` of the inputs + fn inputs(&self) -> &[ExprId]; + /// Returns number of outputs filled by this filler + fn n_outputs(&self) -> usize; + /// Compute the output given the inputs + /// # Arguments + /// * `inputs` - Input witness + fn compute_outputs(&self, inputs_val: Vec) -> Result, CircuitError>; +} + +impl Clone for Box> { + fn clone(&self) -> Self { + self.clone_box() + } +} + +#[derive(Debug, Clone, Default)] +pub struct DefaultHint { + pub n_outputs: usize, +} + +impl DefaultHint { + pub fn boxed_default() -> Box> { + Box::new(Self::default()) + } +} + +impl WitnessHintFiller for DefaultHint { + fn inputs(&self) -> &[ExprId] { + &[] + } + + fn n_outputs(&self) -> usize { + self.n_outputs + } + + fn compute_outputs(&self, _inputs_val: Vec) -> Result, CircuitError> { + Ok(vec![F::default(); self.n_outputs]) + } +} + +// Object-safe "clone into Box" helper +pub trait WitnessFillerClone { + fn clone_box(&self) -> Box>; +} + +impl WitnessFillerClone for T +where + T: WitnessHintFiller + Clone + 'static, +{ + fn clone_box(&self) -> Box> { + Box::new(self.clone()) + } +} diff --git a/circuit/src/ops/hash.rs b/circuit/src/ops/hash.rs index 9aad381..fc7d0a9 100644 --- a/circuit/src/ops/hash.rs +++ b/circuit/src/ops/hash.rs @@ -55,7 +55,7 @@ where fn add_hash_squeeze(&mut self, count: usize) -> Result, CircuitBuilderError> { self.ensure_op_enabled(NonPrimitiveOpType::HashSqueeze)?; - let outputs = self.alloc_witness_hints(count, "hash_squeeze_output"); + let outputs = self.alloc_witness_hints_no_filler(count, "hash_squeeze_output"); let _ = self.push_non_primitive_op( NonPrimitiveOpType::HashSqueeze, diff --git a/circuit/src/tables/mmcs.rs b/circuit/src/tables/mmcs.rs index a152498..ac01d78 100644 --- a/circuit/src/tables/mmcs.rs +++ b/circuit/src/tables/mmcs.rs @@ -3,6 +3,7 @@ use alloc::vec::Vec; use alloc::{format, vec}; use core::fmt::Debug; use core::iter; +use core::result::Result; use itertools::izip; use p3_field::{ExtensionField, Field}; diff --git a/circuit/src/tables/runner.rs b/circuit/src/tables/runner.rs index 939b608..66d161e 100644 --- a/circuit/src/tables/runner.rs +++ b/circuit/src/tables/runner.rs @@ -2,6 +2,7 @@ use alloc::string::ToString; use alloc::vec::Vec; use alloc::{format, vec}; +use p3_util::zip_eq::zip_eq; use tracing::instrument; use super::Traces; @@ -189,6 +190,28 @@ impl CircuitRunner { self.set_witness(b, b_val)?; } } + Op::Unconstrained { + inputs, + outputs, + filler, + } => { + let inputs_val = inputs + .iter() + .map(|&input| self.get_witness(input)) + .collect::, _>>()?; + let outputs_val = filler.compute_outputs(inputs_val)?; + + for (&output, &output_val) in zip_eq( + outputs.iter(), + outputs_val.iter(), + CircuitError::UnconstrainedOpInputLengthMismatch { + expected: outputs.len(), + got: outputs_val.len(), + }, + )? { + self.set_witness(output, output_val)? + } + } Op::NonPrimitiveOpWithExecutor { .. } => { // Handled separately in execute_non_primitives } @@ -262,14 +285,17 @@ impl CircuitRunner { mod tests { extern crate std; use alloc::vec; + use alloc::vec::Vec; use std::println; use p3_baby_bear::BabyBear; use p3_field::extension::BinomialExtensionField; - use p3_field::{BasedVectorSpace, PrimeCharacteristicRing}; + use p3_field::{BasedVectorSpace, Field, PrimeCharacteristicRing}; use crate::builder::CircuitBuilder; + use crate::op::WitnessHintFiller; use crate::types::WitnessId; + use crate::{CircuitError, ExprId}; #[test] fn test_table_generation_basic() { @@ -412,6 +438,148 @@ mod tests { assert_eq!(traces.add_trace.result_index, vec![WitnessId(4)]); } + #[derive(Debug, Clone)] + /// The hint defined by x in an equation a*x - b = 0 + struct XHint { + inputs: Vec, + } + + impl XHint { + pub fn new(a: ExprId, b: ExprId) -> Self { + Self { inputs: vec![a, b] } + } + } + + impl WitnessHintFiller for XHint { + fn inputs(&self) -> &[ExprId] { + &self.inputs + } + + fn n_outputs(&self) -> usize { + 1 + } + + fn compute_outputs(&self, inputs_val: Vec) -> Result, crate::CircuitError> { + if inputs_val.len() != self.inputs.len() { + Err(crate::CircuitError::UnconstrainedOpInputLengthMismatch { + expected: self.inputs.len(), + got: inputs_val.len(), + }) + } else { + let a = inputs_val[0]; + let b = inputs_val[1]; + let inv_a = a.try_inverse().ok_or(CircuitError::DivisionByZero)?; + let x = b * inv_a; + Ok(vec![x]) + } + } + } + + #[test] + // Proves that we know x such that 37 * x - 111 = 0 + fn test_toy_example_37_times_x_minus_111_with_witness_hint() { + let mut builder = CircuitBuilder::new(); + + let c37 = builder.add_const(BabyBear::from_u64(37)); + let c111 = builder.add_const(BabyBear::from_u64(111)); + let x_hint = XHint::new(c37, c111); + let x = builder.alloc_witness_hints(x_hint, "x")[0]; + + let mul_result = builder.mul(c37, x); + let sub_result = builder.sub(mul_result, c111); + builder.assert_zero(sub_result); + + let circuit = builder.build().unwrap(); + println!("=== CIRCUIT PRIMITIVE OPERATIONS ==="); + for (i, prim) in circuit.primitive_ops.iter().enumerate() { + println!("{i}: {prim:?}"); + } + + let witness_count = circuit.witness_count; + let runner = circuit.runner(); + + let traces = runner.run().unwrap(); + + println!("\n=== WITNESS TRACE ==="); + for (i, (idx, val)) in traces + .witness_trace + .index + .iter() + .zip(traces.witness_trace.values.iter()) + .enumerate() + { + println!("Row {i}: WitnessId({idx}) = {val:?}"); + } + + println!("\n=== CONST TRACE ==="); + for (i, (idx, val)) in traces + .const_trace + .index + .iter() + .zip(traces.const_trace.values.iter()) + .enumerate() + { + println!("Row {i}: WitnessId({idx}) = {val:?}"); + } + + println!("\n=== PUBLIC TRACE ==="); + for (i, (idx, val)) in traces + .public_trace + .index + .iter() + .zip(traces.public_trace.values.iter()) + .enumerate() + { + println!("Row {i}: WitnessId({idx}) = {val:?}"); + } + + println!("\n=== MUL TRACE ==="); + for i in 0..traces.mul_trace.lhs_values.len() { + println!( + "Row {}: WitnessId({}) * WitnessId({}) -> WitnessId({}) | {:?} * {:?} -> {:?}", + i, + traces.mul_trace.lhs_index[i], + traces.mul_trace.rhs_index[i], + traces.mul_trace.result_index[i], + traces.mul_trace.lhs_values[i], + traces.mul_trace.rhs_values[i], + traces.mul_trace.result_values[i] + ); + } + + println!("\n=== ADD TRACE ==="); + for i in 0..traces.add_trace.lhs_values.len() { + println!( + "Row {}: WitnessId({}) + WitnessId({}) -> WitnessId({}) | {:?} + {:?} -> {:?}", + i, + traces.add_trace.lhs_index[i], + traces.add_trace.rhs_index[i], + traces.add_trace.result_index[i], + traces.add_trace.lhs_values[i], + traces.add_trace.rhs_values[i], + traces.add_trace.result_values[i] + ); + } + + // Verify trace structure + assert_eq!(traces.witness_trace.index.len(), witness_count as usize); + + // Should have constants: 0, 37, 111 + assert_eq!(traces.const_trace.values.len(), 3); + + // Should have no public input + assert!(traces.public_trace.values.is_empty()); + + // Should have one mul operation: 37 * x + assert_eq!(traces.mul_trace.lhs_values.len(), 1); + + // Encoded subtraction lands in the add table (result + rhs = lhs). + assert_eq!(traces.add_trace.lhs_values.len(), 1); + assert_eq!(traces.add_trace.lhs_index, vec![WitnessId(2)]); + assert_eq!(traces.add_trace.rhs_index, vec![WitnessId(0)]); + assert_eq!(traces.add_trace.result_index, vec![WitnessId(4)]); + } + #[test] fn test_extension_field_support() { type ExtField = BinomialExtensionField; diff --git a/circuit/src/utils.rs b/circuit/src/utils.rs index bb7a080..0d85d18 100644 --- a/circuit/src/utils.rs +++ b/circuit/src/utils.rs @@ -1,10 +1,13 @@ use alloc::vec; use alloc::vec::Vec; +use core::marker::PhantomData; -use p3_field::Field; +use p3_field::{ExtensionField, Field, PrimeField64}; use p3_uni_stark::{Entry, SymbolicExpression}; +use p3_util::log2_ceil_u64; -use crate::{CircuitBuilder, ExprId}; +use crate::op::WitnessHintFiller; +use crate::{CircuitBuilder, CircuitError, ExprId}; /// Identifiers for special row selector flags in the circuit. #[derive(Clone, Copy, Debug)] @@ -126,26 +129,69 @@ pub fn reconstruct_index_from_bits( acc } +#[derive(Debug, Clone)] +/// Given a field element as input, decompose it into its little-endian bits and +/// fill witness hints with the binary decomposition. +/// +/// For a given input `input`, fills `n_bits` witness hints with `b_i` +/// such that that: +/// input = Σ b_i · 2^i +struct BinaryDecompositionHint { + inputs: Vec, + n_bits: usize, + _phantom: PhantomData, +} + +impl BinaryDecompositionHint { + pub fn new(input: ExprId, n_bits: usize) -> Result { + if n_bits > 64 { + return Err(CircuitError::UnconstrainedOpInputLengthMismatch { + expected: 64, + got: n_bits, + }); + } + Ok(Self { + inputs: vec![input], + n_bits, + _phantom: PhantomData, + }) + } +} + +impl> WitnessHintFiller for BinaryDecompositionHint { + fn inputs(&self) -> &[ExprId] { + &self.inputs + } + + fn n_outputs(&self) -> usize { + self.n_bits + } + + fn compute_outputs(&self, inputs_val: Vec) -> Result, CircuitError> { + let val: u64 = inputs_val[0].as_basis_coefficients_slice()[0].as_canonical_u64(); + let bits = (0..self.n_bits) + .map(|i| F::from_bool(val >> i & 1 == 1)) + .collect(); + debug_assert!(self.n_bits as u64 >= log2_ceil_u64(val)); + Ok(bits) + } +} + /// Decompose a field element into its little-endian bits. /// /// For a given target `x`, this function creates `N_BITS` new boolean targets `b_i` /// and adds constraints to enforce that: /// x = Σ b_i · 2^i -pub fn decompose_to_bits( +pub fn decompose_to_bits, BF: PrimeField64>( builder: &mut CircuitBuilder, x: ExprId, n_bits: usize, -) -> Vec { +) -> Result, CircuitError> { builder.push_scope("decompose_to_bits"); - let mut bits = Vec::with_capacity(n_bits); - // Create bit witness variables - for _ in 0..n_bits { - let bit = builder.add_public_input(); // TODO: Should be witness - builder.assert_bool(bit); - bits.push(bit); - } + let binary_decomposition_hint = BinaryDecompositionHint::new(x, n_bits)?; + let bits = builder.alloc_witness_hints(binary_decomposition_hint, "decompose_to_bits"); // Constrain that the bits reconstruct to the original element let reconstructed = reconstruct_index_from_bits(builder, &bits); @@ -153,7 +199,7 @@ pub fn decompose_to_bits( builder.pop_scope(); - bits + Ok(bits) } /// Helper to pad trace values to power-of-two height with zeros @@ -385,28 +431,18 @@ mod tests { let value = builder.add_const(BabyBear::from_u64(6)); // Binary: 110 // Decompose into 3 bits - this creates its own public inputs for the bits - let bits = decompose_to_bits::(&mut builder, value, 3); + let bits = decompose_to_bits::(&mut builder, value, 3).unwrap(); // Build and run the circuit let circuit = builder.build().expect("Failed to build circuit"); - let mut runner = circuit.runner(); - - // Set public inputs: expected bit decomposition of 6 (binary: 110) in little-endian - let public_inputs = vec![ - BabyBear::ZERO, // bit 0: 0 - BabyBear::ONE, // bit 1: 1 - BabyBear::ONE, // bit 2: 1 - ]; + let runner = circuit.runner(); - runner - .set_public_inputs(&public_inputs) - .expect("Failed to set public inputs"); let traces = runner.run().expect("Failed to run circuit"); // Verify the bits are correctly decomposed - 6 = [0,1,1] in little-endian - assert_eq!(traces.public_trace.values[0], BabyBear::ZERO); // bit 0 - assert_eq!(traces.public_trace.values[1], BabyBear::ONE); // bit 1 - assert_eq!(traces.public_trace.values[2], BabyBear::ONE); // bit 2 + assert_eq!(traces.witness_trace.values[3], BabyBear::ZERO); // bit 0 + assert_eq!(traces.witness_trace.values[4], BabyBear::ONE); // bit 1 + assert_eq!(traces.witness_trace.values[5], BabyBear::ONE); // bit 2 // Also verify that the returned bits have the expected length assert_eq!(bits.len(), 3); diff --git a/mmcs-air/src/air.rs b/mmcs-air/src/air.rs index f23cd70..b019925 100644 --- a/mmcs-air/src/air.rs +++ b/mmcs-air/src/air.rs @@ -640,9 +640,9 @@ mod test { type MyConfig = StarkConfig; let config = MyConfig::new(pcs, challenger); - let proof = prove(&config, &air, trace, &vec![]); + let proof = prove(&config, &air, trace, &[]); // Verify the proof. - verify(&config, &air, &proof, &vec![]) + verify(&config, &air, &proof, &[]) } } diff --git a/recursion/src/generation.rs b/recursion/src/generation.rs index 911a5b9..5850b76 100644 --- a/recursion/src/generation.rs +++ b/recursion/src/generation.rs @@ -5,7 +5,7 @@ use itertools::zip_eq; use p3_air::Air; use p3_challenger::{CanObserve, CanSample, CanSampleBits, FieldChallenger, GrindingChallenger}; use p3_commit::{BatchOpening, Mmcs, Pcs, PolynomialSpace}; -use p3_field::{Field, PrimeCharacteristicRing, PrimeField, TwoAdicField}; +use p3_field::{PrimeCharacteristicRing, PrimeField, TwoAdicField}; use p3_fri::{FriProof, TwoAdicFriPcs}; use p3_uni_stark::{ Domain, Proof, StarkGenericConfig, SymbolicAirBuilder, Val, VerifierConstraintFolder, @@ -245,29 +245,15 @@ where if params.len() != 2 { return Err(GenerationError::InvalidParameterCount(params.len(), 2)); } - // Observe PoW and sample bits. - let pow_bits = params[0]; + // Check PoW witness. challenger.observe(opening_proof.pow_witness); - // Sample a challenge and decompose it into bits. Add all bits to the challenges. + // Sample a challenge as H(transcript || pow_witness). The circuit later + // verifies that the challenge begins with the required number of leading zeros. let rand_f: Val = challenger.sample(); let rand_usize = rand_f.as_canonical_biguint().to_u64_digits()[0] as usize; - // Get the bits. The total number of bits is the number of bits in a base field element. - let total_num_bits = Val::::bits(); - let rand_bits = (0..total_num_bits) - .map(|i| SC::Challenge::from_usize((rand_usize >> i) & 1)) - .collect::>(); - // Push the sampled challenge, along with the bits. challenges.push(SC::Challenge::from_usize(rand_usize)); - challenges.extend(rand_bits); - - // Check that the first bits are all 0. - let pow_challenge = rand_usize & ((1 << pow_bits) - 1); - - if pow_challenge != 0 { - return Err(GenerationError::InvalidPowWitness); - } let log_height_max = params[1]; let log_global_max_height = opening_proof.commit_phase_commits.len() + log_height_max; diff --git a/recursion/src/pcs/fri/targets.rs b/recursion/src/pcs/fri/targets.rs index 6e142ab..6b81403 100644 --- a/recursion/src/pcs/fri/targets.rs +++ b/recursion/src/pcs/fri/targets.rs @@ -7,7 +7,9 @@ use p3_circuit::CircuitBuilder; use p3_circuit::utils::{RowSelectorsTargets, decompose_to_bits}; use p3_commit::{BatchOpening, ExtensionMmcs, Mmcs, PolynomialSpace}; use p3_field::coset::TwoAdicMultiplicativeCoset; -use p3_field::{ExtensionField, Field, PackedValue, PrimeCharacteristicRing, TwoAdicField}; +use p3_field::{ + ExtensionField, Field, PackedValue, PrimeCharacteristicRing, PrimeField64, TwoAdicField, +}; use p3_fri::{CommitPhaseProofStep, FriProof, QueryProof, TwoAdicFriPcs}; use p3_merkle_tree::MerkleTreeMmcs; use p3_symmetric::{CryptographicHasher, Hash, PseudoCompressionFunction}; @@ -429,6 +431,7 @@ where RecursiveFriMmcs::Commitment: ObservableCommitment, SC::Challenger: GrindingChallenger, SC::Challenger: CanObserve, + Val: PrimeField64, { type VerifierParams = FriVerifierParams; type RecursiveProof = RecursiveFriProof< @@ -531,7 +534,8 @@ where let index_bits_per_query: Vec> = query_indices .iter() .map(|&index_target| { - let all_bits = decompose_to_bits(circuit, index_target, MAX_QUERY_INDEX_BITS); + let all_bits = + decompose_to_bits(circuit, index_target, MAX_QUERY_INDEX_BITS).unwrap(); all_bits.into_iter().take(log_max_height).collect() }) .collect(); diff --git a/recursion/src/public_inputs.rs b/recursion/src/public_inputs.rs index 4bb8589..1ad1c05 100644 --- a/recursion/src/public_inputs.rs +++ b/recursion/src/public_inputs.rs @@ -180,7 +180,6 @@ where /// 1. AIR public values /// 2. Proof values /// 3. All challenges (alpha, zeta, zeta_next, betas, query indices) - /// 4. Query index bit decompositions (MAX_QUERY_INDEX_BITS per query) pub fn build(self) -> Vec { let mut builder = PublicInputBuilder::new(); @@ -188,24 +187,6 @@ where builder.add_proof_values(self.proof_values); builder.add_challenges(self.challenges.iter().copied()); - // The circuit calls decompose_to_bits on each query index, - // which creates MAX_QUERY_INDEX_BITS additional public inputs per query - let num_regular_challenges = self.challenges.len() - self.num_queries; - for &query_index in &self.challenges[num_regular_challenges..] { - let coeffs = query_index.as_basis_coefficients_slice(); - let index_usize = coeffs[0].as_canonical_u64() as usize; - - // Add bit decomposition (MAX_QUERY_INDEX_BITS public inputs) - for k in 0..MAX_QUERY_INDEX_BITS { - let bit: EF = if (index_usize >> k) & 1 == 1 { - EF::ONE - } else { - EF::ZERO - }; - builder.add_challenge(bit); - } - } - builder.build() } } diff --git a/recursion/src/traits/challenger.rs b/recursion/src/traits/challenger.rs index 8ceb7dd..cf195a1 100644 --- a/recursion/src/traits/challenger.rs +++ b/recursion/src/traits/challenger.rs @@ -4,7 +4,7 @@ use alloc::vec::Vec; use p3_circuit::CircuitBuilder; use p3_circuit::utils::decompose_to_bits; -use p3_field::Field; +use p3_field::{ExtensionField, Field, PrimeField64}; use crate::Target; @@ -82,16 +82,19 @@ pub trait RecursiveChallenger { /// /// # Returns /// Vector of the first `num_bits` bits as targets (each in {0, 1}) - fn sample_public_bits( + fn sample_public_bits( &mut self, circuit: &mut CircuitBuilder, total_num_bits: usize, num_bits: usize, - ) -> Vec { + ) -> Vec + where + F: ExtensionField, + { let x = self.sample(circuit); // Decompose to bits (adds public inputs for each bit and verifies they reconstruct x) - let bits = decompose_to_bits(circuit, x, total_num_bits); + let bits = decompose_to_bits(circuit, x, total_num_bits).unwrap(); bits[..num_bits].to_vec() } @@ -106,13 +109,15 @@ pub trait RecursiveChallenger { /// - `witness_bits`: Number of leading bits that must be zero /// - `witness`: The proof-of-work witness target /// - `total_num_bits`: Total number of bits to decompose - fn check_witness( + fn check_witness( &mut self, circuit: &mut CircuitBuilder, witness_bits: usize, witness: Target, total_num_bits: usize, - ) { + ) where + F: ExtensionField, + { self.observe(circuit, witness); let bits = self.sample_public_bits(circuit, total_num_bits, witness_bits);