Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
356 changes: 347 additions & 9 deletions circuit/src/builder/circuit_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,76 @@ where
self.expr_builder.add_mul(lhs, rhs, label)
}

/// Computes and returns `a * b + c`.
///
/// This is a common fused operation in cryptographic circuits.
///
/// # Arguments
/// * `a`, `b`, `c`: The expressions to operate on.
///
/// # Returns
/// A new `ExprId` representing the result of `a * b + c`.
///
/// # Cost
/// 1 multiplication and 1 addition constraint.
pub fn mul_add(&mut self, a: ExprId, b: ExprId, c: ExprId) -> ExprId {
let product = self.mul(a, b);
self.add(product, c)
}

/// Multiplies a slice of expressions together.
///
/// # Arguments
/// * `inputs`: A slice of `ExprId`s to multiply.
///
/// # Returns
/// A new `ExprId` representing the product of all inputs. Returns `1` if the slice is empty.
///
/// # Cost
/// `N-1` multiplication constraints, where `N` is the number of inputs.
pub fn mul_many(&mut self, inputs: &[ExprId]) -> ExprId {
// Handle edge cases for empty or single-element slices.
if inputs.is_empty() {
return self.add_const(F::ONE);
}
if inputs.len() == 1 {
return inputs[0];
}

// Efficiently multiply all elements using a fold.
inputs
.iter()
.skip(1)
.fold(inputs[0], |acc, &x| self.mul(acc, x))
}

/// Computes the inner product (dot product) of two slices of expressions.
///
/// Computes `∑ (a[i] * b[i])`.
///
/// # Arguments
/// * `a`: The first slice of `ExprId`s.
/// * `b`: The second slice of `ExprId`s.
///
/// # Panics
/// Panics if the input slices `a` and `b` have different lengths.
///
/// # Returns
/// A new `ExprId` representing the inner product.
///
/// # Cost
/// `N` multiplications and `N-1` additions, where `N` is the length of the slices.
pub fn inner_product(&mut self, a: &[ExprId], b: &[ExprId]) -> ExprId {
assert_eq!(a.len(), b.len(), "Input vectors must have the same length");

let zero = self.add_const(F::ZERO);

// Calculate the sum of element-wise products.
a.iter()
.zip(b.iter())
.fold(zero, |acc, (&x, &y)| self.mul_add(x, y, acc))
}

/// Divides two expressions.
///
/// Cost: 1 row in Mul table + 1 row in witness table (encoded as rhs * out = lhs).
Expand Down Expand Up @@ -355,7 +425,12 @@ where

#[cfg(test)]
mod tests {
use alloc::vec;
use alloc::vec::Vec;

use p3_baby_bear::BabyBear;
use p3_field::PrimeCharacteristicRing;
use proptest::prelude::*;

use super::*;

Expand Down Expand Up @@ -627,15 +702,6 @@ mod tests {
assert_eq!(circuit.witness_count, 2);
assert_eq!(circuit.primitive_ops.len(), 2);
}
}

#[cfg(test)]
mod proptests {
use p3_baby_bear::BabyBear;
use p3_field::PrimeCharacteristicRing;
use proptest::prelude::*;

use super::*;

// Strategy for generating valid field elements
fn field_element() -> impl Strategy<Value = BabyBear> {
Expand Down Expand Up @@ -773,4 +839,276 @@ mod proptests {
);
}
}

