Skip to content

Commit f2c1da5

Browse files
committed
refactor a bit more
1 parent 1e5c9d9 commit f2c1da5

File tree

2 files changed

+171
-63
lines changed

2 files changed

+171
-63
lines changed

src/sumcheck/constraints.rs

Lines changed: 26 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -149,19 +149,22 @@ impl<FE: CodecFieldElement + LagrangePolynomialFieldElement> ProofConstraints<FE
149149

150150
let mut layer_claim = SymbolicExpression::new(layer_index);
151151

152-
// Known portion of layer's initial claim.
153-
layer_claim += claims[0] + alpha * claims[1];
152+
let mut claim_0 = SymbolicTerm::from_known(claims[0]);
153+
let mut claim_1 = SymbolicTerm::from_known(claims[1]) * alpha;
154154

155155
if layer_index > 0 {
156156
// For layers past the first, claims is computed from the previous layer's vl and
157157
// vr, so we need linear constraint terms for that symbolic manipulation.
158158
let (vl_witness, vr_witness, _) =
159159
witness_layout.wire_witness_indices(layer_index - 1);
160160

161-
layer_claim += SymbolicTerm::new(vl_witness);
162-
layer_claim += SymbolicTerm::new(vr_witness) * alpha;
161+
claim_0.with_witness(vl_witness);
162+
claim_1.with_witness(vr_witness);
163163
}
164164

