Skip to content
This repository was archived by the owner on Jun 11, 2025. It is now read-only.

feat: add ciphertext compression in the coprocessor #36

Merged
merged 1 commit into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 44 additions & 32 deletions fhevm-engine/coprocessor/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::db_queries::{
};
use crate::server::coprocessor::GenericResponse;
use crate::types::{CoprocessorError, TfheTenantKeys};
use crate::utils::sort_computations_by_dependencies;
use crate::utils::{set_server_key_if_not_set, sort_computations_by_dependencies};
use alloy::signers::local::PrivateKeySigner;
use alloy::signers::SignerSync;
use alloy::sol_types::SolStruct;
Expand All @@ -23,6 +23,7 @@ use fhevm_engine_common::tfhe_ops::{
use fhevm_engine_common::types::{FhevmError, SupportedFheCiphertexts, SupportedFheOperations};
use sha3::{Digest, Keccak256};
use sqlx::{query, Acquire};
use tokio::task::spawn_blocking;
use tonic::transport::Server;

pub mod common {
Expand Down Expand Up @@ -177,14 +178,15 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ
))
})?;
let acl_contract_address =
alloy::primitives::Address::from_str(&fetch_key_response.acl_contract_address).map_err(|e| {
tonic::Status::from_error(Box::new(
CoprocessorError::CannotParseTenantEthereumAddress {
bad_address: fetch_key_response.acl_contract_address.clone(),
parsing_error: e.to_string(),
},
))
})?;
alloy::primitives::Address::from_str(&fetch_key_response.acl_contract_address)
.map_err(|e| {
tonic::Status::from_error(Box::new(
CoprocessorError::CannotParseTenantEthereumAddress {
bad_address: fetch_key_response.acl_contract_address.clone(),
parsing_error: e.to_string(),
},
))
})?;