#[test]
fn test_mul_add() {
// Test case 1: Basic computation (3 * 4 + 5 = 17)
{
let mut builder = CircuitBuilder::<BabyBear>::new();
let a = builder.add_const(BabyBear::from_u64(3));
let b = builder.add_const(BabyBear::from_u64(4));
let c = builder.add_const(BabyBear::from_u64(5));
let result = builder.mul_add(a, b, c);

let circuit = builder.build().unwrap();
let runner = circuit.runner();
let traces = runner.run().unwrap();

assert_eq!(
traces.witness_trace.values[result.0 as usize],
BabyBear::from_u64(17)
);
}

// Test case 2: With zero product (0 * 7 + 9 = 9)
{
let mut builder = CircuitBuilder::<BabyBear>::new();
let zero = builder.add_const(BabyBear::ZERO);
let b = builder.add_const(BabyBear::from_u64(7));
let c = builder.add_const(BabyBear::from_u64(9));
let result = builder.mul_add(zero, b, c);

let circuit = builder.build().unwrap();
let runner = circuit.runner();
let traces = runner.run().unwrap();

assert_eq!(
traces.witness_trace.values[result.0 as usize],
BabyBear::from_u64(9)
);
}
}

#[test]
fn test_mul_many() {
// Test case 1: Empty slice returns 1 (multiplicative identity)
{
let mut builder = CircuitBuilder::<BabyBear>::new();
let result = builder.mul_many(&[]);

let circuit = builder.build().unwrap();
let runner = circuit.runner();
let traces = runner.run().unwrap();

assert_eq!(
traces.witness_trace.values[result.0 as usize],
BabyBear::ONE
);
}

// Test case 2: Multiple elements [2, 3, 4, 5] = 120
{
let mut builder = CircuitBuilder::<BabyBear>::new();
let vals: Vec<ExprId> = vec![2, 3, 4, 5]
.into_iter()
.map(|v| builder.add_const(BabyBear::from_u64(v)))
.collect();
let result = builder.mul_many(&vals);

let circuit = builder.build().unwrap();
let runner = circuit.runner();
let traces = runner.run().unwrap();

assert_eq!(
traces.witness_trace.values[result.0 as usize],
BabyBear::from_u64(120)
);
}

// Test case 3: With zero element [5, 0, 7] = 0
{
let mut builder = CircuitBuilder::<BabyBear>::new();
let with_zero = vec![
builder.add_const(BabyBear::from_u64(5)),
builder.add_const(BabyBear::ZERO),
builder.add_const(BabyBear::from_u64(7)),
];
let result = builder.mul_many(&with_zero);

let circuit = builder.build().unwrap();
let runner = circuit.runner();
let traces = runner.run().unwrap();

assert_eq!(
traces.witness_trace.values[result.0 as usize],
BabyBear::ZERO
);
}
}

#[test]
fn test_inner_product() {
// Test case 1: Basic dot product [1,2,3] · [4,5,6] = 32
{
let mut builder = CircuitBuilder::<BabyBear>::new();
let a: Vec<ExprId> = vec![1, 2, 3]
.into_iter()
.map(|v| builder.add_const(BabyBear::from_u64(v)))
.collect();
let b: Vec<ExprId> = vec![4, 5, 6]
.into_iter()
.map(|v| builder.add_const(BabyBear::from_u64(v)))
.collect();
let result = builder.inner_product(&a, &b);

let circuit = builder.build().unwrap();
let runner = circuit.runner();
let traces = runner.run().unwrap();

assert_eq!(
traces.witness_trace.values[result.0 as usize],
BabyBear::from_u64(32)
);
}

// Test case 2: Empty vectors [] · [] = 0
{
let mut builder = CircuitBuilder::<BabyBear>::new();
let empty_a: Vec<ExprId> = vec![];
let empty_b: Vec<ExprId> = vec![];
let result = builder.inner_product(&empty_a, &empty_b);

let circuit = builder.build().unwrap();
let runner = circuit.runner();
let traces = runner.run().unwrap();

assert_eq!(
traces.witness_trace.values[result.0 as usize],
BabyBear::ZERO
);
}

// Test case 3: Zero vector [0,0,0] · [5,6,7] = 0
{
let mut builder = CircuitBuilder::<BabyBear>::new();
let zeros: Vec<ExprId> = (0..3).map(|_| builder.add_const(BabyBear::ZERO)).collect();
let vals: Vec<ExprId> = vec![5, 6, 7]
.into_iter()
.map(|v| builder.add_const(BabyBear::from_u64(v)))
.collect();
let result = builder.inner_product(&zeros, &vals);

let circuit = builder.build().unwrap();
let runner = circuit.runner();
let traces = runner.run().unwrap();

assert_eq!(
traces.witness_trace.values[result.0 as usize],
BabyBear::ZERO
);
}
}

