Skip to content

Commit aae4236

Browse files
committed
network: mock: Add mock that allows unit testing MPC circuits
1 parent 57730c5 commit aae4236

File tree

5 files changed

+184
-20
lines changed

5 files changed

+184
-20
lines changed

src/algebra/authenticated_scalar.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,7 @@ impl Sub<&AuthenticatedScalarResult> for &AuthenticatedScalarResult {
664664
AuthenticatedScalarResult {
665665
share: &self.share - &rhs.share,
666666
mac: &self.mac - &rhs.mac,
667-
public_modifier: self.public_modifier.clone(),
667+
public_modifier: self.public_modifier.clone() - rhs.public_modifier.clone(),
668668
}
669669
}
670670
}
@@ -1029,3 +1029,23 @@ pub mod test_helpers {
10291029
val.public_modifier = val.fabric().allocate_scalar(new_value)
10301030
}
10311031
}
1032+
1033+
#[cfg(test)]
1034+
mod tests {
1035+
use crate::{algebra::scalar::Scalar, test_helpers::execute_mock_mpc};
1036+
1037+
/// Test a simple `XOR` circuit
1038+
#[tokio::test]
1039+
async fn test_xor_circuit() {
1040+
let (res, _) = execute_mock_mpc(|fabric| async move {
1041+
let a = &fabric.zero_authenticated();
1042+
let b = &fabric.zero_authenticated();
1043+
let res = a + b - Scalar::from(2u64) * a * b;
1044+
1045+
res.open_authenticated().await
1046+
})
1047+
.await;
1048+
1049+
assert_eq!(res.unwrap(), 0.into());
1050+
}
1051+
}

src/beaver.rs

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -56,37 +56,46 @@ pub trait SharedValueSource: Send + Sync {
5656
(a_vals, b_vals, c_vals)
5757
}
5858
}
59-
60-
/// A dummy value source that outputs only ones
61-
/// Used for testing
59+
/// An implementation of a beaver value source that returns
60+
/// beaver triples (0, 0, 0) for party 0 and (1, 1, 1) for party 1
6261
#[cfg(any(feature = "test_helpers", test))]
6362
#[derive(Clone, Debug, Default)]
64-
pub struct DummySharedScalarSource;
63+
pub struct PartyIDBeaverSource {
64+
/// The ID of the local party
65+
party_id: u64,
66+
}
6567

