Skip to content

Commit 57730c5

Browse files
committed
fabric: Add batch sharing methods
1 parent 503e668 commit 57730c5

File tree

5 files changed

+44
-8
lines changed

5 files changed

+44
-8
lines changed

src/algebra/authenticated_scalar.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,10 +117,13 @@ impl AuthenticatedScalarResult {
117117
values: BatchScalarResult,
118118
n: usize,
119119
) -> Vec<AuthenticatedScalarResult> {
120-
// Convert to a set of scalar results, the identity gate does this when set to `n` output arity
120+
// Convert to a set of scalar results
121121
let scalar_results = values
122122
.fabric()
123-
.new_batch_gate_op(vec![values.id()], n, |args| args);
123+
.new_batch_gate_op(vec![values.id()], n, |mut args| {
124+
let scalars: Vec<Scalar> = args.pop().unwrap().into();
125+
scalars.into_iter().map(ResultValue::Scalar).collect()
126+
});
124127

125128
Self::new_shared_batch(&scalar_results)
126129
}

src/algebra/authenticated_stark_point.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,13 @@ impl AuthenticatedStarkPointResult {
115115
values: BatchStarkPointResult,
116116
n: usize,
117117
) -> Vec<AuthenticatedStarkPointResult> {
118-
// Convert to a set of scalar results, the identity gate does this when set to `n` output arity
118+
// Convert to a set of scalar results
119119
let scalar_results = values
120120
.fabric()
121-
.new_batch_gate_op(vec![values.id()], n, |args| args);
121+
.new_batch_gate_op(vec![values.id()], n, |mut args| {
122+
let args: Vec<StarkPoint> = args.pop().unwrap().into();
123+
args.into_iter().map(ResultValue::Point).collect_vec()
124+
});
122125

123126
Self::new_shared_batch(&scalar_results)
124127
}

src/algebra/stark_curve.rs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -529,8 +529,10 @@ impl StarkPointResult {
529529
.chain(a.iter().flat_map(|a| a.ids()))
530530
.collect_vec();
531531

532-
let results =
533-
fabric.new_batch_gate_op(all_ids, n /* output_arity */, move |mut args| {
532+
let results = fabric.new_batch_gate_op(
533+
all_ids,
534+
AUTHENTICATED_STARK_POINT_RESULT_LEN * n, /* output_arity */
535+
move |mut args| {
534536
let points: Vec<StarkPoint> = args.drain(..n).map(StarkPoint::from).collect_vec();
535537

536538
let mut results = Vec::with_capacity(AUTHENTICATED_STARK_POINT_RESULT_LEN * n);
@@ -549,7 +551,8 @@ impl StarkPointResult {
549551
}
550552

551553
results
552-
});
554+
},
555+
);
553556

554557
AuthenticatedStarkPointResult::from_flattened_iterator(results.into_iter())
555558
}

src/fabric.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1064,4 +1064,30 @@ impl MpcFabric {
10641064
let (left, right) = authenticated_left_right.split_at(n);
10651065
(left.to_vec(), right.to_vec())
10661066
}
1067+
1068+
/// Sample a random shared bit from the beaver source
1069+
pub fn random_shared_bit(&self) -> AuthenticatedScalarResult {
1070+
let bit = self
1071+
.inner
1072+
.beaver_source
1073+
.lock()
1074+
.expect("beaver source poisoned")
1075+
.next_shared_bit();
1076+
1077+
let bit = self.allocate_scalar(bit);
1078+
AuthenticatedScalarResult::new_shared(bit)
1079+
}
1080+
1081+
/// Sample a batch of random shared bits from the beaver source
1082+
pub fn random_shared_bits(&self, n: usize) -> Vec<AuthenticatedScalarResult> {
1083+
let bits = self
1084+
.inner
1085+
.beaver_source
1086+
.lock()
1087+
.expect("beaver source poisoned")
1088+
.next_shared_bit_batch(n);
1089+
1090+
let bits = self.allocate_scalars(bits);
1091+
AuthenticatedScalarResult::new_shared_batch(&bits)
1092+
}
10671093
}

src/fabric/result.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,9 @@ impl From<ResultValue> for NetworkPayload {
6565
match value {
6666
ResultValue::Bytes(bytes) => NetworkPayload::Bytes(bytes),
6767
ResultValue::Scalar(scalar) => NetworkPayload::Scalar(scalar),
68+
ResultValue::ScalarBatch(scalars) => NetworkPayload::ScalarBatch(scalars),
6869
ResultValue::Point(point) => NetworkPayload::Point(point),
69-
_ => panic!("not a valid network payload type"),
70+
ResultValue::PointBatch(points) => NetworkPayload::PointBatch(points),
7071
}
7172
}
7273
}

0 commit comments

Comments
 (0)