Skip to content

Commit 728ce02

Browse files
authored
circuit prover: refactor and documentation for AddAir (#164)
* circuit prover: refactor and documentation for AddAir * fix comments
1 parent 9ec4482 commit 728ce02

File tree

1 file changed

+234
-84
lines changed

1 file changed

+234
-84
lines changed

circuit-prover/src/air/add_air.rs

Lines changed: 234 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,72 @@
1-
//! [`AddAir`] deals with addition and subtraction. In the case of subtraction, `a - b = c` is written in the table as `b + c = a`. \
2-
//! The chip handles both base field and extension field operations, as it is parametrized by the extension degree `D`.
3-
//! The runtime parameter `lanes` also controls the number of operations carried out in a row.
1+
//! [`AddAir`] defines the AIR for proving addition and subtraction over both base and extension fields.
42
//!
5-
//! # Columns
3+
//! Conceptually, each row of the trace encodes one or more addition constraints of the form
64
//!
7-
//! The AIR has `3 * D + 3` columns for each operation:
5+
//! lhs + rhs = result
86
//!
9-
//! - `D` columns for the left operand,
10-
//! - 1 column for `index_left`: the index of the left operand in the witness bus,
11-
//! - `D` columns for the right operand,
12-
//! - 1 column for `index_right`: the index of the right operand in the witness bus,
13-
//! - `D` columns for the output,
14-
//! - 1 column for `index_output`: the index of the output in the witness bus.
7+
//! When the circuit wants to prove a subtraction, it is expressed as an addition by rewriting
8+
//!
9+
//! a - b = c
10+
//!
11+
//! as
12+
//!
13+
//! b + c = a
14+
//!
15+
//! so that subtraction is handled uniformly as an addition gate in the AIR.
16+
//!
17+
//! The AIR is generic over an extension degree `D`. Each operand and result is treated as
18+
//! an element of an extension field of degree `D` over the base field. Internally, this is
19+
//! represented as `D` base-field coordinates (basis coefficients), and the addition is
20+
//! checked component-wise. The runtime parameter `lanes` controls how many independent
21+
//! additions are packed side-by-side in a single row of the trace.
22+
//!
23+
//! # Column layout
24+
//!
25+
//! For each logical operation (lane) we allocate `3 * D + 3` base-field columns. These are
26+
//! grouped as:
27+
//!
28+
//! - `D` columns for the left operand (lhs) basis coefficients,
29+
//! - `1` column for `index_left`: the witness-bus index of the lhs operand,
30+
//! - `D` columns for the right operand (rhs) basis coefficients,
31+
//! - `1` column for `index_right`: the witness-bus index of the rhs operand,
32+
//! - `D` columns for the result operand basis coefficients,
33+
//! - `1` column for `index_output`: the witness-bus index of the result.
34+
//!
35+
//! In other words, for a single lane the layout is:
36+
//!
37+
//! [lhs[0..D), lhs_index, rhs[0..D), rhs_index, result[0..D), result_index]
38+
//!
39+
//! A single row can pack several of these lanes side-by-side, so the full row layout is
40+
//! this pattern repeated `lanes` times.
1541
//!
1642
//! # Constraints
1743
//!
18-
//! - for each triple `(left, right, output)`: `left[i] + right[i] - output[i]`, for `i` in `0..D`.
44+
//! Let `left[i]`, `right[i]`, and `output[i]` denote the `i`-th basis coordinate of the
45+
//! left, right and result extension field elements respectively. For each operation and
46+
//! each coordinate `i` in `0..D`, the AIR enforces the linear constraint
47+
//!
48+
//! \begin{equation}
49+
//! left[i] + right[i] - output[i] = 0.
50+
//! \end{equation}
51+
//!
52+
//! Since extension addition is coordinate-wise, these constraints are sufficient to show
53+
//! that the full extension elements satisfy
54+
//!
55+
//! \begin{equation}
56+
//! left + right = output.
57+
//! \end{equation}
58+
//!
59+
//! # Global interactions
1960
//!
20-
//! # Global Interactions
61+
//! Each operation (lane) has three interactions with the global witness bus:
2162
//!
22-
//! There are three interactions per operation with the witness bus:
23-
//! - send `(index_left, left)`
63+
//! - send `(index_left, left)`
2464
//! - send `(index_right, right)`
25-
//! - send `(index_output, output)`
26-
27-
#![allow(clippy::needless_range_loop)]
28-
use alloc::vec::Vec;
65+
//! - send `(index_output, result)`
66+
//!
67+
//! The AIR defined here focuses on the algebraic relation between the operands. The
68+
//! correctness of the indices with respect to the global witness bus is enforced by the
69+
//! bus interaction logic elsewhere in the system.
2970
3071
use p3_air::{Air, AirBuilder, BaseAir};
3172
use p3_circuit::tables::AddTrace;
@@ -34,22 +75,34 @@ use p3_field::{BasedVectorSpace, Field, PrimeCharacteristicRing};
3475
use p3_matrix::Matrix;
3576
use p3_matrix::dense::RowMajorMatrix;
3677

37-
/// AIR for proving addition operations: lhs + rhs = result.
38-
/// Generic over extension degree `D` (component-wise addition) and a runtime lane count
39-
/// that controls how many additions are packed side-by-side in a single row.
78+
/// AIR for proving addition gates of the form `lhs + rhs = result`.
79+
///
80+
/// The type is generic over:
81+
///
82+
/// - `F`: the base field,
83+
/// - `D`: the degree of the extension field; each operand is represented by `D` coordinates.
84+
///
85+
/// At runtime, a `lanes` parameter specifies how many addition gates are packed into each
86+
/// trace row.
4087
#[derive(Debug, Clone)]
4188
pub struct AddAir<F, const D: usize = 1> {
42-
/// Number of logical addition operations in the trace.
89+
/// Total number of logical addition operations (gates) in the trace.
4390
pub num_ops: usize,
44-
/// Number of lanes (operations) packed per row.
91+
/// Number of independent addition gates packed per trace row.
92+
///
93+
/// The last row is padded if the number of operations is not a multiple of this value.
4594
pub lanes: usize,
95+
/// Marker tying this AIR to its base field.
4696
_phantom: core::marker::PhantomData<F>,
4797
}
4898

4999
impl<F: Field + PrimeCharacteristicRing, const D: usize> AddAir<F, D> {
50-
/// Number of base-field columns contributed by a single lane.
51-
pub const LANE_WIDTH: usize = 3 * D + 3;
52-
100+
/// Construct a new `AddAir` instance.
101+
///
102+
/// - `num_ops`: total number of addition operations to be proven,
103+
/// - `lanes`: how many operations are packed side-by-side in each row.
104+
///
105+
/// Panics if `lanes == 0` because we always need at least one lane per row.
53106
pub const fn new(num_ops: usize, lanes: usize) -> Self {
54107
assert!(lanes > 0, "lane count must be non-zero");
55108
Self {
@@ -59,72 +112,143 @@ impl<F: Field + PrimeCharacteristicRing, const D: usize> AddAir<F, D> {
59112
}
60113
}
61114

115+
/// Number of base-field columns occupied by a single lane.
116+
///
117+
/// Each lane stores:
118+
/// - `3 * D` coordinates (for `lhs`, `rhs`, and `result`),
119+
/// - `3` indices (one for each operand).
120+
///
121+
/// The total width of a single row is `3 * D + 3`
62122
pub const fn lane_width() -> usize {
63-
Self::LANE_WIDTH
123+
3 * D + 3
64124
}
65125

126+
/// Total number of columns in the main trace for this AIR instance.
66127
pub const fn total_width(&self) -> usize {
67128
self.lanes * Self::lane_width()
68129
}
69130

70-
/// Convert `AddTrace` to a row-major matrix, packing `lanes` additions per row.
71-
/// Resulting layout per row:
72-
/// `[lhs[D], lhs_idx, rhs[D], rhs_idx, result[D], result_idx]` repeated `lanes` times.
131+
/// Convert an `AddTrace` into a `RowMajorMatrix` suitable for the STARK prover.
132+
///
133+
/// This function is responsible for:
134+
///
135+
/// 1. Taking the logical operations from the `AddTrace`:
136+
/// - `lhs_values`, `rhs_values`, `result_values` (extension elements),
137+
/// - `lhs_index`, `rhs_index`, `result_index` (witness-bus indices),
138+
/// 2. Decomposing each extension element into its `D` basis coordinates,
139+
/// 3. Packing `lanes` operations side-by-side in each row,
140+
/// 4. Padding the trace to have a power-of-two number of rows for FFT-friendly
141+
/// execution by the STARK prover.
142+
///
143+
/// The resulting matrix has:
144+
///
145+
/// - width `= lanes * LANE_WIDTH`,
146+
/// - height equal to the number of rows after packing and padding.
147+
///
148+
/// The layout within a row is:
149+
///
150+
/// [lhs[D], lhs_idx, rhs[D], rhs_idx, result[D], result_idx] repeated `lanes` times.
73151
pub fn trace_to_matrix<ExtF: BasedVectorSpace<F>>(
74152
trace: &AddTrace<ExtF>,
75153
lanes: usize,
76154
) -> RowMajorMatrix<F> {
155+
// Lanes must be strictly positive.
156+
//
157+
// Zero lanes would make it impossible to construct a row.
77158
assert!(lanes > 0, "lane count must be non-zero");
78159

160+
// Per-lane width in base-field columns.
79161
let lane_width = Self::lane_width();
162+
// Total width of each row once all lanes are packed.
80163
let width = lane_width * lanes;
164+
// Number of logical operations we need to pack into the trace.
81165
let op_count = trace.lhs_values.len();
166+
// Number of rows needed to hold `op_count` operations when each row carries `lanes` of them.
82167
let row_count = op_count.div_ceil(lanes);
83168

84-
let mut values = Vec::with_capacity(width * row_count.max(1));
85-
86-
for row in 0..row_count {
87-
for lane in 0..lanes {
88-
let op_idx = row * lanes + lane;
89-
if op_idx < op_count {
90-
// LHS limbs + index
91-
let lhs_coeffs = trace.lhs_values[op_idx].as_basis_coefficients_slice();
92-
assert_eq!(
93-
lhs_coeffs.len(),
94-
D,
95-
"Extension field degree mismatch for lhs",
96-
);
97-
values.extend_from_slice(lhs_coeffs);
98-
values.push(F::from_u64(trace.lhs_index[op_idx].0 as u64));
99-
100-
// RHS limbs + index
101-
let rhs_coeffs = trace.rhs_values[op_idx].as_basis_coefficients_slice();
102-
assert_eq!(
103-
rhs_coeffs.len(),
104-
D,
105-
"Extension field degree mismatch for rhs",
106-
);
107-
values.extend_from_slice(rhs_coeffs);
108-
values.push(F::from_u64(trace.rhs_index[op_idx].0 as u64));
109-
110-
// Result limbs + index
111-
let result_coeffs = trace.result_values[op_idx].as_basis_coefficients_slice();
112-
assert_eq!(
113-
result_coeffs.len(),
114-
D,
115-
"Extension field degree mismatch for result",
116-
);
117-
values.extend_from_slice(result_coeffs);
118-
values.push(F::from_u64(trace.result_index[op_idx].0 as u64));
119-
} else {
120-
// Filler lane: append zeros for unused slot to keep the row width uniform.
121-
values.resize(values.len() + lane_width, F::ZERO);
122-
}
123-
}
169+
// Pre-allocate the entire trace as a flat vector in row-major order.
170+
//
171+
// We start with `row_count` rows, each of width `width`, and fill it with zeros.
172+
// This automatically provides a clean padding for any unused lanes in the final row.
173+
let mut values = F::zero_vec(width * row_count.max(1));
174+
175+
// Iterate over all operations in lockstep across the trace arrays.
176+
for (op_idx, (((((lhs_val, lhs_idx), rhs_val), rhs_idx), res_val), res_idx)) in trace
177+
.lhs_values
178+
.iter()
179+
.zip(trace.lhs_index.iter())
180+
.zip(trace.rhs_values.iter())
181+
.zip(trace.rhs_index.iter())
182+
.zip(trace.result_values.iter())
183+
.zip(trace.result_index.iter())
184+
.enumerate()
185+
{
186+
// Determine the target row index.
187+
let row = op_idx / lanes;
188+
// Determine which lane within that row this operation occupies.
189+
let lane = op_idx % lanes;
190+
191+
// Compute the starting column index (cursor) for this lane within the flat vector.
192+
//
193+
// Row-major layout means:
194+
// row_offset = row * width,
195+
// lane_offset = lane * lane_width.
196+
let mut cursor = (row * width) + (lane * lane_width);
197+
198+
// Write LHS coordinates and LHS witness index.
199+
//
200+
// Extract the basis coefficients of the lhs extension element.
201+
let lhs_coeffs = lhs_val.as_basis_coefficients_slice();
202+
// Sanity check: the extension degree must match the generic parameter `D`.
203+
assert_eq!(
204+
lhs_coeffs.len(),
205+
D,
206+
"Extension field degree mismatch for lhs"
207+
);
208+
// Copy the `D` lhs coordinates into the trace row.
209+
values[cursor..cursor + D].copy_from_slice(lhs_coeffs);
210+
cursor += D;
211+
// Store the lhs witness-bus index as a base-field element.
212+
values[cursor] = F::from_u32(lhs_idx.0);
213+
cursor += 1;
214+
215+
// Write RHS coordinates and RHS witness index.
216+
//
217+
// Extract the basis coefficients of the rhs extension element.
218+
let rhs_coeffs = rhs_val.as_basis_coefficients_slice();
219+
// Sanity check: the extension degree must match the generic parameter `D`.
220+
assert_eq!(
221+
rhs_coeffs.len(),
222+
D,
223+
"Extension field degree mismatch for rhs"
224+
);
225+
// Copy the `D` rhs coordinates into the trace row.
226+
values[cursor..cursor + D].copy_from_slice(rhs_coeffs);
227+
cursor += D;
228+
// Store the rhs witness-bus index as a base-field element.
229+
values[cursor] = F::from_u32(rhs_idx.0);
230+
cursor += 1;
231+
232+
// Write result coordinates and result witness index.
233+
//
234+
// Extract the basis coefficients of the result extension element.
235+
let res_coeffs = res_val.as_basis_coefficients_slice();
236+
debug_assert_eq!(
237+
res_coeffs.len(),
238+
D,
239+
"Extension field degree mismatch for result"
240+
);
241+
// Copy the `D` result coordinates into the trace row.
242+
values[cursor..cursor + D].copy_from_slice(res_coeffs);
243+
cursor += D;
244+
// Store the result witness-bus index as a base-field element.
245+
values[cursor] = F::from_u32(res_idx.0);
124246
}
125247

248+
// Pad the matrix to a power-of-two height.
126249
pad_to_power_of_two(&mut values, width, row_count);
127250

251+
// Build the row-major matrix with the computed width.
128252
RowMajorMatrix::new(values, width)
129253
}
130254
}
@@ -140,26 +264,51 @@ where
140264
AB::F: Field,
141265
{
142266
fn eval(&self, builder: &mut AB) {
267+
// Access the main trace view from the builder.
143268
let main = builder.main();
144269

270+
// Make sure that the matrix width matches what this AIR expects.
145271
debug_assert_eq!(main.width(), self.total_width(), "column width mismatch");
146272

273+
// Get the evaluation at evaluation point `zeta`
147274
let local = main.row_slice(0).expect("matrix must be non-empty");
148-
let local = &*local;
149275
let lane_width = Self::lane_width();
150276

151-
for lane in 0..self.lanes {
152-
let mut cursor = lane * lane_width;
153-
let lhs_slice = &local[cursor..cursor + D];
154-
cursor += D + 1; // Skip lhs index
155-
let rhs_slice = &local[cursor..cursor + D];
156-
cursor += D + 1; // Skip rhs index
157-
let result_slice = &local[cursor..cursor + D];
158-
159-
for i in 0..D {
160-
builder.assert_zero(
161-
lhs_slice[i].clone() + rhs_slice[i].clone() - result_slice[i].clone(),
162-
);
277+
// Iterate over the row in fixed-size chunks, each chunk describing one lane:
278+
//
279+
// [lhs[0..D), lhs_idx, rhs[0..D), rhs_idx, result[0..D), result_idx]
280+
for lane_data in local.chunks_exact(lane_width) {
281+
// First, split off the lhs block and its index:
282+
//
283+
// lhs_and_idx = [lhs[0..D), lhs_idx]
284+
// rest = [rhs[0..D), rhs_idx, result[0..D), result_idx]
285+
let (lhs_and_idx, rest) = lane_data.split_at(D + 1);
286+
// Next, split the remaining data into:
287+
//
288+
// rhs_and_idx = [rhs[0..D), rhs_idx]
289+
// result_and_idx = [result[0..D), result_idx]
290+
let (rhs_and_idx, result_and_idx) = rest.split_at(D + 1);
291+
292+
// Extract just the coordinate slices for the three operands.
293+
//
294+
// NOTE: Indices reside at position [D] in each `*_and_idx` slice.
295+
// They are not used in constraints, but are checked by the bus interaction logic.
296+
let lhs_slice = &lhs_and_idx[..D];
297+
let rhs_slice = &rhs_and_idx[..D];
298+
let result_slice = &result_and_idx[..D];
299+
300+
// Enforce coordinate-wise addition for each basis coordinate `i` in `0..D`.
301+
//
302+
// For each `i`, we add the constraint:
303+
//
304+
// lhs_slice[i] + rhs_slice[i] - result_slice[i] = 0.
305+
for ((lhs, rhs), result) in lhs_slice
306+
.iter()
307+
.zip(rhs_slice.iter())
308+
.zip(result_slice.iter())
309+
{
310+
// Push a single linear constraint into the builder.
311+
builder.assert_zero(lhs.clone() + rhs.clone() - result.clone());
163312
}
164313
}
165314
}
@@ -168,6 +317,7 @@ where
168317
#[cfg(test)]
169318
mod tests {
170319
use alloc::vec;
320+
use alloc::vec::Vec;
171321

172322
use p3_baby_bear::BabyBear as Val;
173323
use p3_circuit::WitnessId;

0 commit comments

Comments
 (0)