6668
#[cfg(any(feature = "test_helpers", test))]
67-
#[allow(dead_code)]
68-
impl DummySharedScalarSource {
69-
/// Constructor
70-
pub fn new() -> Self {
71-
Self
69+
impl PartyIDBeaverSource {
70+
/// Create a new beaver source given the local party_id
71+
pub fn new(party_id: u64) -> Self {
72+
Self { party_id }
7273
}
7374
}
7475

76+
/// The PartyIDBeaverSource returns beaver triplets split statically between the
77+
/// parties. We assume a = 2, b = 3 ==> c = 6. [a] = (1, 1); [b] = (3, 0) [c] = (2, 4)
7578
#[cfg(any(feature = "test_helpers", test))]
76-
impl SharedValueSource for DummySharedScalarSource {
79+
impl SharedValueSource for PartyIDBeaverSource {
7780
fn next_shared_bit(&mut self) -> Scalar {
78-
Scalar::one()
81+
// Simply output partyID, assume partyID \in {0, 1}
82+
assert!(self.party_id == 0 || self.party_id == 1);
83+
Scalar::from(self.party_id)
7984
}
8085

81-
fn next_shared_value(&mut self) -> Scalar {
82-
Scalar::one()
86+
fn next_triplet(&mut self) -> (Scalar, Scalar, Scalar) {
87+
if self.party_id == 0 {
88+
(Scalar::from(1u64), Scalar::from(3u64), Scalar::from(2u64))
89+
} else {
90+
(Scalar::from(1u64), Scalar::from(0u64), Scalar::from(4u64))
91+
}
8392
}
8493

8594
fn next_shared_inverse_pair(&mut self) -> (Scalar, Scalar) {
86-
(Scalar::one(), Scalar::one())
95+
(Scalar::from(self.party_id), Scalar::from(self.party_id))
8796
}
8897

89-
fn next_triplet(&mut self) -> (Scalar, Scalar, Scalar) {
90-
(Scalar::one(), Scalar::one(), Scalar::one())
98+
fn next_shared_value(&mut self) -> Scalar {
99+
Scalar::from(self.party_id)
91100
}
92101
}

src/lib.rs

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,57 @@ pub type BeaverSource<S: SharedValueSource> = Rc<RefCell<S>>;
6767

6868
#[cfg(test)]
6969
pub(crate) mod test_helpers {
70-
use crate::{beaver::DummySharedScalarSource, network::NoRecvNetwork, MpcFabric};
70+
use futures::Future;
71+
72+
use crate::{
73+
beaver::PartyIDBeaverSource,
74+
network::{MockNetwork, NoRecvNetwork, UnboundedDuplexStream},
75+
MpcFabric, PARTY0, PARTY1,
76+
};
7177

7278
/// Create a mock fabric
7379
pub fn mock_fabric() -> MpcFabric {
7480
let network = NoRecvNetwork::default();
75-
let beaver_source = DummySharedScalarSource::new();
81+
let beaver_source = PartyIDBeaverSource::default();
7682

7783
MpcFabric::new(network, beaver_source)
7884
}
85+
86+
/// Run a mock MPC connected by a duplex stream as the mock network
87+
///
88+
/// This will spawn two tasks to execute either side of the MPC
89+
///
90+
/// Returns the outputs of both parties
91+
pub async fn execute_mock_mpc<T, S, F>(mut f: F) -> (T, T)
92+
where
93+
T: Send + 'static,
94+
S: Future<Output = T> + Send + 'static,
95+
F: FnMut(MpcFabric) -> S,
96+
{
97+
// Build a duplex stream to broker communication between the two parties
98+
let (party0_stream, party1_stream) = UnboundedDuplexStream::new_duplex_pair();
99+
let party0_fabric = MpcFabric::new(
100+
MockNetwork::new(PARTY0, party0_stream),
101+
PartyIDBeaverSource::new(PARTY0),
102+
);
103+
let party1_fabric = MpcFabric::new(
104+
MockNetwork::new(PARTY1, party1_stream),
105+
PartyIDBeaverSource::new(PARTY1),
106+
);
107+
108+
// Spawn two tasks to execute the MPC
109+
let fabric0 = party0_fabric.clone();
110+
let fabric1 = party1_fabric.clone();
111+
let party0_task = tokio::spawn(f(fabric0));
112+
let party1_task = tokio::spawn(f(fabric1));
113+
114+
let party0_output = party0_task.await.unwrap();
115+
let party1_output = party1_task.await.unwrap();
116+
117+
// Shutdown the fabrics
118+
party0_fabric.shutdown();
119+
party1_fabric.shutdown();
120+
121+
(party0_output, party1_output)
122+
}
79123
}

src/network.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ mod config;
55
mod mock;
66

77
#[cfg(any(feature = "test_helpers", test))]
8-
pub use mock::NoRecvNetwork;
8+
pub use mock::{NoRecvNetwork, UnboundedDuplexStream, MockNetwork};
99

1010
use async_trait::async_trait;
1111
use quinn::{Endpoint, RecvStream, SendStream};

src/network/mock.rs

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
33
use async_trait::async_trait;
44
use futures::future::pending;
5+
use tokio::sync::mpsc::{unbounded_channel, UnboundedReceiver, UnboundedSender};
56

67
use crate::{error::MpcNetworkError, PARTY0};
78

@@ -36,3 +37,93 @@ impl MpcNetwork for NoRecvNetwork {
3637
Ok(())
3738
}
3839
}
40+
41+
/// A dummy MPC network that operates over a duplex channel instead of a network connection/// An unbounded duplex channel used to mock a network connection
42+
pub struct UnboundedDuplexStream {
43+
/// The send side of the stream
44+
send: UnboundedSender<NetworkOutbound>,
45+
/// The receive side of the stream
46+
recv: UnboundedReceiver<NetworkOutbound>,
47+
}
48+
49+
impl UnboundedDuplexStream {
50+
/// Create a new pair of duplex streams
51+
pub fn new_duplex_pair() -> (Self, Self) {
52+
let (send1, recv1) = unbounded_channel();
53+
let (send2, recv2) = unbounded_channel();
54+
55+
(
56+
Self {
57+
send: send1,
58+
recv: recv2,
59+
},
60+
Self {
61+
send: send2,
62+
recv: recv1,
63+
},
64+
)
65+
}
66+
67+
/// Send a message on the stream
68+
pub fn send(&mut self, msg: NetworkOutbound) {
69+
self.send.send(msg).unwrap();
70+
}
71+
72+
/// Recv a message from the stream
73+
pub async fn recv(&mut self) -> NetworkOutbound {
74+
self.recv.recv().await.unwrap()
75+
}
76+
}
77+
78+
/// A dummy network implementation used for unit testing
79+
pub struct MockNetwork {
80+
/// The ID of the local party
81+
party_id: PartyId,
82+
/// The underlying mock network connection
83+
mock_conn: UnboundedDuplexStream,
84+
}
85+
86+
impl MockNetwork {
87+
/// Create a new mock network from one half of a duplex stream
88+
pub fn new(party_id: PartyId, stream: UnboundedDuplexStream) -> Self {
89+
Self {
90+
party_id,
91+
mock_conn: stream,
92+
}
93+
}
94+
}
95+
96+
#[async_trait]
97+
impl MpcNetwork for MockNetwork {
98+
fn party_id(&self) -> PartyId {
99+
self.party_id
100+
}
101+
102+
async fn send_message(&mut self, message: NetworkOutbound) -> Result<(), MpcNetworkError> {
103+
self.mock_conn.send(message);
104+
Ok(())
105+
}
106+
107+
async fn receive_message(&mut self) -> Result<NetworkOutbound, MpcNetworkError> {
108+
let msg = self.mock_conn.recv().await;
109+
Ok(msg)
110+
}
111+
112+
async fn exchange_messages(
113+
&mut self,
114+
message: NetworkOutbound,
115+
) -> Result<NetworkOutbound, MpcNetworkError> {
116+
if self.party_id() == PARTY0 {
117+
self.send_message(message).await?;
118+
self.receive_message().await
119+
} else {
120+
let res = self.receive_message().await?;
121+
self.send_message(message).await?;
122+
Ok(res)
123+
}
124+
}
125+
126+
async fn close(&mut self) -> Result<(), MpcNetworkError> {
127+
Ok(())
128+
}
129+
}

0 commit comments

Comments
 (0)