let eip_712_domain = alloy::sol_types::eip712_domain! {
name: "InputVerifier",
Expand Down Expand Up @@ -224,12 +226,12 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ
let server_key = server_key.clone();
tfhe_work_set.spawn_blocking(
move || -> Result<_, (Box<(dyn std::error::Error + Send + Sync)>, usize)> {
let expanded =
try_expand_ciphertext_list(&cloned_input.input_payload, &server_key)
.map_err(|e| {
let err: Box<(dyn std::error::Error + Send + Sync)> = Box::new(e);
(err, idx)
})?;
set_server_key_if_not_set(tenant_id, &server_key);
let expanded = try_expand_ciphertext_list(&cloned_input.input_payload)
.map_err(|e| {
let err: Box<(dyn std::error::Error + Send + Sync)> = Box::new(e);
(err, idx)
})?;

Ok((expanded, idx))
},
Expand Down Expand Up @@ -280,7 +282,7 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ
assert_eq!(blob_hash.len(), 32, "should be 32 bytes");

let corresponding_unpacked = results
.get(&idx)
.remove(&idx)
.expect("we should have all results computed now");

// save blob for audits and historical reference
Expand Down Expand Up @@ -320,21 +322,31 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ
signer_address: self.signer.address().to_string(),
};

for (ct_idx, the_ct) in corresponding_unpacked.iter().enumerate() {
let (serialized_type, serialized_ct) = the_ct.serialize();
let ciphertext_version = current_ciphertext_version();
let mut handle_hash = Keccak256::new();
handle_hash.update(&blob_hash);
handle_hash.update(&[ct_idx as u8]);
handle_hash.update(acl_contract_address.as_slice());
handle_hash.update(&chain_id_be);
let mut handle = handle_hash.finalize().to_vec();
assert_eq!(handle.len(), 32);
// idx cast to u8 must succeed because we don't allow
// more handles than u8 size
handle[29] = ct_idx as u8;
handle[30] = serialized_type as u8;
handle[31] = ciphertext_version as u8;
let ciphertext_version = current_ciphertext_version();
for (ct_idx, the_ct) in corresponding_unpacked.into_iter().enumerate() {
// TODO: simplify compress and hash computation async handling
let blob_hash_clone = blob_hash.clone();
let server_key_clone = server_key.clone();
let (handle, serialized_ct, serialized_type) = spawn_blocking(move || {
set_server_key_if_not_set(tenant_id, &server_key_clone);
let (serialized_type, serialized_ct) = the_ct.compress();
let mut handle_hash = Keccak256::new();
handle_hash.update(&blob_hash_clone);
handle_hash.update(&[ct_idx as u8]);
handle_hash.update(acl_contract_address.as_slice());
handle_hash.update(&chain_id_be);
let mut handle = handle_hash.finalize().to_vec();
assert_eq!(handle.len(), 32);
// idx cast to u8 must succeed because we don't allow
// more handles than u8 size
handle[29] = ct_idx as u8;
handle[30] = serialized_type as u8;
handle[31] = ciphertext_version as u8;

(handle, serialized_ct, serialized_type)
})
.await
.map_err(|e| tonic::Status::from_error(Box::new(e)))?;

let _ = sqlx::query!(
"
Expand Down Expand Up @@ -572,7 +584,7 @@ impl coprocessor::fhevm_coprocessor_server::FhevmCoprocessor for CoprocessorServ
let mut res: Vec<(Vec<u8>, i16, Vec<u8>)> = Vec::with_capacity(cloned.len());
for v in cloned {
let ct = trivial_encrypt_be_bytes(v.output_type as i16, &v.be_value);
let (ct_type, ct_bytes) = ct.serialize();
let (ct_type, ct_bytes) = ct.compress();
res.push((v.handle, ct_type, ct_bytes));
}

Expand Down
22 changes: 13 additions & 9 deletions fhevm-engine/coprocessor/src/tests/utils.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::cli::Args;
use fhevm_engine_common::tfhe_ops::{current_ciphertext_version, deserialize_fhe_ciphertext};
use fhevm_engine_common::tfhe_ops::current_ciphertext_version;
use fhevm_engine_common::types::SupportedFheCiphertexts;
use rand::Rng;
use std::collections::BTreeMap;
use std::sync::atomic::{AtomicU16, Ordering};
Expand Down Expand Up @@ -240,9 +241,9 @@ pub async fn decrypt_ciphertexts(
tenant_id: i32,
input: Vec<Vec<u8>>,
) -> Result<Vec<DecryptionResult>, Box<dyn std::error::Error>> {
let mut priv_key = sqlx::query!(
let mut keys = sqlx::query!(
"
SELECT cks_key
SELECT cks_key, sks_key
FROM tenants
WHERE tenant_id = $1
",
Expand All @@ -251,16 +252,16 @@ pub async fn decrypt_ciphertexts(
.fetch_all(pool)
.await?;

if priv_key.is_empty() || priv_key[0].cks_key.is_none() {
panic!("tenant private key not found");
if keys.is_empty() || keys[0].cks_key.is_none() {
panic!("tenant keys not found");
}

let mut ct_indexes: BTreeMap<&[u8], usize> = BTreeMap::new();
for (idx, h) in input.iter().enumerate() {
ct_indexes.insert(h.as_slice(), idx);
}

assert_eq!(priv_key.len(), 1);
assert_eq!(keys.len(), 1);

let cts = sqlx::query!(
"
Expand All @@ -281,15 +282,18 @@ pub async fn decrypt_ciphertexts(
panic!("ciphertext not found");
}

let priv_key = priv_key.pop().unwrap().cks_key.unwrap();
let keys = keys.pop().unwrap();

let mut values = tokio::task::spawn_blocking(move || {
let client_key: tfhe::ClientKey = bincode::deserialize(&priv_key).unwrap();
let client_key: tfhe::ClientKey =
bincode::deserialize(&keys.cks_key.clone().unwrap()).unwrap();
let sks: tfhe::ServerKey = bincode::deserialize(&keys.sks_key).unwrap();
tfhe::set_server_key(sks);

let mut decrypted: Vec<(Vec<u8>, DecryptionResult)> = Vec::with_capacity(cts.len());
for ct in cts {
let deserialized =
deserialize_fhe_ciphertext(ct.ciphertext_type, &ct.ciphertext).unwrap();
SupportedFheCiphertexts::decompress(ct.ciphertext_type, &ct.ciphertext).unwrap();
decrypted.push((
ct.handle,
DecryptionResult {
Expand Down
41 changes: 17 additions & 24 deletions fhevm-engine/coprocessor/src/tfhe_worker.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use crate::utils::set_server_key_if_not_set;
use crate::{db_queries::populate_cache_with_tenant_keys, types::TfheTenantKeys};
use fhevm_engine_common::types::SupportedFheCiphertexts;
use fhevm_engine_common::{
tfhe_ops::{current_ciphertext_version, deserialize_fhe_ciphertext, perform_fhe_operation},
tfhe_ops::{current_ciphertext_version, perform_fhe_operation},
types::SupportedFheOperations,
};
use sqlx::{postgres::PgListener, query, Acquire};
use std::{
cell::Cell,
collections::{BTreeSet, HashMap},
num::NonZeroUsize,
};
Expand Down Expand Up @@ -156,7 +156,8 @@ async fn tfhe_worker_cycle(
let mut work_ciphertexts: Vec<(i16, Vec<u8>)> =
Vec::with_capacity(w.dependencies.len());
for (idx, dh) in w.dependencies.iter().enumerate() {
let is_operand_scalar = w.is_scalar && idx == 1 || fhe_op.does_have_more_than_one_scalar();
let is_operand_scalar =
w.is_scalar && idx == 1 || fhe_op.does_have_more_than_one_scalar();
if is_operand_scalar {
work_ciphertexts.push((-1, dh.clone()));
} else {
Expand All @@ -171,26 +172,20 @@ async fn tfhe_worker_cycle(
// copy for setting error in database
tfhe_work_set.spawn_blocking(
move || -> Result<_, (Box<(dyn std::error::Error + Send + Sync)>, i32, Vec<u8>)> {
thread_local! {
static TFHE_TENANT_ID: Cell<i32> = Cell::new(-1);
}

// set thread tenant key
// set the server key if not set
{
let mut rk = tenant_key_cache.blocking_write();
let keys = rk
.get(&w.tenant_id)
.expect("Can't get tenant key from cache");
if w.tenant_id != TFHE_TENANT_ID.get() {
tfhe::set_server_key(keys.sks.clone());
TFHE_TENANT_ID.set(w.tenant_id);
}
set_server_key_if_not_set(w.tenant_id, &keys.sks);
}

let mut deserialized_cts: Vec<SupportedFheCiphertexts> =
Vec::with_capacity(work_ciphertexts.len());
for (idx, (ct_type, ct_bytes)) in work_ciphertexts.iter().enumerate() {
let is_operand_scalar = w.is_scalar && idx == 1 || fhe_op.does_have_more_than_one_scalar();
let is_operand_scalar =
w.is_scalar && idx == 1 || fhe_op.does_have_more_than_one_scalar();
if is_operand_scalar {
let mut the_int = tfhe::integer::U256::default();
assert!(
Expand All @@ -208,24 +203,22 @@ async fn tfhe_worker_cycle(
deserialized_cts.push(SupportedFheCiphertexts::Scalar(the_int));
} else {
deserialized_cts.push(
deserialize_fhe_ciphertext(*ct_type, ct_bytes.as_slice()).map_err(
|e| {
SupportedFheCiphertexts::decompress(*ct_type, ct_bytes.as_slice())
.map_err(|e| {
let err: Box<dyn std::error::Error + Send + Sync> =
Box::new(e);
e.into();
(err, w.tenant_id, w.output_handle.clone())
},
)?,
})?,
);
}
}

let res =
perform_fhe_operation(w.fhe_operation, &deserialized_cts)
.map_err(|e| {
let err: Box<dyn std::error::Error + Send + Sync> = Box::new(e);
(err, w.tenant_id, w.output_handle.clone())
})?;
let (db_type, db_bytes) = res.serialize();
perform_fhe_operation(w.fhe_operation, &deserialized_cts).map_err(|e| {
let err: Box<dyn std::error::Error + Send + Sync> = Box::new(e);
(err, w.tenant_id, w.output_handle.clone())
})?;
let (db_type, db_bytes) = res.compress();

Ok((w, db_type, db_bytes))
},
Expand Down
16 changes: 15 additions & 1 deletion fhevm-engine/coprocessor/src/utils.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
use std::collections::{BTreeSet, HashMap, HashSet};
use std::{
cell::Cell,
collections::{BTreeSet, HashMap, HashSet},
};

use fhevm_engine_common::types::{FhevmError, SupportedFheOperations};

Expand Down Expand Up @@ -321,3 +324,14 @@ fn test_multi_level_circular_dependency_detection() {
}
}
}

pub fn set_server_key_if_not_set(tenant_id: i32, sks: &tfhe::ServerKey) {
thread_local! {
static TFHE_TENANT_ID: Cell<i32> = Cell::new(-1);
}

if tenant_id != TFHE_TENANT_ID.get() {
tfhe::set_server_key(sks.clone());
TFHE_TENANT_ID.set(tenant_id);
}
}
25 changes: 10 additions & 15 deletions fhevm-engine/executor/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ pub mod executor {

pub fn start(args: &crate::cli::Args) -> Result<()> {
let keys: Arc<FhevmKeys> = Arc::new(SerializedFhevmKeys::load_from_disk().into());
let executor = FhevmExecutorService::new(keys.clone());
let executor = FhevmExecutorService::new();
let runtime = tokio::runtime::Builder::new_multi_thread()
.worker_threads(args.tokio_threads)
.max_blocking_threads(args.fhe_compute_threads)
Expand Down Expand Up @@ -70,24 +70,20 @@ pub struct ComputationState {
pub ciphertexts: HashMap<Handle, InMemoryCiphertext>,
}

struct FhevmExecutorService {
keys: Arc<FhevmKeys>,
}
struct FhevmExecutorService {}

#[tonic::async_trait]
impl FhevmExecutor for FhevmExecutorService {
async fn sync_compute(
&self,
req: Request<SyncComputeRequest>,
) -> Result<Response<SyncComputeResponse>, Status> {
let keys = self.keys.clone();
let resp = spawn_blocking(move || {
let req = req.get_ref();
let mut state = ComputationState::default();

// Exapnd compact ciphertext lists for the whole request.
if Self::expand_compact_lists(&req.compact_ciphertext_lists, &keys, &mut state).is_err()
{
if Self::expand_compact_lists(&req.compact_ciphertext_lists, &mut state).is_err() {
return SyncComputeResponse {
resp: Some(Resp::Error(SyncComputeError::BadInputList.into())),
};
Expand Down Expand Up @@ -138,8 +134,8 @@ impl FhevmExecutor for FhevmExecutorService {
}

impl FhevmExecutorService {
fn new(keys: Arc<FhevmKeys>) -> Self {
FhevmExecutorService { keys }
fn new() -> Self {
FhevmExecutorService {}
}

#[allow(dead_code)]
Expand All @@ -166,11 +162,10 @@ impl FhevmExecutorService {

fn expand_compact_lists(
lists: &Vec<Vec<u8>>,
keys: &FhevmKeys,
state: &mut ComputationState,
) -> Result<(), FhevmError> {
for list in lists {
let cts = try_expand_ciphertext_list(&list, &keys.server_key)?;
let cts = try_expand_ciphertext_list(&list)?;
let list_hash: Handle = Keccak256::digest(list).to_vec();
for (i, ct) in cts.iter().enumerate() {
let mut handle = list_hash.clone();
Expand All @@ -181,7 +176,7 @@ impl FhevmExecutorService {
handle,
InMemoryCiphertext {
expanded: ct.clone(),
compressed: ct.clone().compress(),
compressed: ct.clone().compress().1,
},
);
}
Expand Down Expand Up @@ -268,7 +263,7 @@ impl FhevmExecutorService {
match inputs {
Ok(inputs) => match perform_fhe_operation(comp.operation as i16, &inputs) {
Ok(result) => {
let compressed = result.clone().compress();
let (_, compressed) = result.clone().compress();
state.ciphertexts.insert(
result_handle.clone(),
InMemoryCiphertext {
Expand Down Expand Up @@ -299,15 +294,15 @@ pub fn run_computation(
Ok(FheOperation::FheGetCiphertext) => {
let res = InMemoryCiphertext {
expanded: inputs[0].clone(),
compressed: inputs[0].clone().compress(),
compressed: inputs[0].clone().compress().1,
};
Ok((graph_node_index, res))
}
Ok(_) => match perform_fhe_operation(operation as i16, &inputs) {
Ok(result) => {
let res = InMemoryCiphertext {
expanded: result.clone(),
compressed: result.compress(),
compressed: result.compress().1,
};
Ok((graph_node_index, res))
}
Expand Down
Loading