diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a666789..636f3ad 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -97,6 +97,7 @@ jobs: # cargo build --verbose --target ${{ env.target }} -p p3-fri-air # cargo build --verbose --target ${{ env.target }} -p p3-interpolation-air cargo build --verbose --target ${{ env.target }} -p p3-mmcs-air + cargo build --verbose --target ${{ env.target }} -p p3-poseidon2-circuit-air cargo build --verbose --target ${{ env.target }} -p p3-symmetric-air cargo build --verbose --target ${{ env.target }} -p p3-recursion diff --git a/Cargo.toml b/Cargo.toml index 3d9274b..fdebd8c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ members = [ "fri-air", "interpolation-air", "mmcs-air", + "poseidon2-circuit-air", "recursion", "symmetric-air", ] @@ -41,6 +42,8 @@ p3-koala-bear = { git = "https://github.com/Plonky3/Plonky3" } p3-matrix = { git = "https://github.com/Plonky3/Plonky3" } p3-maybe-rayon = { git = "https://github.com/Plonky3/Plonky3" } p3-merkle-tree = { git = "https://github.com/Plonky3/Plonky3" } +p3-poseidon2 = { git = "https://github.com/Plonky3/Plonky3" } +p3-poseidon2-air = { git = "https://github.com/Plonky3/Plonky3" } p3-symmetric = { git = "https://github.com/Plonky3/Plonky3" } p3-uni-stark = { git = "https://github.com/Plonky3/Plonky3" } p3-util = { git = "https://github.com/Plonky3/Plonky3" } @@ -72,6 +75,7 @@ p3-field-air = { path = "field-air", version = "0.1.0" } p3-fri-air = { path = "fri-air", version = "0.1.0" } p3-interpolation-air = { path = "interpolation-air", version = "0.1.0" } p3-mmcs-air = { path = "mmcs-air", version = "0.1.0" } +p3-poseidon2-circuit-air = { path = "poseidon2-circuit-air", version = "0.1.0" } p3-recursion = { path = "recursion", version = "0.1.0" } p3-symmetric-air = { path = "symmetric-air", version = "0.1.0" } diff --git a/circuit-prover/Cargo.toml b/circuit-prover/Cargo.toml index cc1c22c..c3987ee 100644 --- a/circuit-prover/Cargo.toml +++ b/circuit-prover/Cargo.toml @@ -28,6 +28,7 @@ p3-matrix.workspace = true p3-maybe-rayon.workspace = true p3-merkle-tree.workspace = true p3-mmcs-air.workspace = true +p3-poseidon2-air.workspace = true p3-symmetric.workspace = true p3-uni-stark.workspace = true rand.workspace = true diff --git a/circuit/src/tables/mod.rs b/circuit/src/tables/mod.rs index ae24782..30b77c9 100644 --- a/circuit/src/tables/mod.rs +++ b/circuit/src/tables/mod.rs @@ -4,6 +4,7 @@ mod add; mod constant; mod mmcs; mod mul; +mod poseidon2; mod public; mod runner; mod witness; @@ -12,6 +13,7 @@ pub use add::AddTrace; pub use constant::ConstTrace; pub use mmcs::{MmcsPathTrace, MmcsPrivateData, MmcsTrace}; pub use mul::MulTrace; +pub use poseidon2::{Poseidon2CircuitRow, Poseidon2CircuitTrace}; pub use public::PublicTrace; pub use runner::CircuitRunner; pub use witness::WitnessTrace; diff --git a/circuit/src/tables/poseidon2.rs b/circuit/src/tables/poseidon2.rs new file mode 100644 index 0000000..8c47e3b --- /dev/null +++ b/circuit/src/tables/poseidon2.rs @@ -0,0 +1,18 @@ +use alloc::vec::Vec; + +/// Poseidon2 operation table +pub struct Poseidon2CircuitRow { + /// Poseidon2 operation type + pub is_sponge: bool, + /// Reset flag + pub reset: bool, + /// Absorb flags + pub absorb_flags: Vec, + /// Inputs to the Poseidon2 permutation + pub input_values: Vec, + /// Input indices + pub input_indices: Vec, + /// Output indices + pub output_indices: Vec, +} +pub type Poseidon2CircuitTrace = Vec>; diff --git a/poseidon2-circuit-air/Cargo.toml b/poseidon2-circuit-air/Cargo.toml new file mode 100644 index 0000000..1a01944 --- /dev/null +++ b/poseidon2-circuit-air/Cargo.toml @@ -0,0 +1,44 @@ +[package] +name = "p3-poseidon2-circuit-air" +description = "An AIR implementation of Poseidon2 modified for the circuit builder." +version.workspace = true +edition.workspace = true +license.workspace = true +repository.workspace = true +homepage.workspace = true +keywords.workspace = true +categories.workspace = true + +[dependencies] +p3-air.workspace = true +p3-circuit.workspace = true +p3-field.workspace = true +p3-matrix.workspace = true +p3-maybe-rayon.workspace = true +p3-poseidon2.workspace = true +p3-poseidon2-air.workspace = true +p3-symmetric.workspace = true + +itertools.workspace = true +rand.workspace = true +tracing.workspace = true + +[target.'cfg(target_family = "unix")'.dev-dependencies] +tikv-jemallocator = "0.6" + +[dev-dependencies] +p3-baby-bear.workspace = true +p3-challenger.workspace = true +p3-commit.workspace = true +p3-dft.workspace = true +p3-fri.workspace = true +p3-keccak.workspace = true +p3-koala-bear.workspace = true +p3-merkle-tree.workspace = true +p3-uni-stark.workspace = true + +tracing-forest = { workspace = true, features = ["ansi", "smallvec"] } +tracing-subscriber = { workspace = true, features = ["std", "env-filter"] } + +[features] +parallel = ["p3-maybe-rayon/parallel"] diff --git a/poseidon2-circuit-air/src/air.rs b/poseidon2-circuit-air/src/air.rs new file mode 100644 index 0000000..4b5d653 --- /dev/null +++ b/poseidon2-circuit-air/src/air.rs @@ -0,0 +1,676 @@ +use alloc::vec; +use alloc::vec::Vec; +use core::array; +use core::borrow::Borrow; + +use p3_air::{Air, AirBuilder, BaseAir}; +use p3_circuit::tables::{Poseidon2CircuitRow, Poseidon2CircuitTrace}; +use p3_field::{PrimeCharacteristicRing, PrimeField}; +use p3_matrix::Matrix; +use p3_matrix::dense::{RowMajorMatrix, RowMajorMatrixViewMut}; +use p3_poseidon2::GenericPoseidon2LinearLayers; +use p3_poseidon2_air::{Poseidon2Air, RoundConstants, generate_trace_rows}; +use p3_symmetric::CryptographicPermutation; + +use crate::sub_builder::SubAirBuilder; +use crate::{Poseidon2CircuitCols, num_cols}; + +/// Extends the Poseidon2 AIR with recursion circuit-specific columns and constraints. +/// Assumes the field size is at least 16 bits. +/// +/// SPECIFIC ASSUMPTIONS: +/// - Memory elements from the witness table are extension elements of degree D. +/// - RATE and CAPACITY are the number of extension elements in the rate/capacity. +/// - WIDTH is the number of field elements in the state, i.e., (RATE + CAPACITY) * D. +/// - `reset` can only be set during an absorb. +#[derive(Debug)] +pub struct Poseidon2CircuitAir< + F: PrimeCharacteristicRing, + LinearLayers, + const D: usize, + const WIDTH: usize, + const WIDTH_EXT: usize, + const RATE_EXT: usize, + const CAPACITY_EXT: usize, + const SBOX_DEGREE: u64, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, +> { + p3_poseidon2: Poseidon2Air< + F, + LinearLayers, + WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + >, +} + +impl< + F: PrimeField, + LinearLayers: GenericPoseidon2LinearLayers, + const D: usize, + const WIDTH: usize, + const WIDTH_EXT: usize, + const RATE_EXT: usize, + const CAPACITY_EXT: usize, + const SBOX_DEGREE: u64, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, +> + Poseidon2CircuitAir< + F, + LinearLayers, + D, + WIDTH, + WIDTH_EXT, + RATE_EXT, + CAPACITY_EXT, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + > +{ + pub const fn new( + constants: RoundConstants, + ) -> Self { + assert!(CAPACITY_EXT + RATE_EXT == WIDTH_EXT); + assert!(WIDTH_EXT * D == WIDTH); + + Self { + p3_poseidon2: Poseidon2Air::new(constants), + } + } + + pub fn generate_trace_rows>( + &self, + sponge_ops: Poseidon2CircuitTrace, + constants: &RoundConstants, + extra_capacity_bits: usize, + perm: P, + ) -> RowMajorMatrix { + let n = sponge_ops.len(); + assert!( + n.is_power_of_two(), + "Callers expected to pad inputs to a power of two" + ); + + let num_circuit_cols = 3 + 2 * RATE_EXT + WIDTH_EXT; + let mut circuit_trace = vec![F::ZERO; n * num_circuit_cols]; + let mut circuit_trace = RowMajorMatrixViewMut::new(&mut circuit_trace, num_circuit_cols); + + let mut state = [F::ZERO; WIDTH]; + let mut inputs = Vec::with_capacity(n); + for (i, op) in sponge_ops.iter().enumerate() { + let Poseidon2CircuitRow { + is_sponge, + reset, + absorb_flags, + input_values, + input_indices, + output_indices, + } = op; + + let row = circuit_trace.row_mut(i); + + row[0] = if *is_sponge { F::ONE } else { F::ZERO }; + row[1] = if *reset { F::ONE } else { F::ZERO }; + row[2] = if *is_sponge && *reset { + F::ONE + } else { + F::ZERO + }; + for j in 0..RATE_EXT { + row[3 + j] = if absorb_flags[j] { F::ONE } else { F::ZERO }; + } + for j in 0..RATE_EXT { + row[3 + RATE_EXT + j] = F::from_u32(input_indices[j]); + } + for j in 0..RATE_EXT { + row[3 + RATE_EXT + WIDTH_EXT + j] = F::from_u32(output_indices[j]); + } + + let mut index_absorb = [false; RATE_EXT]; + for (j, flag) in absorb_flags.iter().enumerate() { + if *flag { + for absorb in index_absorb.iter_mut().take(j + 1) { + *absorb = true; + } + } + } + + for (j, absorb) in index_absorb.iter_mut().enumerate() { + if *absorb { + for d in 0..D { + let idx = j * D + d; + state[idx] = input_values[idx]; + } + } else if *reset { + // During a reset, non-absorbed rate elements are zeroed. + for d in 0..D { + let idx = j * D + d; + state[idx] = F::ZERO; + } + } + } + + if *reset || !*is_sponge { + // Compression or reset: reset capacity + for j in 0..(CAPACITY_EXT * D) { + state[RATE_EXT * D + j] = F::ZERO; + } + } + + inputs.push(state); + state = perm.permute(state); + } + + let p2_trace = generate_trace_rows::< + F, + LinearLayers, + WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + >(inputs, constants, extra_capacity_bits); + + let ncols = self.width(); + + let p2_ncols = p3_poseidon2_air::num_cols::< + WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + >(); + + let mut vec = vec![F::ZERO; n * ncols]; + + for i in 0..n { + let row = &mut vec[(i * ncols)..((i + 1) * ncols)]; + let left_part = p2_trace + .row(i) + .expect("Missing row {i}?") + .into_iter() + .collect::>(); + let right_part = circuit_trace + .row(i) + .expect("Missing row {i}?") + .into_iter() + .collect::>(); + row[..p2_ncols].copy_from_slice(&left_part); + row[p2_ncols..].copy_from_slice(&right_part); + } + + unsafe { + vec.set_len(n * ncols); + } + + RowMajorMatrix::new(vec, ncols) + } +} + +impl< + F: PrimeCharacteristicRing + Sync, + LinearLayers: Sync, + const D: usize, + const WIDTH: usize, + const WIDTH_EXT: usize, + const RATE_EXT: usize, + const CAPACITY_EXT: usize, + const SBOX_DEGREE: u64, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, +> BaseAir + for Poseidon2CircuitAir< + F, + LinearLayers, + D, + WIDTH, + WIDTH_EXT, + RATE_EXT, + CAPACITY_EXT, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + > +{ + fn width(&self) -> usize { + num_cols::< + WIDTH, + WIDTH_EXT, + RATE_EXT, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + >() + } +} + +pub(crate) fn eval< + AB: AirBuilder, + LinearLayers: GenericPoseidon2LinearLayers, + const D: usize, + const WIDTH: usize, + const WIDTH_EXT: usize, + const RATE_EXT: usize, + const CAPACITY_EXT: usize, + const SBOX_DEGREE: u64, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, +>( + air: &Poseidon2CircuitAir< + AB::F, + LinearLayers, + D, + WIDTH, + WIDTH_EXT, + RATE_EXT, + CAPACITY_EXT, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + >, + builder: &mut AB, + local: &Poseidon2CircuitCols< + AB::Var, + WIDTH, + WIDTH_EXT, + RATE_EXT, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + >, + next: &Poseidon2CircuitCols< + AB::Var, + WIDTH, + WIDTH_EXT, + RATE_EXT, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + >, +) { + // SPONGE CONSTRAINTS + let next_no_reset = AB::Expr::ONE - next.reset.clone(); + for i in 0..(CAPACITY_EXT * D) { + // The first row has capacity zeroed. + builder + .when(local.is_sponge.clone()) + .when_first_row() + .assert_zero(local.poseidon2.inputs[RATE_EXT * D + i].clone()); + + // When resetting the state, we just have to clear the capacity. The rate will be overwritten by the input. + builder + .when(local.is_sponge.clone()) + .when(local.reset.clone()) + .assert_zero(local.poseidon2.inputs[RATE_EXT * D + i].clone()); + + // If the next row doesn't reset, propagate the capacity. + builder + .when_transition() + .when(next.is_sponge.clone()) + .when(next_no_reset.clone()) + .assert_zero( + next.poseidon2.inputs[RATE_EXT * D + i].clone() + - local.poseidon2.ending_full_rounds[HALF_FULL_ROUNDS - 1].post + [RATE_EXT * D + i] + .clone(), + ); + } + + let mut next_absorb = [AB::Expr::ZERO; RATE_EXT]; + for i in 0..RATE_EXT { + for col in next_absorb.iter_mut().take(i + 1) { + *col += next.absorb_flags[i].clone(); + } + } + let next_no_absorb = + array::from_fn::<_, RATE_EXT, _>(|i| AB::Expr::ONE - next_absorb[i].clone()); + // In the next row, each rate element not being absorbed is either: + // - zeroed if the next row is a reset (handled elsewhere); + // - copied from the current row if the next row is not a reset. + // We omit the `is_sponge` check because in a compression all absorb flags are set. + for index in 0..(RATE_EXT * D) { + let i = index / D; + let j = index % D; + builder + .when_transition() + .when(next_no_absorb[i].clone()) + .when(next_no_reset.clone()) + .assert_zero( + next.poseidon2.inputs[i * D + j].clone() + - local.poseidon2.ending_full_rounds[HALF_FULL_ROUNDS - 1].post[i * D + j] + .clone(), + ); + } + + let mut current_absorb = [AB::Expr::ZERO; RATE_EXT]; + for i in 0..RATE_EXT { + for col in current_absorb.iter_mut().take(i + 1) { + *col += local.absorb_flags[i].clone(); + } + } + let current_no_absorb = + array::from_fn::<_, RATE_EXT, _>(|i| AB::Expr::ONE - current_absorb[i].clone()); + builder.assert_eq( + local.is_sponge.clone() * local.reset.clone(), + local.sponge_reset.clone(), + ); + // During a reset, the rate elements not being absorbed are zeroed. + for (i, col) in current_no_absorb.iter().enumerate() { + let arr = array::from_fn::<_, D, _>(|j| local.poseidon2.inputs[i * D + j].clone().into()); + builder + .when(local.sponge_reset.clone() * col.clone()) + .assert_zeros(arr); + } + + let _is_squeeze = AB::Expr::ONE - current_absorb[0].clone(); + // TODO: Add all lookups: + // - If current_absorb[i] = 1: + // * local.rate[i] comes from input lookups. + // - If is_squeeze = 1: + // * local.rate is sent to output lookups. + + // COMPRESSION CONSTRAINTS + // TODO: Add all lookups: + // - local input state comes from input lookups. + // - send local output state to output lookups. + + let p3_poseidon2_num_cols = p3_poseidon2_air::num_cols::< + WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + >(); + let mut sub_builder = SubAirBuilder::< + AB, + Poseidon2Air< + AB::F, + LinearLayers, + WIDTH, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + >, + AB::Var, + >::new(builder, 0..p3_poseidon2_num_cols); + + // Eval the Plonky3 Poseidon2 air. + air.p3_poseidon2.eval(&mut sub_builder); +} + +impl< + AB: AirBuilder, + LinearLayers: GenericPoseidon2LinearLayers, + const D: usize, + const WIDTH: usize, + const WIDTH_EXT: usize, + const RATE_EXT: usize, + const CAPACITY_EXT: usize, + const SBOX_DEGREE: u64, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, +> Air + for Poseidon2CircuitAir< + AB::F, + LinearLayers, + D, + WIDTH, + WIDTH_EXT, + RATE_EXT, + CAPACITY_EXT, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + > +{ + #[inline] + fn eval(&self, builder: &mut AB) { + let main = builder.main(); + let local = main.row_slice(0).expect("The matrix is empty?"); + let local = (*local).borrow(); + let next = main.row_slice(1).expect("The matrix has only one row?"); + let next = (*next).borrow(); + + eval::< + _, + _, + D, + WIDTH, + WIDTH_EXT, + RATE_EXT, + CAPACITY_EXT, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + >(self, builder, local, next); + } +} + +#[cfg(test)] +mod test { + + use alloc::vec; + + use p3_baby_bear::{ + BabyBear, GenericPoseidon2LinearLayersBabyBear, Poseidon2ExternalLayerBabyBear, + Poseidon2InternalLayerBabyBear, + }; + use p3_challenger::{HashChallenger, SerializingChallenger32}; + use p3_circuit::tables::Poseidon2CircuitRow; + use p3_commit::ExtensionMmcs; + use p3_field::extension::BinomialExtensionField; + use p3_fri::{TwoAdicFriPcs, create_benchmark_fri_params}; + use p3_keccak::{Keccak256Hash, KeccakF}; + use p3_merkle_tree::MerkleTreeHidingMmcs; + use p3_poseidon2::{ExternalLayerConstants, Poseidon2}; + use p3_poseidon2_air::RoundConstants; + use p3_symmetric::{CompressionFunctionFromHasher, PaddingFreeSponge, SerializingHasher}; + use p3_uni_stark::{StarkConfig, prove, verify}; + use rand::rngs::SmallRng; + use rand::{Rng, SeedableRng}; + use tracing_forest::ForestLayer; + use tracing_forest::util::LevelFilter; + use tracing_subscriber::layer::SubscriberExt; + use tracing_subscriber::util::SubscriberInitExt; + use tracing_subscriber::{EnvFilter, Registry}; + + use crate::air::Poseidon2CircuitAir; + + const D: usize = 4; + const WIDTH: usize = 16; + const WIDTH_EXT: usize = 4; + const RATE_EXT: usize = 2; + const CAPACITY_EXT: usize = 2; + const SBOX_DEGREE: u64 = 7; + const SBOX_REGISTERS: usize = 1; + const HALF_FULL_ROUNDS: usize = 4; + const PARTIAL_ROUNDS: usize = 20; + + fn init_logger() { + let env_filter = EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(); + + Registry::default() + .with(env_filter) + .with(ForestLayer::default()) + .init(); + } + + #[test] + fn prove_poseidon2_sponge() -> Result< + (), + p3_uni_stark::VerificationError< + p3_fri::verifier::FriError< + p3_merkle_tree::MerkleTreeError, + p3_merkle_tree::MerkleTreeError, + >, + >, + > { + init_logger(); + type Val = BabyBear; + type Challenge = BinomialExtensionField; + + type ByteHash = Keccak256Hash; + let byte_hash = ByteHash {}; + + type U64Hash = PaddingFreeSponge; + let u64_hash = U64Hash::new(KeccakF {}); + + type FieldHash = SerializingHasher; + let field_hash = FieldHash::new(u64_hash); + + type MyCompress = CompressionFunctionFromHasher; + let compress = MyCompress::new(u64_hash); + + // WARNING: DO NOT USE SmallRng in proper applications! Use a real PRNG instead! + type ValMmcs = MerkleTreeHidingMmcs< + [Val; p3_keccak::VECTOR_LEN], + [u64; p3_keccak::VECTOR_LEN], + FieldHash, + MyCompress, + SmallRng, + 4, + 4, + >; + let mut rng = SmallRng::seed_from_u64(1); + let val_mmcs = ValMmcs::new(field_hash, compress, rng.clone()); + + type ChallengeMmcs = ExtensionMmcs; + let challenge_mmcs = ChallengeMmcs::new(val_mmcs.clone()); + + type Challenger = SerializingChallenger32>; + let challenger = Challenger::from_hasher(vec![], byte_hash); + + let fri_params = create_benchmark_fri_params(challenge_mmcs); + + let beginning_full_constants = rng.random(); + let partial_constants = rng.random(); + let ending_full_constants = rng.random(); + + let constants = RoundConstants::new( + beginning_full_constants, + partial_constants, + ending_full_constants, + ); + + let perm = Poseidon2::< + Val, + Poseidon2ExternalLayerBabyBear, + Poseidon2InternalLayerBabyBear, + WIDTH, + SBOX_DEGREE, + >::new( + ExternalLayerConstants::new( + beginning_full_constants.to_vec(), + ending_full_constants.to_vec(), + ), + partial_constants.to_vec(), + ); + + let air: Poseidon2CircuitAir< + Val, + GenericPoseidon2LinearLayersBabyBear, + D, + WIDTH, + WIDTH_EXT, + RATE_EXT, + CAPACITY_EXT, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + > = Poseidon2CircuitAir::new(constants.clone()); + + // Generate random inputs. + let mut rng = SmallRng::seed_from_u64(1); + + // Absorb + let sponge_a: Poseidon2CircuitRow = Poseidon2CircuitRow { + is_sponge: true, + reset: true, + absorb_flags: vec![false, true], + input_values: (0..RATE_EXT * D).map(|_| rng.random()).collect(), + input_indices: vec![0; RATE_EXT], + output_indices: vec![0; RATE_EXT], + }; + + // Absorb + let sponge_b: Poseidon2CircuitRow = Poseidon2CircuitRow { + is_sponge: true, + reset: false, + absorb_flags: vec![false, true], + input_values: (0..RATE_EXT * D).map(|_| rng.random()).collect(), + input_indices: vec![0; RATE_EXT], + output_indices: vec![0; RATE_EXT], + }; + + // Squeeze + let sponge_c: Poseidon2CircuitRow = Poseidon2CircuitRow { + is_sponge: true, + reset: false, + absorb_flags: vec![false, false], + input_values: vec![Val::new(0); RATE_EXT * D], + input_indices: vec![0; RATE_EXT], + output_indices: vec![0; RATE_EXT], + }; + + // Absorb one element with reset + let sponge_d: Poseidon2CircuitRow = Poseidon2CircuitRow { + is_sponge: true, + reset: true, + absorb_flags: vec![true, false], + input_values: vec![ + Val::new(42), + Val::new(43), + Val::new(44), + Val::new(45), + Val::new(0), + Val::new(0), + Val::new(0), + Val::new(0), + ], + input_indices: vec![0; RATE_EXT], + output_indices: vec![0; RATE_EXT], + }; + + let trace = air.generate_trace_rows( + vec![sponge_a, sponge_b, sponge_c, sponge_d], + &constants, + fri_params.log_blowup, + perm, + ); + + type Dft = p3_dft::Radix2Bowers; + let dft = Dft::default(); + + type Pcs = TwoAdicFriPcs; + let pcs = Pcs::new(dft, val_mmcs, fri_params); + + type MyConfig = StarkConfig; + let config = MyConfig::new(pcs, challenger); + + let proof = prove(&config, &air, trace, &vec![]); + + verify(&config, &air, &proof, &vec![]) + } +} diff --git a/poseidon2-circuit-air/src/columns.rs b/poseidon2-circuit-air/src/columns.rs new file mode 100644 index 0000000..db3187c --- /dev/null +++ b/poseidon2-circuit-air/src/columns.rs @@ -0,0 +1,171 @@ +use core::borrow::{Borrow, BorrowMut}; + +use p3_poseidon2_air::Poseidon2Cols; + +/// Columns for a Poseidon2 AIR which computes one permutation per row. +/// +/// They extend the P3 columns with some circuit-specific columns. +/// +/// `is_sponge` (transparent): if `1`, this row performs a sponge operation (absorb or squeeze); +/// otherwise, it performs a compression. +/// `reset` (transparent): indicates whether the state is being reset this row. +/// `sponge_reset`: auxiliary column to keep constraint degrees below three. +/// `absorb_flags` (transparent): for each rate element, indicates if it is being absorbed this row. +/// At most one flag is set to 1 per row: if `absorb_flags[i]` is 1, then all elements up to the `i`-th +/// are absorbed; the rest are propagated from the previous row. +/// `input_indices` (transparent): for each input element, indicates the index in the witness table for the +/// memory lookup. It's either received (for an absorb or a compression) or sent (for a squeeze). +/// `output_indices` (transparent): for each output element, indicates the index in the witness table for the +/// memory lookup. Only used by compressions to send the output. +#[repr(C)] +pub struct Poseidon2CircuitCols< + T, + const WIDTH: usize, + const WIDTH_EXT: usize, + const RATE_EXT: usize, + const SBOX_DEGREE: u64, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, +> { + pub poseidon2: + Poseidon2Cols, + + pub is_sponge: T, + pub reset: T, + pub sponge_reset: T, + pub absorb_flags: [T; RATE_EXT], + pub input_indices: [T; WIDTH_EXT], + pub output_indices: [T; RATE_EXT], +} + +pub const fn num_cols< + const WIDTH: usize, + const WIDTH_EXT: usize, + const RATE_EXT: usize, + const SBOX_DEGREE: u64, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, +>() -> usize { + size_of::< + Poseidon2CircuitCols< + u8, + WIDTH, + WIDTH_EXT, + RATE_EXT, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + >, + >() +} + +impl< + T, + const WIDTH: usize, + const WIDTH_EXT: usize, + const RATE_EXT: usize, + const SBOX_DEGREE: u64, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, +> + Borrow< + Poseidon2CircuitCols< + T, + WIDTH, + WIDTH_EXT, + RATE_EXT, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + >, + > for [T] +{ + fn borrow( + &self, + ) -> &Poseidon2CircuitCols< + T, + WIDTH, + WIDTH_EXT, + RATE_EXT, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + > { + let (prefix, shorts, suffix) = unsafe { + self.align_to::>() + }; + debug_assert!(prefix.is_empty(), "Alignment should match"); + debug_assert!(suffix.is_empty(), "Alignment should match"); + debug_assert_eq!(shorts.len(), 1); + &shorts[0] + } +} + +impl< + T, + const WIDTH: usize, + const WIDTH_EXT: usize, + const RATE_EXT: usize, + const SBOX_DEGREE: u64, + const SBOX_REGISTERS: usize, + const HALF_FULL_ROUNDS: usize, + const PARTIAL_ROUNDS: usize, +> + BorrowMut< + Poseidon2CircuitCols< + T, + WIDTH, + WIDTH_EXT, + RATE_EXT, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + >, + > for [T] +{ + fn borrow_mut( + &mut self, + ) -> &mut Poseidon2CircuitCols< + T, + WIDTH, + WIDTH_EXT, + RATE_EXT, + SBOX_DEGREE, + SBOX_REGISTERS, + HALF_FULL_ROUNDS, + PARTIAL_ROUNDS, + > { + let (prefix, shorts, suffix) = unsafe { + self.align_to_mut::>() + }; + debug_assert!(prefix.is_empty(), "Alignment should match"); + debug_assert!(suffix.is_empty(), "Alignment should match"); + debug_assert_eq!(shorts.len(), 1); + &mut shorts[0] + } +} diff --git a/poseidon2-circuit-air/src/lib.rs b/poseidon2-circuit-air/src/lib.rs new file mode 100644 index 0000000..4852024 --- /dev/null +++ b/poseidon2-circuit-air/src/lib.rs @@ -0,0 +1,12 @@ +//! An AIR for the Poseidon2 table for recursion. Handles sponge operations and compressions. + +#![no_std] + +extern crate alloc; + +mod air; +mod columns; +mod sub_builder; + +pub use air::*; +pub use columns::*; diff --git a/poseidon2-circuit-air/src/sub_builder.rs b/poseidon2-circuit-air/src/sub_builder.rs new file mode 100644 index 0000000..0b91c09 --- /dev/null +++ b/poseidon2-circuit-air/src/sub_builder.rs @@ -0,0 +1,108 @@ +// Code from SP1 with minor modifications: +// https://github.com/succinctlabs/sp1/blob/main/crates/stark/src/air/sub_builder.rs + +use alloc::vec::Vec; +use core::ops::{Deref, Range}; + +use p3_air::{Air, AirBuilder}; +use p3_matrix::Matrix; + +/// A submatrix of a matrix. The matrix will contain a subset of the columns of `self.inner`. +pub struct SubMatrixRowSlices, T: Send + Sync + Clone> { + inner: M, + column_range: Range, + _phantom: core::marker::PhantomData, +} + +impl, T: Send + Sync + Clone> SubMatrixRowSlices { + /// Creates a new [`SubMatrixRowSlices`]. + #[must_use] + pub const fn new(inner: M, column_range: Range) -> Self { + Self { + inner, + column_range, + _phantom: core::marker::PhantomData, + } + } +} + +/// Implement `Matrix` for `SubMatrixRowSlices`. +impl, T: Send + Sync + Clone> Matrix for SubMatrixRowSlices { + #[inline] + fn row( + &self, + r: usize, + ) -> Option + Send + Sync>> { + self.inner.row(r).map(|row| { + row.into_iter() + .take(self.column_range.end) + .skip(self.column_range.start) + }) + } + + #[inline] + fn row_slice(&self, r: usize) -> Option> { + self.row(r)?.into_iter().collect::>().into() + } + + #[inline] + fn width(&self) -> usize { + self.column_range.len() + } + + #[inline] + fn height(&self) -> usize { + self.inner.height() + } +} + +/// A builder used to eval a sub-air. This will handle enforcing constraints for a subset of a +/// trace matrix. E.g. if a particular air needs to be enforced for a subset of the columns of +/// the trace, then the [`SubAirBuilder`] can be used. +pub struct SubAirBuilder<'a, AB: AirBuilder, SubAir: Air, T> { + inner: &'a mut AB, + column_range: Range, + _phantom: core::marker::PhantomData<(SubAir, T)>, +} + +impl<'a, AB: AirBuilder, SubAir: Air, T> SubAirBuilder<'a, AB, SubAir, T> { + /// Creates a new [`SubAirBuilder`]. + #[must_use] + pub fn new(inner: &'a mut AB, column_range: Range) -> Self { + Self { + inner, + column_range, + _phantom: core::marker::PhantomData, + } + } +} + +/// Implement `AirBuilder` for `SubAirBuilder`. +impl, F> AirBuilder for SubAirBuilder<'_, AB, SubAir, F> { + type F = AB::F; + type Expr = AB::Expr; + type Var = AB::Var; + type M = SubMatrixRowSlices; + + fn main(&self) -> Self::M { + let matrix = self.inner.main(); + + SubMatrixRowSlices::new(matrix, self.column_range.clone()) + } + + fn is_first_row(&self) -> Self::Expr { + self.inner.is_first_row() + } + + fn is_last_row(&self) -> Self::Expr { + self.inner.is_last_row() + } + + fn is_transition_window(&self, size: usize) -> Self::Expr { + self.inner.is_transition_window(size) + } + + fn assert_zero>(&mut self, x: I) { + self.inner.assert_zero(x.into()); + } +}