@@ -2,6 +2,7 @@ use alloc::vec::Vec;
22use core:: hash:: Hash ;
33
44use hashbrown:: HashMap ;
5+ use itertools:: zip_eq;
56use p3_field:: { Field , PrimeCharacteristicRing } ;
67
78use super :: compiler:: { ExpressionLowerer , NonPrimitiveLowerer , Optimizer } ;
@@ -188,6 +189,72 @@ where
188189 self . expr_builder . add_mul ( lhs, rhs, label)
189190 }
190191
192+ /// Computes and returns `a * b + c`.
193+ ///
194+ /// This is a common fused operation in cryptographic circuits.
195+ ///
196+ /// # Arguments
197+ /// * `a`, `b`, `c`: The expressions to operate on.
198+ ///
199+ /// # Returns
200+ /// A new `ExprId` representing the result of `a * b + c`.
201+ ///
202+ /// # Cost
203+ /// 1 multiplication and 1 addition constraint.
204+ pub fn mul_add ( & mut self , a : ExprId , b : ExprId , c : ExprId ) -> ExprId {
205+ let product = self . mul ( a, b) ;
206+ self . add ( product, c)
207+ }
208+
209+ /// Multiplies a slice of expressions together.
210+ ///
211+ /// # Arguments
212+ /// * `inputs`: A slice of `ExprId`s to multiply.
213+ ///
214+ /// # Returns
215+ /// A new `ExprId` representing the product of all inputs. Returns `1` if the slice is empty.
216+ ///
217+ /// # Cost
218+ /// `N-1` multiplication constraints, where `N` is the number of inputs.
219+ pub fn mul_many ( & mut self , inputs : & [ ExprId ] ) -> ExprId {
220+ // Handle edge cases for empty or single-element slices.
221+ if inputs. is_empty ( ) {
222+ return self . add_const ( F :: ONE ) ;
223+ }
224+ if inputs. len ( ) == 1 {
225+ return inputs[ 0 ] ;
226+ }
227+
228+ // Efficiently multiply all elements using a fold.
229+ inputs
230+ . iter ( )
231+ . skip ( 1 )
232+ . fold ( inputs[ 0 ] , |acc, & x| self . mul ( acc, x) )
233+ }
234+
235+ /// Computes the inner product (dot product) of two slices of expressions.
236+ ///
237+ /// Computes `∑ (a[i] * b[i])`.
238+ ///
239+ /// # Arguments
240+ /// * `a`: The first slice of `ExprId`s.
241+ /// * `b`: The second slice of `ExprId`s.
242+ ///
243+ /// # Panics
244+ /// Panics if the input slices `a` and `b` have different lengths.
245+ ///
246+ /// # Returns
247+ /// A new `ExprId` representing the inner product.
248+ ///
249+ /// # Cost
250+ /// `N` multiplications and `N-1` additions, where `N` is the length of the slices.
251+ pub fn inner_product ( & mut self , a : & [ ExprId ] , b : & [ ExprId ] ) -> ExprId {
252+ let zero = self . add_const ( F :: ZERO ) ;
253+
254+ // Calculate the sum of element-wise products.
255+ zip_eq ( a, b) . fold ( zero, |acc, ( & x, & y) | self . mul_add ( x, y, acc) )
256+ }
257+
191258 /// Divides two expressions.
192259 ///
193260 /// Cost: 1 row in Mul table + 1 row in witness table (encoded as rhs * out = lhs).
@@ -355,7 +422,12 @@ where
355422
356423#[ cfg( test) ]
357424mod tests {
425+ use alloc:: vec;
426+ use alloc:: vec:: Vec ;
427+
358428 use p3_baby_bear:: BabyBear ;
429+ use p3_field:: PrimeCharacteristicRing ;
430+ use proptest:: prelude:: * ;
359431
360432 use super :: * ;
361433
@@ -627,15 +699,6 @@ mod tests {
627699 assert_eq ! ( circuit. witness_count, 2 ) ;
628700 assert_eq ! ( circuit. primitive_ops. len( ) , 2 ) ;
629701 }
630- }
631-
632- #[ cfg( test) ]
633- mod proptests {
634- use p3_baby_bear:: BabyBear ;
635- use p3_field:: PrimeCharacteristicRing ;
636- use proptest:: prelude:: * ;
637-
638- use super :: * ;
639702
640703 // Strategy for generating valid field elements
641704 fn field_element ( ) -> impl Strategy < Value = BabyBear > {
@@ -773,4 +836,276 @@ mod proptests {
773836 ) ;
774837 }
775838 }
839+
840+ #[ test]
841+ fn test_mul_add ( ) {
842+ // Test case 1: Basic computation (3 * 4 + 5 = 17)
843+ {
844+ let mut builder = CircuitBuilder :: < BabyBear > :: new ( ) ;
845+ let a = builder. add_const ( BabyBear :: from_u64 ( 3 ) ) ;
846+ let b = builder. add_const ( BabyBear :: from_u64 ( 4 ) ) ;
847+ let c = builder. add_const ( BabyBear :: from_u64 ( 5 ) ) ;
848+ let result = builder. mul_add ( a, b, c) ;
849+
850+ let circuit = builder. build ( ) . unwrap ( ) ;
851+ let runner = circuit. runner ( ) ;
852+ let traces = runner. run ( ) . unwrap ( ) ;
853+
854+ assert_eq ! (
855+ traces. witness_trace. values[ result. 0 as usize ] ,
856+ BabyBear :: from_u64( 17 )
857+ ) ;
858+ }
859+
860+ // Test case 2: With zero product (0 * 7 + 9 = 9)
861+ {
862+ let mut builder = CircuitBuilder :: < BabyBear > :: new ( ) ;
863+ let zero = builder. add_const ( BabyBear :: ZERO ) ;
864+ let b = builder. add_const ( BabyBear :: from_u64 ( 7 ) ) ;
865+ let c = builder. add_const ( BabyBear :: from_u64 ( 9 ) ) ;
866+ let result = builder. mul_add ( zero, b, c) ;
867+
868+ let circuit = builder. build ( ) . unwrap ( ) ;
869+ let runner = circuit. runner ( ) ;
870+ let traces = runner. run ( ) . unwrap ( ) ;
871+
872+ assert_eq ! (
873+ traces. witness_trace. values[ result. 0 as usize ] ,
874+ BabyBear :: from_u64( 9 )
875+ ) ;
876+ }
877+ }
878+
879+ #[ test]
880+ fn test_mul_many ( ) {
881+ // Test case 1: Empty slice returns 1 (multiplicative identity)
882+ {
883+ let mut builder = CircuitBuilder :: < BabyBear > :: new ( ) ;
884+ let result = builder. mul_many ( & [ ] ) ;
885+
886+ let circuit = builder. build ( ) . unwrap ( ) ;
887+ let runner = circuit. runner ( ) ;
888+ let traces = runner. run ( ) . unwrap ( ) ;
889+
890+ assert_eq ! (
891+ traces. witness_trace. values[ result. 0 as usize ] ,
892+ BabyBear :: ONE
893+ ) ;
894+ }
895+
896+ // Test case 2: Multiple elements [2, 3, 4, 5] = 120
897+ {
898+ let mut builder = CircuitBuilder :: < BabyBear > :: new ( ) ;
899+ let vals: Vec < ExprId > = vec ! [ 2 , 3 , 4 , 5 ]
900+ . into_iter ( )
901+ . map ( |v| builder. add_const ( BabyBear :: from_u64 ( v) ) )
902+ . collect ( ) ;
903+ let result = builder. mul_many ( & vals) ;
904+
905+ let circuit = builder. build ( ) . unwrap ( ) ;
906+ let runner = circuit. runner ( ) ;
907+ let traces = runner. run ( ) . unwrap ( ) ;
908+
909+ assert_eq ! (
910+ traces. witness_trace. values[ result. 0 as usize ] ,
911+ BabyBear :: from_u64( 120 )
912+ ) ;
913+ }
914+
915+ // Test case 3: With zero element [5, 0, 7] = 0
916+ {
917+ let mut builder = CircuitBuilder :: < BabyBear > :: new ( ) ;
918+ let with_zero = vec ! [
919+ builder. add_const( BabyBear :: from_u64( 5 ) ) ,
920+ builder. add_const( BabyBear :: ZERO ) ,
921+ builder. add_const( BabyBear :: from_u64( 7 ) ) ,
922+ ] ;
923+ let result = builder. mul_many ( & with_zero) ;
924+
925+ let circuit = builder. build ( ) . unwrap ( ) ;
926+ let runner = circuit. runner ( ) ;
927+ let traces = runner. run ( ) . unwrap ( ) ;
928+
929+ assert_eq ! (
930+ traces. witness_trace. values[ result. 0 as usize ] ,
931+ BabyBear :: ZERO
932+ ) ;
933+ }
934+ }
935+
936+ #[ test]
937+ fn test_inner_product ( ) {
938+ // Test case 1: Basic dot product [1,2,3] · [4,5,6] = 32
939+ {
940+ let mut builder = CircuitBuilder :: < BabyBear > :: new ( ) ;
941+ let a: Vec < ExprId > = vec ! [ 1 , 2 , 3 ]
942+ . into_iter ( )
943+ . map ( |v| builder. add_const ( BabyBear :: from_u64 ( v) ) )
944+ . collect ( ) ;
945+ let b: Vec < ExprId > = vec ! [ 4 , 5 , 6 ]
946+ . into_iter ( )
947+ . map ( |v| builder. add_const ( BabyBear :: from_u64 ( v) ) )
948+ . collect ( ) ;
949+ let result = builder. inner_product ( & a, & b) ;
950+
951+ let circuit = builder. build ( ) . unwrap ( ) ;
952+ let runner = circuit. runner ( ) ;
953+ let traces = runner. run ( ) . unwrap ( ) ;
954+
955+ assert_eq ! (
956+ traces. witness_trace. values[ result. 0 as usize ] ,
957+ BabyBear :: from_u64( 32 )
958+ ) ;
959+ }
960+
961+ // Test case 2: Empty vectors [] · [] = 0
962+ {
963+ let mut builder = CircuitBuilder :: < BabyBear > :: new ( ) ;
964+ let empty_a: Vec < ExprId > = vec ! [ ] ;
965+ let empty_b: Vec < ExprId > = vec ! [ ] ;
966+ let result = builder. inner_product ( & empty_a, & empty_b) ;
967+
968+ let circuit = builder. build ( ) . unwrap ( ) ;
969+ let runner = circuit. runner ( ) ;
970+ let traces = runner. run ( ) . unwrap ( ) ;
971+
972+ assert_eq ! (
973+ traces. witness_trace. values[ result. 0 as usize ] ,
974+ BabyBear :: ZERO
975+ ) ;
976+ }
977+
978+ // Test case 3: Zero vector [0,0,0] · [5,6,7] = 0
979+ {
980+ let mut builder = CircuitBuilder :: < BabyBear > :: new ( ) ;
981+ let zeros: Vec < ExprId > = ( 0 ..3 ) . map ( |_| builder. add_const ( BabyBear :: ZERO ) ) . collect ( ) ;
982+ let vals: Vec < ExprId > = vec ! [ 5 , 6 , 7 ]
983+ . into_iter ( )
984+ . map ( |v| builder. add_const ( BabyBear :: from_u64 ( v) ) )
985+ . collect ( ) ;
986+ let result = builder. inner_product ( & zeros, & vals) ;
987+
988+ let circuit = builder. build ( ) . unwrap ( ) ;
989+ let runner = circuit. runner ( ) ;
990+ let traces = runner. run ( ) . unwrap ( ) ;
991+
992+ assert_eq ! (
993+ traces. witness_trace. values[ result. 0 as usize ] ,
994+ BabyBear :: ZERO
995+ ) ;
996+ }
997+ }
998+
999+ #[ test]
1000+ #[ should_panic]
1001+ fn test_inner_product_mismatched_lengths ( ) {
1002+ // Verify that inner_product panics with mismatched vector lengths
1003+ let mut builder = CircuitBuilder :: < BabyBear > :: new ( ) ;
1004+
1005+ // Create vectors with different lengths: [1,2] vs [3,4,5]
1006+ let a: Vec < ExprId > = vec ! [ 1 , 2 ]
1007+ . into_iter ( )
1008+ . map ( |v| builder. add_const ( BabyBear :: from_u64 ( v) ) )
1009+ . collect ( ) ;
1010+ let b: Vec < ExprId > = vec ! [ 3 , 4 , 5 ]
1011+ . into_iter ( )
1012+ . map ( |v| builder. add_const ( BabyBear :: from_u64 ( v) ) )
1013+ . collect ( ) ;
1014+
1015+ // Should panic: lengths don't match (2 != 3)
1016+ builder. inner_product ( & a, & b) ;
1017+ }
1018+
1019+ proptest ! {
1020+ #[ test]
1021+ fn prop_mul_add_correctness(
1022+ a in field_element( ) ,
1023+ b in field_element( ) ,
1024+ c in field_element( )
1025+ ) {
1026+ // Build circuit with mul_add
1027+ let mut builder = CircuitBuilder :: <BabyBear >:: new( ) ;
1028+ let ca = builder. add_const( a) ;
1029+ let cb = builder. add_const( b) ;
1030+ let cc = builder. add_const( c) ;
1031+ let result = builder. mul_add( ca, cb, cc) ;
1032+
1033+ // Execute circuit
1034+ let circuit = builder. build( ) . unwrap( ) ;
1035+ let runner = circuit. runner( ) ;
1036+ let traces = runner. run( ) . unwrap( ) ;
1037+
1038+ // Compute expected value
1039+ let expected = a * b + c;
1040+
1041+ // Verify correctness
1042+ prop_assert_eq!(
1043+ traces. witness_trace. values[ result. 0 as usize ] ,
1044+ expected
1045+ ) ;
1046+ }
1047+
1048+ #[ test]
1049+ fn prop_mul_many_correctness(
1050+ values in prop:: collection:: vec( field_element( ) , 0 ..8 )
1051+ ) {
1052+ // Build circuit with mul_many
1053+ let mut builder = CircuitBuilder :: <BabyBear >:: new( ) ;
1054+ let expr_ids: Vec <ExprId > = values
1055+ . iter( )
1056+ . map( |& v| builder. add_const( v) )
1057+ . collect( ) ;
1058+ let result = builder. mul_many( & expr_ids) ;
1059+
1060+ // Execute circuit
1061+ let circuit = builder. build( ) . unwrap( ) ;
1062+ let runner = circuit. runner( ) ;
1063+ let traces = runner. run( ) . unwrap( ) ;
1064+
1065+ // Compute expected product (empty → 1, otherwise fold multiply)
1066+ let expected = if values. is_empty( ) {
1067+ BabyBear :: ONE
1068+ } else {
1069+ values. iter( ) . fold( BabyBear :: ONE , |acc, & x| acc * x)
1070+ } ;
1071+
1072+ // Verify correctness
1073+ prop_assert_eq!(
1074+ traces. witness_trace. values[ result. 0 as usize ] ,
1075+ expected
1076+ ) ;
1077+ }
1078+
1079+ #[ test]
1080+ fn prop_inner_product_correctness(
1081+ values in prop:: collection:: vec( ( field_element( ) , field_element( ) ) , 0 ..8 )
1082+ ) {
1083+ // Extract equal-length vectors from paired values
1084+ let vec1: Vec <BabyBear > = values. iter( ) . map( |( a, _) | * a) . collect( ) ;
1085+ let vec2: Vec <BabyBear > = values. iter( ) . map( |( _, b) | * b) . collect( ) ;
1086+
1087+ // Build circuit with inner_product
1088+ let mut builder = CircuitBuilder :: <BabyBear >:: new( ) ;
1089+ let a: Vec <ExprId > = vec1. iter( ) . map( |& v| builder. add_const( v) ) . collect( ) ;
1090+ let b: Vec <ExprId > = vec2. iter( ) . map( |& v| builder. add_const( v) ) . collect( ) ;
1091+ let result = builder. inner_product( & a, & b) ;
1092+
1093+ // Execute circuit
1094+ let circuit = builder. build( ) . unwrap( ) ;
1095+ let runner = circuit. runner( ) ;
1096+ let traces = runner. run( ) . unwrap( ) ;
1097+
1098+ // Compute expected dot product: Σ(a_i * b_i)
1099+ let expected = vec1
1100+ . iter( )
1101+ . zip( vec2. iter( ) )
1102+ . fold( BabyBear :: ZERO , |acc, ( & x, & y) | acc + x * y) ;
1103+
1104+ // Verify correctness
1105+ prop_assert_eq!(
1106+ traces. witness_trace. values[ result. 0 as usize ] ,
1107+ expected
1108+ ) ;
1109+ }
1110+ }
7761111}
0 commit comments