Skip to content

Commit fe5c5ee

Browse files
authored
circuit builder: more multiplication utils (#161)
* circuit builder: more multiplication utils * fix comment * refactor quotient with new helpers
1 parent a8760d0 commit fe5c5ee

File tree

2 files changed

+654
-80
lines changed

2 files changed

+654
-80
lines changed

circuit/src/builder/circuit_builder.rs

Lines changed: 344 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ use alloc::vec::Vec;
22
use core::hash::Hash;
33

44
use hashbrown::HashMap;
5+
use itertools::zip_eq;
56
use p3_field::{Field, PrimeCharacteristicRing};
67

78
use super::compiler::{ExpressionLowerer, NonPrimitiveLowerer, Optimizer};
@@ -188,6 +189,72 @@ where
188189
self.expr_builder.add_mul(lhs, rhs, label)
189190
}
190191

192+
/// Computes and returns `a * b + c`.
193+
///
194+
/// This is a common fused operation in cryptographic circuits.
195+
///
196+
/// # Arguments
197+
/// * `a`, `b`, `c`: The expressions to operate on.
198+
///
199+
/// # Returns
200+
/// A new `ExprId` representing the result of `a * b + c`.
201+
///
202+
/// # Cost
203+
/// 1 multiplication and 1 addition constraint.
204+
pub fn mul_add(&mut self, a: ExprId, b: ExprId, c: ExprId) -> ExprId {
205+
let product = self.mul(a, b);
206+
self.add(product, c)
207+
}
208+
209+
/// Multiplies a slice of expressions together.
210+
///
211+
/// # Arguments
212+
/// * `inputs`: A slice of `ExprId`s to multiply.
213+
///
214+
/// # Returns
215+
/// A new `ExprId` representing the product of all inputs. Returns `1` if the slice is empty.
216+
///
217+
/// # Cost
218+
/// `N-1` multiplication constraints, where `N` is the number of inputs.
219+
pub fn mul_many(&mut self, inputs: &[ExprId]) -> ExprId {
220+
// Handle edge cases for empty or single-element slices.
221+
if inputs.is_empty() {
222+
return self.add_const(F::ONE);
223+
}
224+
if inputs.len() == 1 {
225+
return inputs[0];
226+
}
227+
228+
// Efficiently multiply all elements using a fold.
229+
inputs
230+
.iter()
231+
.skip(1)
232+
.fold(inputs[0], |acc, &x| self.mul(acc, x))
233+
}
234+
235+
/// Computes the inner product (dot product) of two slices of expressions.
236+
///
237+
/// Computes `∑ (a[i] * b[i])`.
238+
///
239+
/// # Arguments
240+
/// * `a`: The first slice of `ExprId`s.
241+
/// * `b`: The second slice of `ExprId`s.
242+
///
243+
/// # Panics
244+
/// Panics if the input slices `a` and `b` have different lengths.
245+
///
246+
/// # Returns
247+
/// A new `ExprId` representing the inner product.
248+
///
249+
/// # Cost
250+
/// `N` multiplications and `N-1` additions, where `N` is the length of the slices.
251+
pub fn inner_product(&mut self, a: &[ExprId], b: &[ExprId]) -> ExprId {
252+
let zero = self.add_const(F::ZERO);
253+
254+
// Calculate the sum of element-wise products.
255+
zip_eq(a, b).fold(zero, |acc, (&x, &y)| self.mul_add(x, y, acc))
256+
}
257+
191258
/// Divides two expressions.
192259
///
193260
/// Cost: 1 row in Mul table + 1 row in witness table (encoded as rhs * out = lhs).
@@ -355,7 +422,12 @@ where
355422

