diff --git a/fhevm-engine/Cargo.toml b/fhevm-engine/Cargo.toml index be4d80ae..0e7bc11b 100644 --- a/fhevm-engine/Cargo.toml +++ b/fhevm-engine/Cargo.toml @@ -1,6 +1,7 @@ [workspace] resolver = "2" -members = ["coprocessor", "executor", "fhevm-engine-common", "listener"] +members = ["coprocessor", "executor", "fhevm-engine-common", "listener", "sns-executor"] + [workspace.package] authors = ["Zama"] diff --git a/fhevm-engine/coprocessor/migrations/20250131125649_add_large_ct_column.sql b/fhevm-engine/coprocessor/migrations/20250131125649_add_large_ct_column.sql new file mode 100644 index 00000000..3bd78e42 --- /dev/null +++ b/fhevm-engine/coprocessor/migrations/20250131125649_add_large_ct_column.sql @@ -0,0 +1,4 @@ +ALTER TABLE ciphertexts +ADD COLUMN IF NOT EXISTS large_ct BYTEA, +ADD COLUMN IF NOT EXISTS is_sent BOOLEAN NOT NULL DEFAULT FALSE, +ADD COLUMN IF NOT EXISTS is_allowed BOOLEAN NOT NULL DEFAULT FALSE; \ No newline at end of file diff --git a/fhevm-engine/fhevm-db/migrations/20250131125649_add_large_ct_column.sql b/fhevm-engine/fhevm-db/migrations/20250131125649_add_large_ct_column.sql new file mode 100644 index 00000000..3bd78e42 --- /dev/null +++ b/fhevm-engine/fhevm-db/migrations/20250131125649_add_large_ct_column.sql @@ -0,0 +1,4 @@ +ALTER TABLE ciphertexts +ADD COLUMN IF NOT EXISTS large_ct BYTEA, +ADD COLUMN IF NOT EXISTS is_sent BOOLEAN NOT NULL DEFAULT FALSE, +ADD COLUMN IF NOT EXISTS is_allowed BOOLEAN NOT NULL DEFAULT FALSE; \ No newline at end of file diff --git a/fhevm-engine/fhevm-engine-common/src/types.rs b/fhevm-engine/fhevm-engine-common/src/types.rs index 093e2d6d..f4847171 100644 --- a/fhevm-engine/fhevm-engine-common/src/types.rs +++ b/fhevm-engine/fhevm-engine-common/src/types.rs @@ -1,8 +1,10 @@ use anyhow::Result; use bigdecimal::num_bigint::BigInt; use tfhe::integer::bigint::StaticUnsignedBigInt; +use tfhe::integer::ciphertext::BaseRadixCiphertext; use tfhe::integer::U256; use tfhe::prelude::{CiphertextList, FheDecrypt}; +use tfhe::shortint::Ciphertext; use tfhe::{CompressedCiphertextList, CompressedCiphertextListBuilder}; use crate::utils::{safe_deserialize, safe_serialize}; @@ -389,6 +391,28 @@ impl SupportedFheCiphertexts { } } + pub fn to_ciphertext64(self) -> BaseRadixCiphertext { + match self { + SupportedFheCiphertexts::FheBool(v) => { + BaseRadixCiphertext::from(vec![v.into_raw_parts()]) + } + SupportedFheCiphertexts::FheUint4(v) => v.into_raw_parts().0, + SupportedFheCiphertexts::FheUint8(v) => v.into_raw_parts().0, + SupportedFheCiphertexts::FheUint16(v) => v.into_raw_parts().0, + SupportedFheCiphertexts::FheUint32(v) => v.into_raw_parts().0, + SupportedFheCiphertexts::FheUint64(v) => v.into_raw_parts().0, + SupportedFheCiphertexts::FheUint128(v) => v.into_raw_parts().0, + SupportedFheCiphertexts::FheUint160(v) => v.into_raw_parts().0, + SupportedFheCiphertexts::FheUint256(v) => v.into_raw_parts().0, + SupportedFheCiphertexts::FheBytes64(v) => v.into_raw_parts().0, + SupportedFheCiphertexts::FheBytes128(v) => v.into_raw_parts().0, + SupportedFheCiphertexts::FheBytes256(v) => v.into_raw_parts().0, + SupportedFheCiphertexts::Scalar(_) => { + panic!("scalar cannot be converted to regular ciphertext") + } + } + } + pub fn type_num(&self) -> i16 { match self { // values taken to match with solidity library diff --git a/fhevm-engine/sns-executor/Cargo.toml b/fhevm-engine/sns-executor/Cargo.toml new file mode 100644 index 00000000..e54d5a01 --- /dev/null +++ b/fhevm-engine/sns-executor/Cargo.toml @@ -0,0 +1,45 @@ +[package] +name = "sns-executor" +version = "0.1.0" +authors.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +# workspace dependencies +bincode = { workspace = true } +clap = { workspace = true } +prometheus = { workspace = true } +prost = { workspace = true } +rayon = { workspace = true } +sha3 = { workspace = true } +tokio = { workspace = true } +tonic = { workspace = true } +tracing = { workspace = true } +tracing-subscriber = { workspace = true } +anyhow = { workspace = true } +serde = { workspace = true } +hex = "0.4" + +aligned-vec = "0.5.0" +num-traits = "=0.2.19" +sqlx = { version = "0.7", features = ["runtime-tokio", "tls-rustls", "postgres", "uuid"] } + +serde_json = "=1.0" + +# local dependencies +fhevm-engine-common = { path = "../fhevm-engine-common" } + +# arch-specific dependencies +[target.'cfg(target_arch = "x86_64")'.dependencies] +tfhe = { workspace = true, features = ["x86_64-unix"] } +[target.'cfg(target_arch = "aarch64")'.dependencies] +tfhe = { workspace = true, features = ["aarch64-unix"] } + +[[bin]] +name = "sns_worker" +path = "src/bin/sns_worker.rs" + + +[features] +decrypt_128 = [] \ No newline at end of file diff --git a/fhevm-engine/sns-executor/README.md b/fhevm-engine/sns-executor/README.md new file mode 100644 index 00000000..b63aa4a0 --- /dev/null +++ b/fhevm-engine/sns-executor/README.md @@ -0,0 +1,29 @@ +# SnS executor + +## Description + +### Library crate (sns-executor) + +Executes a loop that: +- Retrieves `(handle, compressed_ct)` pairs from PG table.ciphertexts marked as `allowed`. +- Computes `large_ct` using the SnS algorithm. +- Updates the `large_ct` column corresponding to the specified handle. +- Sends a signal indicating the availability of newly computed `large_ct`. + +#### Features +**decrypt_128** - Decrypt each `large_ct` and print it as a plaintext (for testing purposes only). + +### Binary (sns-worker) + +Runs sns-executor. See also `src/bin/utils/daemon_cli.rs` + + +## How to run a sns-worker + +``` +DATABASE_URL=postgresql://postgres:postgres@localhost:5432/coprocessor \ +cargo run --release -- \ +--pg-listen-channel "allowed_handles" \ +--pg-notify-channel "computed_handles" \ +--keys-file-path "./default_keys.bin" +``` \ No newline at end of file diff --git a/fhevm-engine/sns-executor/src/bin/sns_worker.rs b/fhevm-engine/sns-executor/src/bin/sns_worker.rs new file mode 100644 index 00000000..30bc9879 --- /dev/null +++ b/fhevm-engine/sns-executor/src/bin/sns_worker.rs @@ -0,0 +1,56 @@ +use serde::{de::DeserializeOwned, Serialize}; +use sns_executor::DBConfig; +use std::fs; +use tokio::{signal::unix, sync::broadcast}; + +mod utils; + +fn read_element(file_path: String) -> anyhow::Result { + let read_element = fs::read(file_path.clone())?; + Ok(bincode::deserialize_from(read_element.as_slice())?) +} + +fn handle_sigint(cancel_tx: broadcast::Sender<()>) { + tokio::spawn(async move { + let mut signal = unix::signal(unix::SignalKind::interrupt()).unwrap(); + signal.recv().await; + cancel_tx.send(()).unwrap(); + }); +} + +#[tokio::main] +async fn main() { + let args = utils::daemon_cli::parse_args(); + + // Read keys from the file path, if specified + let mut keys = None; + if let Some(path) = args.keys_file_path { + keys = Some(read_element(path).expect("Failed to read keys.")); + } + + let db_url = args + .database_url + .clone() + .unwrap_or_else(|| std::env::var("DATABASE_URL").expect("DATABASE_URL is undefined")); + + tracing_subscriber::fmt().json().with_level(true).init(); + + let conf = sns_executor::Config { + db: DBConfig { + url: db_url, + listen_channel: args.pg_listen_channel, + notify_channel: args.pg_notify_channel, + batch_limit: args.work_items_batch_size, + polling_interval: args.pg_polling_interval, + max_connections: args.pg_pool_connections, + }, + }; + + // Handle SIGINIT signals + let (cancel_tx, cancel_rx) = broadcast::channel(1); + handle_sigint(cancel_tx); + + if let Err(err) = sns_executor::run(keys, &conf, cancel_rx).await { + tracing::error!("Worker failed: {:?}", err); + } +} diff --git a/fhevm-engine/sns-executor/src/bin/utils/daemon_cli.rs b/fhevm-engine/sns-executor/src/bin/utils/daemon_cli.rs new file mode 100644 index 00000000..f1ff0fd6 --- /dev/null +++ b/fhevm-engine/sns-executor/src/bin/utils/daemon_cli.rs @@ -0,0 +1,41 @@ +use clap::{command, Parser}; + +#[derive(Parser, Debug, Clone)] +#[command(version, about, long_about = None)] +pub struct Args { + /// Work items batch size + #[arg(long, default_value_t = 4)] + pub work_items_batch_size: u32, + + /// NOTIFY/LISTEN channel for database that the worker listen to + #[arg(long)] + pub pg_listen_channel: String, + + /// NOTIFY/LISTEN channel for database that the worker notify to + #[arg(long)] + pub pg_notify_channel: String, + + /// Polling interval in seconds + #[arg(long, default_value_t = 60)] + pub pg_polling_interval: u32, + + /// Postgres pool connections + #[arg(long, default_value_t = 10)] + pub pg_pool_connections: u32, + + /// Postgres database url. If unspecified DATABASE_URL environment variable is used + #[arg(long)] + pub database_url: Option, + + /// KeySet file. If unspecified the the keys are read from the database (not implemented) + #[arg(long)] + pub keys_file_path: Option, + + /// sns-executor service name in OTLP traces (not implemented) + #[arg(long, default_value = "sns-executor")] + pub service_name: String, +} + +pub fn parse_args() -> Args { + Args::parse() +} diff --git a/fhevm-engine/sns-executor/src/bin/utils/mod.rs b/fhevm-engine/sns-executor/src/bin/utils/mod.rs new file mode 100644 index 00000000..6dfe2b11 --- /dev/null +++ b/fhevm-engine/sns-executor/src/bin/utils/mod.rs @@ -0,0 +1 @@ +pub mod daemon_cli; diff --git a/fhevm-engine/sns-executor/src/executor.rs b/fhevm-engine/sns-executor/src/executor.rs new file mode 100644 index 00000000..891ab47a --- /dev/null +++ b/fhevm-engine/sns-executor/src/executor.rs @@ -0,0 +1,261 @@ +use std::error::Error; +use std::time::Duration; + +use sqlx::postgres::PgListener; +use sqlx::{Acquire, PgPool, Postgres, Transaction}; +use tfhe::integer::IntegerCiphertext; +use tfhe::set_server_key; +use tokio::select; +use tokio::sync::broadcast; +use tracing::{debug, error, info}; + +use crate::{switch_and_squash::Ciphertext128, KeySet}; +use crate::{Config, DBConfig}; + +use fhevm_engine_common::types::{get_ct_type, SupportedFheCiphertexts}; + +const RETRY_DB_CONN_INTERVAL: Duration = Duration::from_secs(5); + +enum ConnStatus { + Established(sqlx::pool::PoolConnection), + Failed, + Cancelled, +} + +struct SnSTask { + handle: Vec, + compressed: Vec, + large_ct: Option, +} + +/// Executes the worker logic for the SnS task. +pub(crate) async fn run_loop( + keys: Option, + conf: &Config, + mut cancel_chan: broadcast::Receiver<()>, +) -> Result<(), Box> { + let keys = keys.unwrap_or_else(|| unimplemented!("Read keys from the database")); + let conf = &conf.db; + + let pool = sqlx::postgres::PgPoolOptions::new() + .max_connections(conf.max_connections) + .connect(&conf.url) + .await?; + + let mut listener = PgListener::connect_with(&pool).await?; + listener.listen(&conf.listen_channel).await?; + + loop { + let mut conn = match acquire_connection(&pool, &mut cancel_chan).await { + ConnStatus::Established(conn) => conn, + ConnStatus::Failed => { + tokio::time::sleep(RETRY_DB_CONN_INTERVAL).await; + continue; // Retry to reacquire a connection + } + ConnStatus::Cancelled => return Ok(()), + }; + + loop { + if let Err(err) = poll_and_execute_sns_tasks(&mut conn, &keys, conf).await { + error!(target: "worker", "Failed to poll and execute tasks: {err}"); + break; // Break to reacquire a connection + } + + select! { + _ = cancel_chan.recv() => return Ok(()), + _ = listener.try_recv() => { + debug!(target: "worker", "Received notification"); + }, + _ = tokio::time::sleep(Duration::from_secs(conf.polling_interval.into())) => { + debug!(target: "worker", "Polling timeout, rechecking for tasks"); + } + } + } + } +} + +/// Polls the database for tasks and executes them. +async fn poll_and_execute_sns_tasks( + conn: &mut sqlx::pool::PoolConnection, + keys: &KeySet, + conf: &DBConfig, +) -> Result<(), Box> { + let mut db_txn = match conn.begin().await { + Ok(txn) => txn, + Err(err) => { + error!(target: "worker", "Failed to begin transaction: {err}"); + return Err(err.into()); + } + }; + + if let Some(mut tasks) = query_sns_tasks(&mut db_txn, conf.batch_limit).await? { + process_tasks(&mut tasks, keys)?; + update_large_ct(&mut db_txn, &tasks).await?; + notify_large_ct_ready(&mut db_txn, &conf.notify_channel).await?; + db_txn.commit().await?; + } else { + db_txn.rollback().await?; + } + + Ok(()) +} + +async fn acquire_connection( + pool: &PgPool, + cancel_chan: &mut broadcast::Receiver<()>, +) -> ConnStatus { + select! { + conn = pool.acquire() => match conn { + Ok(conn) => ConnStatus::Established(conn), + Err(err) => { + error!(target: "worker", "Failed to acquire connection: {err}"); + ConnStatus::Failed + } + }, + _ = cancel_chan.recv() => { + info!(target: "worker", "Cancellation received while acquiring connection"); + ConnStatus::Cancelled + } + } +} + +/// Queries the database for a fixed number of tasks. +async fn query_sns_tasks( + db_txn: &mut Transaction<'_, Postgres>, + limit: u32, +) -> Result>, Box> { + let records = sqlx::query!( + " + SELECT handle, ciphertext + FROM ciphertexts + WHERE ciphertext IS NOT NULL + AND is_allowed = TRUE + AND is_sent = FALSE + AND large_ct IS NULL + FOR UPDATE SKIP LOCKED + LIMIT $1;", + limit as i64 + ) + .fetch_all(db_txn.as_mut()) + .await?; + + info!(target: "sns", { count = records.len()}, "Fetched SnS tasks"); + + if records.is_empty() { + return Ok(None); + } + + let tasks = records + .into_iter() + .map(|record| SnSTask { + handle: record.handle, + compressed: record.ciphertext, + large_ct: None, + }) + .collect(); + + Ok(Some(tasks)) +} + +/// Processes the tasks by decompressing and transforming ciphertexts. +fn process_tasks(tasks: &mut [SnSTask], keys: &KeySet) -> Result<(), Box> { + set_server_key(keys.public_keys.server_key.clone()); + + for task in tasks.iter_mut() { + let ct = decompress_ct(&task.handle, &task.compressed)?; + let raw_ct = ct.to_ciphertext64(); + let handle = to_hex(&task.handle); + + let blocks = raw_ct.blocks().len(); + info!(target: "sns", { handle, blocks }, "Converting ciphertext"); + + let sns_key = keys + .public_keys + .sns_key + .as_ref() + .ok_or_else(|| "sns_key not found".to_string())?; + + let large_ct = sns_key.to_large_ciphertext(&raw_ct).map_err(|e| { + format!( + "Failed to convert to large ciphertext: handle: {} {}", + handle, e + ) + })?; + + info!(target: "sns", { handle }, "Ciphertext converted"); + + // Optional: Decrypt and log for debugging + #[cfg(feature = "decrypt_128")] + { + let decrypted = keys.sns_secret_key.decrypt_128(&large_ct); + info!(target: "sns", { handle, decrypted }, "Decrypted plaintext"); + } + + task.large_ct = Some(large_ct); + } + + Ok(()) +} + +/// Updates the database with the computed large ciphertexts. +async fn update_large_ct( + db_txn: &mut Transaction<'_, Postgres>, + tasks: &[SnSTask], +) -> Result<(), Box> { + for task in tasks { + if let Some(large_ct) = &task.large_ct { + let large_ct_bytes = bincode::serialize(large_ct)?; + sqlx::query!( + " + UPDATE ciphertexts + SET large_ct = $1 + WHERE handle = $2;", + large_ct_bytes, + task.handle + ) + .execute(db_txn.as_mut()) + .await?; + } else { + error!(target: "worker", handle = ?task.handle, "Large ciphertext not computed for task"); + } + } + Ok(()) +} + +/// Notifies the database that large ciphertexts are ready. +async fn notify_large_ct_ready( + db_txn: &mut Transaction<'_, Postgres>, + db_channel: &str, +) -> Result<(), Box> { + sqlx::query("SELECT pg_notify($1, '')") + .bind(db_channel) + .execute(db_txn.as_mut()) + .await?; + Ok(()) +} + +/// Decompresses a ciphertext based on its type. +fn decompress_ct( + handle: &[u8], + compressed_ct: &[u8], +) -> Result> { + let ct_type = get_ct_type(handle)?; + SupportedFheCiphertexts::decompress(ct_type, compressed_ct).map_err(|e| e.into()) +} + +// Print first 4 and last 4 bytes of a blob as hex +fn to_hex(handle: &[u8]) -> String { + const OFFSET: usize = 8; + match handle.len() { + 0 => String::from("0x"), + len if len <= 2 * OFFSET => format!("0x{}", hex::encode(handle)), + _ => { + let hex_str = hex::encode(handle); + format!( + "0x{}...{}", + &hex_str[..OFFSET], + &hex_str[hex_str.len() - OFFSET..] + ) + } + } +} diff --git a/fhevm-engine/sns-executor/src/lib.rs b/fhevm-engine/sns-executor/src/lib.rs new file mode 100644 index 00000000..faf00294 --- /dev/null +++ b/fhevm-engine/sns-executor/src/lib.rs @@ -0,0 +1,63 @@ +mod executor; +mod switch_and_squash; + +use serde::{Deserialize, Serialize}; +use switch_and_squash::{SnsClientKey, SwitchAndSquashKey}; +use tokio::sync::broadcast; +use tracing::info; + +#[derive(Serialize, Deserialize, Clone)] +pub struct FhePubKeySet { + pub public_key: tfhe::CompactPublicKey, + pub server_key: tfhe::ServerKey, + pub sns_key: Option, +} +#[derive(Serialize, Deserialize, Clone)] +pub struct KeySet { + pub client_key: tfhe::ClientKey, + pub sns_secret_key: SnsClientKey, + pub public_keys: FhePubKeySet, +} + +pub struct DBConfig { + pub url: String, + pub listen_channel: String, + pub notify_channel: String, + pub batch_limit: u32, + pub polling_interval: u32, + pub max_connections: u32, +} + +pub struct Config { + pub db: DBConfig, +} + +/// Implement Display for Config +impl std::fmt::Display for Config { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!( + f, + "db_url: {}, db_listen_channel: {}, db_notify_channel: {}, db_batch_limit: {}", + self.db.url, self.db.listen_channel, self.db.notify_channel, self.db.batch_limit + ) + } +} + +/// Starts the worker loop +/// +/// # Arguments +/// +/// * `keys` - The keys to use for the worker +/// * `limit` - The maximum number of tasks to process per iteration +pub async fn run( + keys: Option, + conf: &Config, + cancel_chan: broadcast::Receiver<()>, +) -> Result<(), Box> { + info!(target: "sns", "Worker started with {}", conf); + + executor::run_loop(keys, conf, cancel_chan).await?; + + info!(target: "sns", "Worker stopped"); + Ok(()) +} diff --git a/fhevm-engine/sns-executor/src/switch_and_squash.rs b/fhevm-engine/sns-executor/src/switch_and_squash.rs new file mode 100644 index 00000000..afe81af1 --- /dev/null +++ b/fhevm-engine/sns-executor/src/switch_and_squash.rs @@ -0,0 +1,338 @@ +use aligned_vec::ABox; +use anyhow::anyhow; +use rayon::prelude::*; +use serde::{Deserialize, Serialize}; +use std::fmt; +use std::num::Wrapping; + +use std::panic::Location; + +use tfhe::boolean::prelude::PolynomialSize; +use tfhe::core_crypto::prelude::{ + allocate_and_trivially_encrypt_new_glwe_ciphertext, decrypt_lwe_ciphertext, + keyswitch_lwe_ciphertext, lwe_ciphertext_cleartext_mul_assign, + programmable_bootstrap_f128_lwe_ciphertext, CiphertextModulus, Cleartext, Container, + ContainerMut, Fourier128LweBootstrapKey, GlweCiphertextOwned, GlweSize, LweCiphertext, + LweCiphertextOwned, LweKeyswitchKey, LweSecretKeyOwned, LweSize, PlaintextList, + UnsignedInteger, UnsignedTorus, +}; + +use tfhe::{ + core_crypto::commons::traits::{CastFrom, CastInto}, + integer::IntegerCiphertext, + shortint::PBSOrder, +}; + +use tfhe::integer::block_decomposition::BlockRecomposer; +use tfhe::integer::ciphertext::BaseRadixCiphertext; + +use tfhe::shortint::ClassicPBSParameters; + +pub type Z128 = Wrapping; +use num_traits::{AsPrimitive, ConstZero}; + +pub type Ciphertext64 = BaseRadixCiphertext; +pub type Ciphertext64Block = tfhe::shortint::Ciphertext; +// Observe that tfhe-rs is hard-coded to use u64, hence we require custom types for the 128 bit versions for now. +pub type Ciphertext128 = Vec; +pub type Ciphertext128Block = LweCiphertextOwned; + +// NOTE: the below is copied from core/threshold +// since the calling tracing from another crate +// does not generate correct logs in tracing_test::traced_test +#[track_caller] +pub(crate) fn anyhow_error_and_log + fmt::Display>(msg: S) -> anyhow::Error { + println!("Error in {}: {}", Location::caller(), msg); + anyhow!("Error in {}: {}", Location::caller(), msg) +} + +/// Key used for switch-and-squash to convert a ciphertext over u64 to one over u128 +#[derive(Serialize, Deserialize, Clone, PartialEq)] +pub struct SwitchAndSquashKey { + pub fbsk_out: Fourier128LweBootstrapKey>, + //ksk is needed if PBSOrder is KS-PBS + pub ksk: LweKeyswitchKey>, +} + +pub trait AugmentedCiphertextParameters { + // Return the minimum amount of bits that can be used for a message in each block. + fn message_modulus_log(&self) -> u32; + + // Return the minimum amount of bits that can be used for a carry in each block. + fn carry_modulus_log(&self) -> u32; + // Return the minimum total amounts of availble bits in each block. I.e. including both message and carry bits + fn total_block_bits(&self) -> u32; +} + +impl AugmentedCiphertextParameters for tfhe::shortint::Ciphertext { + // Return the minimum amount of bits that can be used for a message in each block. + fn message_modulus_log(&self) -> u32 { + self.message_modulus.0.ilog2() + } + + // Return the minimum amount of bits that can be used for a carry in each block. + fn carry_modulus_log(&self) -> u32 { + self.carry_modulus.0.ilog2() + } + + // Return the minimum total amounts of availble bits in each block. I.e. including both message and carry bits + fn total_block_bits(&self) -> u32 { + self.carry_modulus_log() + self.message_modulus_log() + } +} + +impl SwitchAndSquashKey { + pub fn new( + fbsk_out: Fourier128LweBootstrapKey>, + ksk: LweKeyswitchKey>, + ) -> Self { + SwitchAndSquashKey { fbsk_out, ksk } + } + + /// Converts a ciphertext over a 64 bit domain to a ciphertext over a 128 bit domain (which is needed for secure threshold decryption). + /// Conversion is done using a precreated conversion key [conversion_key]. + /// Observe that the decryption key will be different after conversion, since [conversion_key] is actually a key-switching key. + /// This version computes SnS on all blocks in parallel. + pub fn to_large_ciphertext( + &self, + raw_small_ct: &Ciphertext64, + ) -> anyhow::Result { + let blocks = raw_small_ct.blocks(); + // do switch and squash on all blocks in parallel + let res = blocks + .par_iter() + .map(|current_block| self.to_large_ciphertext_block(current_block)) + .collect::>>()?; + Ok(res) + } + + /// Converts a single ciphertext block over a 64 bit domain to a ciphertext block over a 128 bit domain (which is needed for secure threshold decryption). + /// Conversion is done using a precreated conversion key, [conversion_key]. + /// Observe that the decryption key will be different after conversion, since [conversion_key] is actually a key-switching key. + pub fn to_large_ciphertext_block( + &self, + small_ct_block: &Ciphertext64Block, + ) -> anyhow::Result { + let total_bits = small_ct_block.total_block_bits(); + + // Accumulator definition + let delta = 1_u64 << (u64::BITS - 1 - total_bits); + let msg_modulus = 1_u64 << total_bits; + + let f_out = |x: u128| x; + let delta_u128 = (delta as u128) << 64; + let accumulator_out: GlweCiphertextOwned = Self::generate_accumulator( + self.fbsk_out.polynomial_size(), + self.fbsk_out.glwe_size(), + msg_modulus.cast_into(), + CiphertextModulus::::new_native(), + delta_u128, + f_out, + ); + + //MSUP + let mut ms_output_lwe = LweCiphertext::new( + 0_u128, + self.fbsk_out.input_lwe_dimension().to_lwe_size(), + CiphertextModulus::new_native(), + ); + //If ctype = F-GLWE we need to KS before doing the Bootstrap + if small_ct_block.pbs_order == PBSOrder::KeyswitchBootstrap { + let mut output_raw_ctxt = + LweCiphertext::new(0, self.ksk.output_lwe_size(), self.ksk.ciphertext_modulus()); + keyswitch_lwe_ciphertext(&self.ksk, &small_ct_block.ct, &mut output_raw_ctxt); + Self::lwe_ciphertext_modulus_switch_up(&mut ms_output_lwe, &output_raw_ctxt)?; + } else { + Self::lwe_ciphertext_modulus_switch_up(&mut ms_output_lwe, &small_ct_block.ct)?; + }; + + let pbs_cipher_size = LweSize( + 1 + self.fbsk_out.glwe_size().to_glwe_dimension().0 * self.fbsk_out.polynomial_size().0, + ); + let mut out_pbs_ct = LweCiphertext::new( + 0_u128, + pbs_cipher_size, + CiphertextModulus::::new_native(), + ); + programmable_bootstrap_f128_lwe_ciphertext( + &ms_output_lwe, + &mut out_pbs_ct, + &accumulator_out, + &self.fbsk_out, + ); + Ok(out_pbs_ct) + } + + // Here we will define a helper function to generate an accumulator for a PBS + fn generate_accumulator>( + polynomial_size: PolynomialSize, + glwe_size: GlweSize, + message_modulus: usize, + ciphertext_modulus: CiphertextModulus, + delta: Scalar, + f: F, + ) -> GlweCiphertextOwned + where + F: Fn(Scalar) -> Scalar, + { + // N/(p/2) = size of each block, to correct noise from the input we introduce the + // notion of box, which manages redundancy to yield a denoised value + // for several noisy values around a true input value. + let box_size = polynomial_size.0 / message_modulus; + + // Create the accumulator + let mut accumulator_scalar = vec![Scalar::ZERO; polynomial_size.0]; + + // Fill each box with the encoded denoised value + for i in 0..message_modulus { + let index = i * box_size; + accumulator_scalar[index..index + box_size] + .iter_mut() + .for_each(|a| *a = f(Scalar::cast_from(i)) * delta); + } + + let half_box_size = box_size / 2; + + // Negate the first half_box_size coefficients to manage negacyclicity and rotate + for a_i in accumulator_scalar[0..half_box_size].iter_mut() { + *a_i = (*a_i).wrapping_neg(); + } + + // Rotate the accumulator + accumulator_scalar.rotate_left(half_box_size); + + let accumulator_plaintext = PlaintextList::from_container(accumulator_scalar); + + allocate_and_trivially_encrypt_new_glwe_ciphertext( + glwe_size, + &accumulator_plaintext, + ciphertext_modulus, + ) + } + + /// The method below is copied from the `noise-gap-exp` branch in tfhe-rs-internal (and added error handling) + /// since this branch will likely not be merged in main. + /// + /// Takes a ciphertext, `input`, of a certain domain, [InputScalar] and overwrites the content of `output` + /// with the ciphertext converted to the [OutputScaler] domain. + fn lwe_ciphertext_modulus_switch_up( + output: &mut LweCiphertext, + input: &LweCiphertext, + ) -> anyhow::Result<()> + where + InputScalar: UnsignedInteger + CastInto, + OutputScalar: UnsignedInteger, + InputCont: Container, + OutputCont: ContainerMut, + { + if !input.ciphertext_modulus().is_native_modulus() { + return Err(anyhow_error_and_log( + "Ciphertext modulus is not native, which is the only kind supported", + )); + } + + output + .as_mut() + .iter_mut() + .zip(input.as_ref().iter()) + .for_each(|(dst, &src)| *dst = src.cast_into()); + let modulus_up: CiphertextModulus = input + .ciphertext_modulus() + .try_to() + .map_err(|_| anyhow_error_and_log("Could not parse ciphertext modulus"))?; + + lwe_ciphertext_cleartext_mul_assign( + output, + Cleartext(modulus_up.get_power_of_two_scaling_to_native_torus()), + ); + Ok(()) + } +} + +impl AugmentedCiphertextParameters for ClassicPBSParameters { + // Return the minimum amount of bits that can be used for a message in each block. + fn message_modulus_log(&self) -> u32 { + self.message_modulus.0.ilog2() + } + + // Return the minimum amount of bits that can be used for a carry in each block. + fn carry_modulus_log(&self) -> u32 { + self.carry_modulus.0.ilog2() + } + + // Return the minimum total amounts of availble bits in each block. I.e. including both message and carry bits + fn total_block_bits(&self) -> u32 { + self.carry_modulus_log() + self.message_modulus_log() + } +} + +#[derive(Serialize, Deserialize, Clone)] +pub struct SnsClientKey { + pub key: LweSecretKeyOwned, + pub params: ClassicPBSParameters, +} + +impl SnsClientKey { + pub fn new(params: ClassicPBSParameters, sns_secret_key: LweSecretKeyOwned) -> Self { + SnsClientKey { + key: sns_secret_key, + params, + } + } + + #[cfg(feature = "decrypt_128")] + pub fn decrypt_128(&self, ct: &Ciphertext128) -> u128 { + if ct.is_empty() { + return 0; + } + + let bits_in_block = self.params.message_modulus_log(); + let mut recomposer = BlockRecomposer::::new(bits_in_block); + + for encrypted_block in ct { + let decrypted_block = self.decrypt_block_128(encrypted_block); + if !recomposer.add_unmasked(decrypted_block.0) { + // End of T::BITS reached no need to try more + // recomposition + break; + }; + } + + recomposer.value() + } + + #[cfg(feature = "decrypt_128")] + pub(crate) fn decrypt_block_128(&self, ct: &Ciphertext128Block) -> Z128 { + let total_bits = self.params.total_block_bits() as usize; + let raw_plaintext = decrypt_lwe_ciphertext(&self.key, ct); + from_expanded_msg(raw_plaintext.0, total_bits) + } +} + +// Map a raw, decrypted message to its real value by dividing by the appropriate shift, delta, assuming padding +pub(crate) fn from_expanded_msg>( + raw_plaintext: Scalar, + message_and_carry_mod_bits: usize, +) -> Z128 { + // delta = q/t where t is the amount of plain text bits + // Observe that t includes the message and carry bits as well as the padding bit (hence the + 1) + let delta_pad_bits = (Scalar::BITS as u128) - (message_and_carry_mod_bits as u128 + 1_u128); + + // Observe that in certain situations the computation of b- may be negative + // Concretely this happens when the message encrypted is 0 and randomness ends up being negative. + // We cannot simply do the standard modulo operation then, as this would mean the message becomes + // 2^message_mod_bits instead of 0 as it should be. + // However the maximal negative value it can have (without a general decryption error) is delta/2 + // which we can compute as 1 << delta_pad_bits, since the padding already halves the true delta + if raw_plaintext.as_() > Scalar::MAX.as_() - (1 << delta_pad_bits) { + Z128::ZERO + } else { + // compute delta / 2 + let delta_pad_half = 1 << (delta_pad_bits - 1); + + // add delta/2 to kill the negative noise, note this does not affect the message. + // and then divide by delta + let raw_msg = raw_plaintext.as_().wrapping_add(delta_pad_half) >> delta_pad_bits; + Wrapping(raw_msg % (1 << message_and_carry_mod_bits)) + } +}