Skip to content

Commit 2f94892

Browse files
committed
Address remaining reviews
1 parent f9ef0e8 commit 2f94892

File tree

4 files changed

+21
-50
lines changed

4 files changed

+21
-50
lines changed

circuit/src/builder/circuit_builder.rs

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use super::{BuilderConfig, ExpressionBuilder};
1010
use crate::CircuitBuilderError;
1111
use crate::builder::public_input_tracker::PublicInputTracker;
1212
use crate::circuit::Circuit;
13-
use crate::op::{NonPrimitiveOpType, WitnessHintFiller};
13+
use crate::op::{DefaultHint, NonPrimitiveOpType, WitnessHintFiller};
1414
use crate::ops::MmcsVerifyConfig;
1515
use crate::types::{ExprId, NonPrimitiveOpId, WitnessAllocator, WitnessId};
1616

@@ -140,9 +140,8 @@ where
140140
count: usize,
141141
label: &'static str,
142142
) -> Vec<ExprId> {
143-
(0..count)
144-
.map(|_| self.expr_builder.add_witness_hint(label))
145-
.collect()
143+
self.expr_builder
144+
.add_witness_hints(DefaultHint { n_outputs: count }, label)
146145
}
147146

148147
/// Adds a constant to the circuit (deduplicated).
@@ -437,7 +436,6 @@ where
437436
#[cfg(test)]
438437
mod tests {
439438
use alloc::vec;
440-
use alloc::vec::Vec;
441439

442440
use p3_baby_bear::BabyBear;
443441
use p3_field::PrimeCharacteristicRing;
@@ -713,39 +711,12 @@ mod tests {
713711
assert_eq!(circuit.primitive_ops.len(), 2);
714712
}
715713

716-
#[derive(Debug, Clone)]
717-
struct ConstantHint<const C: usize> {
718-
inputs: Vec<ExprId>,
719-
}
720-
721-
impl<const C: usize> ConstantHint<C> {
722-
pub fn new(input: ExprId) -> Self {
723-
Self {
724-
inputs: vec![input],
725-
}
726-
}
727-
}
728-
729-
impl<F: Field, const C: usize> WitnessHintFiller<F> for ConstantHint<C> {
730-
fn inputs(&self) -> &[ExprId] {
731-
&self.inputs
732-
}
733-
734-
fn n_outputs(&self) -> usize {
735-
1
736-
}
737-
738-
fn compute_outputs(&self, _inputs_val: Vec<F>) -> Result<Vec<F>, crate::CircuitError> {
739-
Ok(vec![F::from_usize(C)])
740-
}
741-
}
742714
#[test]
743715
fn test_build_with_witness_hint() {
744716
let mut builder = CircuitBuilder::<BabyBear>::new();
745-
let a = builder.add_const(BabyBear::ZERO);
746-
let mock_filler = ConstantHint::<1>::new(a);
747-
let b = builder.alloc_witness_hints(mock_filler, "a");
748-
assert_eq!(b.len(), 1);
717+
let default_hint = DefaultHint { n_outputs: 1 };
718+
let a = builder.alloc_witness_hints(default_hint, "a");
719+
assert_eq!(a.len(), 1);
749720
let circuit = builder
750721
.build()
751722
.expect("Circuit with operations should build");
@@ -757,7 +728,7 @@ mod tests {
757728
crate::op::Op::Unconstrained {
758729
inputs, outputs, ..
759730
} => {
760-
assert_eq!(*inputs, vec![WitnessId(0)]);
731+
assert_eq!(*inputs, vec![]);
761732
assert_eq!(*outputs, vec![WitnessId(1)]);
762733
}
763734
_ => panic!("Expected Unconstrained at index 0"),

circuit/src/builder/compiler/expression_lowerer.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use crate::Op;
1010
use crate::builder::CircuitBuilderError;
1111
use crate::builder::compiler::get_witness_id;
1212
use crate::expr::{Expr, ExpressionGraph};
13-
use crate::op::{DefaultHint, WitnessHintFiller};
13+
use crate::op::WitnessHintFiller;
1414
use crate::types::{ExprId, WitnessAllocator, WitnessId};
1515

1616
/// Sparse disjoint-set "find" with path compression over a HashMap (iterative).
@@ -187,7 +187,7 @@ where
187187
}
188188
}
189189

190-
// Pass C: emit arithmetic ops in creation order; tie outputs to class slot if connected
190+
// Pass C: emit arithmetic and unconstrained ops in creation order; tie outputs to class slot if connected
191191
let mut hints_sequence = vec![];
192192
let mut fillers_iter = self.hints_with_fillers.iter().cloned();
193193
for (expr_idx, expr) in self.graph.nodes().iter().enumerate() {
@@ -200,7 +200,9 @@ where
200200
expr_to_widx.insert(expr_id, out_widx);
201201
hints_sequence.push(out_widx);
202202
if *last_hint {
203-
let filler = fillers_iter.next().unwrap_or(DefaultHint::boxed_default());
203+
let filler = fillers_iter.next().expect(
204+
"By construction, every sequence of witness must haver one filler",
205+
);
204206
let inputs = filler
205207
.inputs()
206208
.iter()

circuit/src/op.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -419,25 +419,27 @@ impl<F> Clone for Box<dyn WitnessHintFiller<F>> {
419419
}
420420

421421
#[derive(Debug, Clone, Default)]
422-
pub struct DefaultHint {}
422+
pub struct DefaultHint {
423+
pub n_outputs: usize,
424+
}
423425

424426
impl DefaultHint {
425-
pub fn boxed_default<F: Default>() -> Box<dyn WitnessHintFiller<F>> {
427+
pub fn boxed_default<F: Default + Clone>() -> Box<dyn WitnessHintFiller<F>> {
426428
Box::new(Self::default())
427429
}
428430
}
429431

430-
impl<F: Default> WitnessHintFiller<F> for DefaultHint {
432+
impl<F: Default + Clone> WitnessHintFiller<F> for DefaultHint {
431433
fn inputs(&self) -> &[ExprId] {
432434
&[]
433435
}
434436

435437
fn n_outputs(&self) -> usize {
436-
1
438+
self.n_outputs
437439
}
438440

439441
fn compute_outputs(&self, _inputs_val: Vec<F>) -> Result<Vec<F>, CircuitError> {
440-
Ok(vec![F::default()])
442+
Ok(vec![F::default(); self.n_outputs])
441443
}
442444
}
443445

circuit/src/utils.rs

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use core::marker::PhantomData;
44

55
use p3_field::{ExtensionField, Field, PrimeField64};
66
use p3_uni_stark::{Entry, SymbolicExpression};
7+
use p3_util::log2_ceil_u64;
78

89
use crate::op::WitnessHintFiller;
910
use crate::{CircuitBuilder, CircuitError, ExprId};
@@ -167,16 +168,11 @@ impl<BF: PrimeField64, F: ExtensionField<BF>> WitnessHintFiller<F> for BinaryDec
167168
}
168169

169170
fn compute_outputs(&self, inputs_val: Vec<F>) -> Result<Vec<F>, CircuitError> {
170-
if inputs_val.len() != 1 {
171-
return Err(CircuitError::UnconstrainedOpInputLengthMismatch {
172-
expected: 1,
173-
got: inputs_val.len(),
174-
});
175-
}
176171
let val: u64 = inputs_val[0].as_basis_coefficients_slice()[0].as_canonical_u64();
177172
let bits = (0..self.n_bits)
178173
.map(|i| F::from_bool(val >> i & 1 == 1))
179174
.collect();
175+
debug_assert!(self.n_bits as u64 >= log2_ceil_u64(val));
180176
Ok(bits)
181177
}
182178
}

0 commit comments

Comments
 (0)