356423
#[cfg(test)]
357424
mod tests {
425+
use alloc::vec;
426+
use alloc::vec::Vec;
427+
358428
use p3_baby_bear::BabyBear;
429+
use p3_field::PrimeCharacteristicRing;
430+
use proptest::prelude::*;
359431

360432
use super::*;
361433

@@ -627,15 +699,6 @@ mod tests {
627699
assert_eq!(circuit.witness_count, 2);
628700
assert_eq!(circuit.primitive_ops.len(), 2);
629701
}
630-
}
631-
632-
#[cfg(test)]
633-
mod proptests {
634-
use p3_baby_bear::BabyBear;
635-
use p3_field::PrimeCharacteristicRing;
636-
use proptest::prelude::*;
637-
638-
use super::*;
639702

640703
// Strategy for generating valid field elements
641704
fn field_element() -> impl Strategy<Value = BabyBear> {
@@ -773,4 +836,276 @@ mod proptests {
773836
);
774837
}
775838
}
839+
840+
#[test]
841+
fn test_mul_add() {
842+
// Test case 1: Basic computation (3 * 4 + 5 = 17)
843+
{
844+
let mut builder = CircuitBuilder::<BabyBear>::new();
845+
let a = builder.add_const(BabyBear::from_u64(3));
846+
let b = builder.add_const(BabyBear::from_u64(4));
847+
let c = builder.add_const(BabyBear::from_u64(5));
848+
let result = builder.mul_add(a, b, c);
849+
850+
let circuit = builder.build().unwrap();
851+
let runner = circuit.runner();
852+
let traces = runner.run().unwrap();
853+
854+
assert_eq!(
855+
traces.witness_trace.values[result.0 as usize],
856+
BabyBear::from_u64(17)
857+
);
858+
}
859+
860+
// Test case 2: With zero product (0 * 7 + 9 = 9)
861+
{
862+
let mut builder = CircuitBuilder::<BabyBear>::new();
863+
let zero = builder.add_const(BabyBear::ZERO);
864+
let b = builder.add_const(BabyBear::from_u64(7));
865+
let c = builder.add_const(BabyBear::from_u64(9));
866+
let result = builder.mul_add(zero, b, c);
867+
868+
let circuit = builder.build().unwrap();
869+
let runner = circuit.runner();
870+
let traces = runner.run().unwrap();
871+
872+
assert_eq!(
873+
traces.witness_trace.values[result.0 as usize],
874+
BabyBear::from_u64(9)
875+
);
876+
}
877+
}
878+
879+
#[test]
880+
fn test_mul_many() {
881+
// Test case 1: Empty slice returns 1 (multiplicative identity)
882+
{
883+
let mut builder = CircuitBuilder::<BabyBear>::new();
884+
let result = builder.mul_many(&[]);
885+
886+
let circuit = builder.build().unwrap();
887+
let runner = circuit.runner();
888+
let traces = runner.run().unwrap();
889+
890+
assert_eq!(
891+
traces.witness_trace.values[result.0 as usize],
892+
BabyBear::ONE
893+
);
894+
}
895+
896+
// Test case 2: Multiple elements [2, 3, 4, 5] = 120
897+
{
898+
let mut builder = CircuitBuilder::<BabyBear>::new();
899+
let vals: Vec<ExprId> = vec![2, 3, 4, 5]
900+
.into_iter()
901+
.map(|v| builder.add_const(BabyBear::from_u64(v)))
902+
.collect();
903+
let result = builder.mul_many(&vals);
904+
905+
let circuit = builder.build().unwrap();
906+
let runner = circuit.runner();
907+
let traces = runner.run().unwrap();
908+
909+
assert_eq!(
910+
traces.witness_trace.values[result.0 as usize],
911+
BabyBear::from_u64(120)
912+
);
913+
}
914+
915+
// Test case 3: With zero element [5, 0, 7] = 0
916+
{
917+
let mut builder = CircuitBuilder::<BabyBear>::new();
918+
let with_zero = vec![
919+
builder.add_const(BabyBear::from_u64(5)),
920+
builder.add_const(BabyBear::ZERO),
921+
builder.add_const(BabyBear::from_u64(7)),
922+
];
923+
let result = builder.mul_many(&with_zero);
924+
925+
let circuit = builder.build().unwrap();
926+
let runner = circuit.runner();
927+
let traces = runner.run().unwrap();
928+
929+
assert_eq!(
930+
traces.witness_trace.values[result.0 as usize],
931+
BabyBear::ZERO
932+
);
933+
}
934+
}
935+
936+
#[test]
937+
fn test_inner_product() {
938+
// Test case 1: Basic dot product [1,2,3] · [4,5,6] = 32
939+
{
940+
let mut builder = CircuitBuilder::<BabyBear>::new();
941+
let a: Vec<ExprId> = vec![1, 2, 3]
942+
.into_iter()
943+
.map(|v| builder.add_const(BabyBear::from_u64(v)))
944+
.collect();
945+
let b: Vec<ExprId> = vec![4, 5, 6]
946+
.into_iter()
947+
.map(|v| builder.add_const(BabyBear::from_u64(v)))
948+
.collect();
949+
let result = builder.inner_product(&a, &b);
950+
951+
let circuit = builder.build().unwrap();
952+
let runner = circuit.runner();
953+
let traces = runner.run().unwrap();
954+
955+
assert_eq!(
956+
traces.witness_trace.values[result.0 as usize],
957+
BabyBear::from_u64(32)
958+
);
959+
}
960+
961+
// Test case 2: Empty vectors [] · [] = 0
962+
{
963+
let mut builder = CircuitBuilder::<BabyBear>::new();
964+
let empty_a: Vec<ExprId> = vec![];
965+
let empty_b: Vec<ExprId> = vec![];
966+
let result = builder.inner_product(&empty_a, &empty_b);
967+
968+
let circuit = builder.build().unwrap();
969+
let runner = circuit.runner();
970+
let traces = runner.run().unwrap();
971+
972+
assert_eq!(
973+
traces.witness_trace.values[result.0 as usize],
974+
BabyBear::ZERO
975+
);
976+
}
977+
978+
// Test case 3: Zero vector [0,0,0] · [5,6,7] = 0
979+
{
980+
let mut builder = CircuitBuilder::<BabyBear>::new();
981+
let zeros: Vec<ExprId> = (0..3).map(|_| builder.add_const(BabyBear::ZERO)).collect();
982+
let vals: Vec<ExprId> = vec![5, 6, 7]
983+
.into_iter()
984+
.map(|v| builder.add_const(BabyBear::from_u64(v)))
985+
.collect();
986+
let result = builder.inner_product(&zeros, &vals);
987+
988+
let circuit = builder.build().unwrap();
989+
let runner = circuit.runner();
990+
let traces = runner.run().unwrap();
991+
992+
assert_eq!(
993+
traces.witness_trace.values[result.0 as usize],
994+
BabyBear::ZERO
995+
);
996+
}
997+
}
998+
999+
#[test]
1000+
#[should_panic]
1001+
fn test_inner_product_mismatched_lengths() {
1002+
// Verify that inner_product panics with mismatched vector lengths
1003+
let mut builder = CircuitBuilder::<BabyBear>::new();
1004+
1005+
// Create vectors with different lengths: [1,2] vs [3,4,5]
1006+
let a: Vec<ExprId> = vec![1, 2]
1007+
.into_iter()
1008+
.map(|v| builder.add_const(BabyBear::from_u64(v)))
1009+
.collect();
1010+
let b: Vec<ExprId> = vec![3, 4, 5]
1011+
.into_iter()
1012+
.map(|v| builder.add_const(BabyBear::from_u64(v)))
1013+
.collect();
1014+
1015+
// Should panic: lengths don't match (2 != 3)
1016+
builder.inner_product(&a, &b);
1017+
}
1018+
1019+
proptest! {
1020+
#[test]
1021+
fn prop_mul_add_correctness(
1022+
a in field_element(),
1023+
b in field_element(),
1024+
c in field_element()
1025+
) {
1026+
// Build circuit with mul_add
1027+
let mut builder = CircuitBuilder::<BabyBear>::new();
1028+
let ca = builder.add_const(a);
1029+
let cb = builder.add_const(b);
1030+
let cc = builder.add_const(c);
1031+
let result = builder.mul_add(ca, cb, cc);
1032+
1033+
// Execute circuit
1034+
let circuit = builder.build().unwrap();
1035+
let runner = circuit.runner();
1036+
let traces = runner.run().unwrap();
1037+
1038+
// Compute expected value
1039+
let expected = a * b + c;
1040+
1041+
// Verify correctness
1042+
prop_assert_eq!(
1043+
traces.witness_trace.values[result.0 as usize],
1044+
expected
1045+
);
1046+
}
1047+
1048+
#[test]
1049+
fn prop_mul_many_correctness(
1050+
values in prop::collection::vec(field_element(), 0..8)
1051+
) {
1052+
// Build circuit with mul_many
1053+
let mut builder = CircuitBuilder::<BabyBear>::new();
1054+
let expr_ids: Vec<ExprId> = values
1055+
.iter()
1056+
.map(|&v| builder.add_const(v))
1057+
.collect();
1058+
let result = builder.mul_many(&expr_ids);
1059+
1060+
// Execute circuit
1061+
let circuit = builder.build().unwrap();
1062+
let runner = circuit.runner();
1063+
let traces = runner.run().unwrap();
1064+
1065+
// Compute expected product (empty → 1, otherwise fold multiply)
1066+
let expected = if values.is_empty() {
1067+
BabyBear::ONE
1068+
} else {
1069+
values.iter().fold(BabyBear::ONE, |acc, &x| acc * x)
1070+
};
1071+
1072+
// Verify correctness
1073+
prop_assert_eq!(
1074+
traces.witness_trace.values[result.0 as usize],
1075+
expected
1076+
);
1077+
}
1078+
1079+
#[test]
1080+
fn prop_inner_product_correctness(
1081+
values in prop::collection::vec((field_element(), field_element()), 0..8)
1082+
) {
1083+
// Extract equal-length vectors from paired values
1084+
let vec1: Vec<BabyBear> = values.iter().map(|(a, _)| *a).collect();
1085+
let vec2: Vec<BabyBear> = values.iter().map(|(_, b)| *b).collect();
1086+
1087+
// Build circuit with inner_product
1088+
let mut builder = CircuitBuilder::<BabyBear>::new();
1089+
let a: Vec<ExprId> = vec1.iter().map(|&v| builder.add_const(v)).collect();
1090+
let b: Vec<ExprId> = vec2.iter().map(|&v| builder.add_const(v)).collect();
1091+
let result = builder.inner_product(&a, &b);
1092+
1093+
// Execute circuit
1094+
let circuit = builder.build().unwrap();
1095+
let runner = circuit.runner();
1096+
let traces = runner.run().unwrap();
1097+
1098+
// Compute expected dot product: Σ(a_i * b_i)
1099+
let expected = vec1
1100+
.iter()
1101+
.zip(vec2.iter())
1102+
.fold(BabyBear::ZERO, |acc, (&x, &y)| acc + x * y);
1103+
1104+
// Verify correctness
1105+
prop_assert_eq!(
1106+
traces.witness_trace.values[result.0 as usize],
1107+
expected
1108+
);
1109+
}
1110+
}
7761111
}

0 commit comments

Comments
 (0)