165+
layer_claim += claim_0;
166+
layer_claim += claim_1;
167+
165168
for (round, polynomial_pair) in proof_layer.polynomials.iter().enumerate() {
166169
for (hand, polynomial) in polynomial_pair.iter().enumerate() {
167170
transcript.write_polynomial(polynomial)?;
@@ -179,34 +182,26 @@ impl<FE: CodecFieldElement + LagrangePolynomialFieldElement> ProofConstraints<FE
179182
//
180183
// Compute the current claim:
181184
//
182-
// claim = lag_0(challenge) * p0
183-
// + lag_1(challenge) * p1
184-
// + lag_2(challenge) * p2
185+
// claim = p0 * lag_0(challenge)
186+
// + p1 * lag_1(challenge)
187+
// + p2 * lag_2(challenge)
185188
//
186189
// Expanding p1 = prev_claim - p0 and rearranging:
187190
//
188-
// claim = lag_1(challenge) * prev_claim
189-
// + (lag_0(challenge) - lag_1(challenge)) * p0
190-
// + lag_2(challenge) * p2
191+
// claim = prev_claim * lag_1(challenge)
192+
// + p0 * (lag_0(challenge) - lag_1(challenge))
193+
// + p2 * lag_2(challenge)
191194

192195
// lag_1(challenge) * prev_claim
193196
layer_claim *= FE::lagrange_basis_polynomial_1(challenge[0]);
194197

195-
// (lag_0(challenge) - lag_1(challenge)) * p0, known part:
196-
layer_claim += polynomial.p0
197-
* (FE::lagrange_basis_polynomial_0(challenge[0])
198-
- FE::lagrange_basis_polynomial_1(challenge[0]));
199-
200-
// (lag_0(challenge) - lag_1(challenge)) * p0, symbolic part:
201-
layer_claim += SymbolicTerm::new(p0_witness)
198+
// p0 * (lag_0(challenge) - lag_1(challenge)):
199+
layer_claim += (SymbolicTerm::new(p0_witness) + polynomial.p0)
202200
* (FE::lagrange_basis_polynomial_0(challenge[0])
203201
- FE::lagrange_basis_polynomial_1(challenge[0]));
204202

205-
// lag_2(challenge) * p2, known part:
206-
layer_claim += polynomial.p2 * FE::lagrange_basis_polynomial_2(challenge[0]);
207-
208-
// lag_2(challenge) * p2, symbolic part:
209-
layer_claim += SymbolicTerm::new(p2_witness)
203+
// p2 * lag_2(challenge):
204+
layer_claim += (SymbolicTerm::new(p2_witness) + polynomial.p2)
210205
* FE::lagrange_basis_polynomial_2(challenge[0]);
211206

212207
bound_quad = bound_quad.bind(&challenge).transpose();
@@ -296,13 +291,15 @@ impl<FE: CodecFieldElement + LagrangePolynomialFieldElement> ProofConstraints<FE
296291
final_claim += SymbolicTerm::new(vr_witness) * -gamma;
297292

298293
// Linear constraint RHS
299-
final_claim += claims[0] + gamma * claims[1]
300-
- public_inputs
301-
.iter()
302-
.zip(eq2.iter())
303-
.fold(FE::ZERO, |sum, (public_input_i, eq2_i)| {
304-
sum + *public_input_i * eq2_i
305-
});
294+
final_claim += SymbolicTerm::from_known(
295+
claims[0] + gamma * claims[1]
296+
- public_inputs
297+
.iter()
298+
.zip(eq2.iter())
299+
.fold(FE::ZERO, |sum, (public_input_i, eq2_i)| {
300+
sum + *public_input_i * eq2_i
301+
}),
302+
);
306303

307304
constraints
308305
.linear_constraint_lhs

src/sumcheck/symbolic.rs

Lines changed: 145 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
use crate::{fields::FieldElement, sumcheck::constraints::LinearConstraintLhsTerm};
2-
use std::ops::{AddAssign, Mul, MulAssign, Sub};
2+
use std::ops::{Add, AddAssign, Mul, MulAssign};
33

44
/// A symbolic expression, used to accumulate symbolic terms that contribute to a circuit layer's
55
/// linear constraint.
66
#[derive(Debug, Clone, PartialEq, Eq)]
77
pub struct SymbolicExpression<FieldElement> {
88
constraint_number: usize,
9-
known: FieldElement,
10-
symbolic_terms: Vec<SymbolicTerm<FieldElement>>,
9+
terms: Vec<SymbolicTerm<FieldElement>>,
1110
}
1211

1312
impl<FE: FieldElement> SymbolicExpression<FE> {
@@ -16,76 +15,90 @@ impl<FE: FieldElement> SymbolicExpression<FE> {
1615
pub fn new(layer_index: usize) -> Self {
1716
Self {
1817
constraint_number: layer_index,
19-
known: FE::ZERO,
20-
symbolic_terms: Vec::new(),
18+
terms: Vec::new(),
2119
}
2220
}
2321

2422
/// The known portion of this expression. Its contribution to the linear constraint's right hand
2523
/// side.
2624
pub fn known(&self) -> FE {
27-
self.known
25+
self.terms
26+
.iter()
27+
.fold(FE::ZERO, |sum, term| sum + term.known)
2828
}
2929

3030
/// The linear constraint LHS terms for this expression.
3131
pub fn lhs_terms(&self) -> Vec<LinearConstraintLhsTerm<FE>> {
32-
self.symbolic_terms
32+
self.terms
3333
.iter()
34-
.map(|term| LinearConstraintLhsTerm {
35-
constraint_number: self.constraint_number,
36-
witness_index: term.witness_index,
37-
constant_factor: term.constant_factor,
34+
// Terms with no witness index do not contribute to LHS
35+
.filter_map(|term| {
36+
term.witness_index
37+
.map(|witness_index| LinearConstraintLhsTerm {
38+
constraint_number: self.constraint_number,
39+
witness_index,
40+
constant_factor: term.constant_factor,
41+
})
3842
})
3943
.collect()
4044
}
4145
}
4246

43-
impl<FE: FieldElement> AddAssign<FE> for SymbolicExpression<FE> {
44-
fn add_assign(&mut self, rhs: FE) {
45-
self.known += rhs;
46-
}
47-
}
48-
4947
impl<FE: FieldElement> AddAssign<SymbolicTerm<FE>> for SymbolicExpression<FE> {
5048
fn add_assign(&mut self, rhs: SymbolicTerm<FE>) {
51-
self.symbolic_terms.push(rhs);
52-
}
53-
}
54-
55-
impl<FE: FieldElement> Sub<FE> for SymbolicExpression<FE> {
56-
type Output = Self;
57-
58-
fn sub(self, rhs: FE) -> Self::Output {
59-
Self {
60-
known: self.known - rhs,
61-
..self
62-
}
49+
self.terms.push(rhs);
6350
}
6451
}
6552

6653
impl<FE: FieldElement> MulAssign<FE> for SymbolicExpression<FE> {
6754
fn mul_assign(&mut self, rhs: FE) {
68-
self.known *= rhs;
69-
self.symbolic_terms.iter_mut().for_each(|term| *term *= rhs);
55+
self.terms.iter_mut().for_each(|term| *term *= rhs);
7056
}
7157
}
7258

73-
/// A symbolic term in a symbolic expression.
59+
/// A symbolic term in a symbolic expression, consisting of
60+
/// `known + constant_factor * W[witness_index]`.
7461
#[derive(Debug, Clone, PartialEq, Eq)]
7562
pub struct SymbolicTerm<FieldElement> {
63+
/// The known portion of the expression.
64+
pub known: FieldElement,
7665
/// The index into the witness vector W. This is `j` in the specification.
77-
pub witness_index: usize,
66+
pub witness_index: Option<usize>,
7867
/// The constant factor `k`.
7968
pub constant_factor: FieldElement,
8069
}
8170

8271
impl<FE: FieldElement> SymbolicTerm<FE> {
8372
pub fn new(witness_index: usize) -> Self {
8473
Self {
85-
witness_index,
74+
known: FE::ZERO,
75+
witness_index: Some(witness_index),
76+
constant_factor: FE::ONE,
77+
}
78+
}
79+
80+
pub fn from_known(known: FE) -> Self {
81+
Self {
82+
known,
83+
witness_index: None,
8684
constant_factor: FE::ONE,
8785
}
8886
}
87+
88+
pub fn with_witness(&mut self, index: usize) {
89+
self.witness_index = Some(index);
90+
}
91+
}
92+
93+
impl<FE: FieldElement> Add<FE> for SymbolicTerm<FE> {
94+
type Output = Self;
95+
96+
fn add(self, rhs: FE) -> Self::Output {
97+
Self {
98+
known: self.known + rhs,
99+
..self
100+
}
101+
}
89102
}
90103

91104
impl<FE: FieldElement> Mul<FE> for SymbolicTerm<FE> {
@@ -94,13 +107,111 @@ impl<FE: FieldElement> Mul<FE> for SymbolicTerm<FE> {
94107
fn mul(self, rhs: FE) -> Self::Output {
95108
Self {
96109
constant_factor: self.constant_factor * rhs,
110+
known: self.known * rhs,
97111
..self
98112
}
99113
}
100114
}
101115

102116
impl<FE: FieldElement> MulAssign<FE> for SymbolicTerm<FE> {
103117
fn mul_assign(&mut self, rhs: FE) {
118+
self.known *= rhs;
104119
self.constant_factor *= rhs;
105120
}
106121
}
122+
123+
#[cfg(test)]
124+
mod tests {
125+
use super::*;
126+
use crate::fields::fieldp256::FieldP256;
127+
128+
#[test]
129+
fn term_ops() {
130+
let term = SymbolicTerm::new(1);
131+
132+
let term = term + FieldP256::from_u128(2);
133+
134+
assert_eq!(
135+
term,
136+
SymbolicTerm {
137+
known: FieldP256::from_u128(2),
138+
witness_index: Some(1),
139+
constant_factor: FieldP256::ONE
140+
}
141+
);
142+
143+
let mut term = term * FieldP256::from_u128(5);
144+
145+
assert_eq!(
146+
term,
147+
SymbolicTerm {
148+
known: FieldP256::from_u128(10),
149+
witness_index: Some(1),
150+
constant_factor: FieldP256::from_u128(5)
151+
}
152+
);
153+
154+
term *= FieldP256::from_u128(6);
155+
156+
assert_eq!(
157+
term,
158+
SymbolicTerm {
159+
known: FieldP256::from_u128(60),
160+
witness_index: Some(1),
161+
constant_factor: FieldP256::from_u128(30)
162+
}
163+
);
164+
}
165+
166+
#[test]
167+
fn expression_ops() {
168+
let mut expression = SymbolicExpression::new(11);
169+
assert_eq!(expression.lhs_terms(), vec![]);
170+
assert_eq!(expression.known(), FieldP256::ZERO);
171+
172+
// Term with both known and symbolic part
173+
expression += SymbolicTerm::new(22) + FieldP256::from_u128(11);
174+
175+
// Term with only symbolic part
176+
expression += SymbolicTerm::new(33);
177+
178+
// Term with only known part
179+
expression += SymbolicTerm::from_known(FieldP256::from_u128(3));
180+
181+
assert_eq!(
182+
expression.lhs_terms(),
183+
vec![
184+
LinearConstraintLhsTerm {
185+
constraint_number: 11,
186+
witness_index: 22,
187+
constant_factor: FieldP256::ONE,
188+
},
189+
LinearConstraintLhsTerm {
190+
constraint_number: 11,
191+
witness_index: 33,
192+
constant_factor: FieldP256::ONE,
193+
},
194+
]
195+
);
196+
assert_eq!(expression.known(), FieldP256::from_u128(14));
197+
198+
expression *= FieldP256::from_u128(6);
199+
200+
assert_eq!(
201+
expression.lhs_terms(),
202+
vec![
203+
LinearConstraintLhsTerm {
204+
constraint_number: 11,
205+
witness_index: 22,
206+
constant_factor: FieldP256::from_u128(6),
207+
},
208+
LinearConstraintLhsTerm {
209+
constraint_number: 11,
210+
witness_index: 33,
211+
constant_factor: FieldP256::from_u128(6),
212+
},
213+
]
214+
);
215+
assert_eq!(expression.known(), FieldP256::from_u128(14 * 6));
216+
}
217+
}

0 commit comments

Comments
 (0)