#[test]
#[should_panic(expected = "Input vectors must have the same length")]
fn test_inner_product_mismatched_lengths() {
// Verify that inner_product panics with mismatched vector lengths
let mut builder = CircuitBuilder::<BabyBear>::new();

// Create vectors with different lengths: [1,2] vs [3,4,5]
let a: Vec<ExprId> = vec![1, 2]
.into_iter()
.map(|v| builder.add_const(BabyBear::from_u64(v)))
.collect();
let b: Vec<ExprId> = vec![3, 4, 5]
.into_iter()
.map(|v| builder.add_const(BabyBear::from_u64(v)))
.collect();

// Should panic: lengths don't match (2 != 3)
builder.inner_product(&a, &b);
}

proptest! {
#[test]
fn prop_mul_add_correctness(
a in field_element(),
b in field_element(),
c in field_element()
) {
// Build circuit with mul_add
let mut builder = CircuitBuilder::<BabyBear>::new();
let ca = builder.add_const(a);
let cb = builder.add_const(b);
let cc = builder.add_const(c);
let result = builder.mul_add(ca, cb, cc);

// Execute circuit
let circuit = builder.build().unwrap();
let runner = circuit.runner();
let traces = runner.run().unwrap();

// Compute expected value
let expected = a * b + c;

// Verify correctness
prop_assert_eq!(
traces.witness_trace.values[result.0 as usize],
expected
);
}

#[test]
fn prop_mul_many_correctness(
values in prop::collection::vec(field_element(), 0..8)
) {
// Build circuit with mul_many
let mut builder = CircuitBuilder::<BabyBear>::new();
let expr_ids: Vec<ExprId> = values
.iter()
.map(|&v| builder.add_const(v))
.collect();
let result = builder.mul_many(&expr_ids);

// Execute circuit
let circuit = builder.build().unwrap();
let runner = circuit.runner();
let traces = runner.run().unwrap();

// Compute expected product (empty → 1, otherwise fold multiply)
let expected = if values.is_empty() {
BabyBear::ONE
} else {
values.iter().fold(BabyBear::ONE, |acc, &x| acc * x)
};

// Verify correctness
prop_assert_eq!(
traces.witness_trace.values[result.0 as usize],
expected
);
}

#[test]
fn prop_inner_product_correctness(
values in prop::collection::vec((field_element(), field_element()), 0..8)
) {
// Extract equal-length vectors from paired values
let vec1: Vec<BabyBear> = values.iter().map(|(a, _)| *a).collect();
let vec2: Vec<BabyBear> = values.iter().map(|(_, b)| *b).collect();

// Build circuit with inner_product
let mut builder = CircuitBuilder::<BabyBear>::new();
let a: Vec<ExprId> = vec1.iter().map(|&v| builder.add_const(v)).collect();
let b: Vec<ExprId> = vec2.iter().map(|&v| builder.add_const(v)).collect();
let result = builder.inner_product(&a, &b);

// Execute circuit
let circuit = builder.build().unwrap();
let runner = circuit.runner();
let traces = runner.run().unwrap();

// Compute expected dot product: Σ(a_i * b_i)
let expected = vec1
.iter()
.zip(vec2.iter())
.fold(BabyBear::ZERO, |acc, (&x, &y)| acc + x * y);

// Verify correctness
prop_assert_eq!(
traces.witness_trace.values[result.0 as usize],
expected
);
}
}
}
Loading