Skip to content

Commit 103607c

Browse files
huitseekerkevinlewi
authored andcommitted
Uses conditional compilation to make sure postprocessing is only available in tests
See rust-lang/rust#64010
1 parent 8d29f72 commit 103607c

File tree

3 files changed

+68
-72
lines changed

3 files changed

+68
-72
lines changed

src/opaque.rs

Lines changed: 22 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -561,37 +561,25 @@ impl<CS: CipherSuite> ClientRegistration<CS> {
561561
&Vec::new(),
562562
password,
563563
blinding_factor_rng,
564-
)
565-
}
566-
567-
/// Same as ClientRegistration::start, but also accepts a username and server name as input
568-
pub fn start_with_user_and_server_name<R: RngCore + CryptoRng>(
569-
user_name: &[u8],
570-
server_name: &[u8],
571-
password: &[u8],
572-
blinding_factor_rng: &mut R,
573-
) -> Result<(RegisterFirstMessage<CS::Group>, Self), ProtocolError> {
574-
Self::start_with_user_and_server_name_and_postprocessing(
575-
user_name,
576-
server_name,
577-
password,
578-
blinding_factor_rng,
564+
#[cfg(test)]
579565
std::convert::identity,
580566
)
581567
}
582568

583-
/// Same as ClientRegistration::start, but also accepts a username and server name as input as well as
584-
/// an optional postprocessing function for the blinding factor
585-
pub fn start_with_user_and_server_name_and_postprocessing<R: RngCore + CryptoRng>(
569+
/// Same as ClientRegistration::start, but also accepts a username and
570+
/// server name as input
571+
/// as well as an optional postprocessing function for the blinding factor(used in tests)
572+
pub fn start_with_user_and_server_name<R: RngCore + CryptoRng>(
586573
user_name: &[u8],
587574
server_name: &[u8],
588575
password: &[u8],
589576
blinding_factor_rng: &mut R,
590-
postprocess: fn(<CS::Group as Group>::Scalar) -> <CS::Group as Group>::Scalar,
577+
#[cfg(test)] postprocess: fn(<CS::Group as Group>::Scalar) -> <CS::Group as Group>::Scalar,
591578
) -> Result<(RegisterFirstMessage<CS::Group>, Self), ProtocolError> {
592-
let (token, alpha) = oprf::blind_with_postprocessing::<R, CS::Group>(
579+
let (token, alpha) = oprf::blind::<R, CS::Group>(
593580
&password,
594581
blinding_factor_rng,
582+
#[cfg(test)]
595583
postprocess,
596584
)?;
597585

@@ -1018,35 +1006,31 @@ impl<CS: CipherSuite> ClientLogin<CS> {
10181006
password: &[u8],
10191007
rng: &mut R,
10201008
) -> Result<(LoginFirstMessage<CS>, Self), ProtocolError> {
1021-
Self::start_with_user_and_server_name(&Vec::new(), &Vec::new(), password, rng)
1022-
}
1023-
1024-
/// Same as start, but allows the user to supply a username and server name
1025-
pub fn start_with_user_and_server_name<R: RngCore + CryptoRng>(
1026-
user_name: &[u8],
1027-
server_name: &[u8],
1028-
password: &[u8],
1029-
rng: &mut R,
1030-
) -> Result<(LoginFirstMessage<CS>, Self), ProtocolError> {
1031-
Self::start_with_user_and_server_name_and_postprocessing(
1032-
user_name,
1033-
server_name,
1009+
Self::start_with_user_and_server_name(
1010+
&Vec::new(),
1011+
&Vec::new(),
10341012
password,
10351013
rng,
1014+
#[cfg(test)]
10361015
std::convert::identity,
10371016
)
10381017
}
10391018

1040-
/// Same as start, but allows the user to supply a username and server name and postprocessing function
1041-
pub fn start_with_user_and_server_name_and_postprocessing<R: RngCore + CryptoRng>(
1019+
/// Same as start, but allows the user to supply a username and server name
1020+
/// and, in tests, a postprocessing function
1021+
pub fn start_with_user_and_server_name<R: RngCore + CryptoRng>(
10421022
user_name: &[u8],
10431023
server_name: &[u8],
10441024
password: &[u8],
10451025
rng: &mut R,
1046-
postprocess: fn(<CS::Group as Group>::Scalar) -> <CS::Group as Group>::Scalar,
1026+
#[cfg(test)] postprocess: fn(<CS::Group as Group>::Scalar) -> <CS::Group as Group>::Scalar,
10471027
) -> Result<(LoginFirstMessage<CS>, Self), ProtocolError> {
1048-
let (token, alpha) =
1049-
oprf::blind_with_postprocessing::<R, CS::Group>(&password, rng, postprocess)?;
1028+
let (token, alpha) = oprf::blind::<R, CS::Group>(
1029+
&password,
1030+
rng,
1031+
#[cfg(test)]
1032+
postprocess,
1033+
)?;
10501034

10511035
let (ke1_state, ke1_message) = CS::KeyExchange::generate_ke1(alpha.to_arr().to_vec(), rng)?;
10521036

src/oprf.rs

Lines changed: 27 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -23,14 +23,18 @@ static STR_VOPRF: &[u8] = b"VOPRF05";
2323
/// message is sent from the client (who holds the input) to the server (who holds the OPRF key).
2424
/// The client can also pass in an optional "pepper" string to be mixed in with the input through
2525
/// an HKDF computation.
26-
pub(crate) fn blind_with_postprocessing<R: RngCore + CryptoRng, G: GroupWithMapToCurve>(
26+
pub(crate) fn blind<R: RngCore + CryptoRng, G: GroupWithMapToCurve>(
2727
input: &[u8],
2828
blinding_factor_rng: &mut R,
29-
postprocess: fn(G::Scalar) -> G::Scalar,
29+
#[cfg(test)] postprocess: fn(G::Scalar) -> G::Scalar,
3030
) -> Result<(Token<G>, G), InternalPakeError> {
3131
let mapped_point = G::map_to_curve(input, Some(STR_VOPRF)); // TODO: add contextString from RFC
3232
let blinding_factor = G::random_scalar(blinding_factor_rng);
33+
#[cfg(test)]
3334
let blind = postprocess(blinding_factor);
35+
#[cfg(not(test))]
36+
let blind = blinding_factor;
37+
3438
let blind_token = mapped_point * &blind;
3539
Ok((
3640
Token {
@@ -60,23 +64,34 @@ pub(crate) fn unblind_and_finalize<G: Group, H: Hash>(
6064
Ok(prk)
6165
}
6266

63-
// Benchmarking shims
67+
////////////////////////
68+
// Benchmarking shims //
69+
////////////////////////
70+
6471
#[cfg(feature = "bench")]
72+
#[doc(hidden)]
6573
#[inline]
6674
pub fn blind_shim<R: RngCore + CryptoRng, G: GroupWithMapToCurve>(
6775
input: &[u8],
6876
blinding_factor_rng: &mut R,
6977
) -> Result<(Token<G>, G), InternalPakeError> {
70-
blind_with_postprocessing(input, blinding_factor_rng, std::convert::identity)
78+
blind(
79+
input,
80+
blinding_factor_rng,
81+
#[cfg(test)]
82+
std::convert::identity,
83+
)
7184
}
7285

7386
#[cfg(feature = "bench")]
87+
#[doc(hidden)]
7488
#[inline]
7589
pub fn evaluate_shim<G: Group>(point: G, oprf_key: &G::Scalar) -> Result<G, InternalPakeError> {
7690
evaluate(point, oprf_key)
7791
}
7892

7993
#[cfg(feature = "bench")]
94+
#[doc(hidden)]
8095
#[inline]
8196
pub fn unblind_and_finalize_shim<G: Group, H: Hash>(
8297
token: &Token<G>,
@@ -85,8 +100,10 @@ pub fn unblind_and_finalize_shim<G: Group, H: Hash>(
85100
unblind_and_finalize::<G, H>(token, point)
86101
}
87102

88-
// Tests
89-
// =====
103+
///////////
104+
// Tests //
105+
// ===== //
106+
///////////
90107

91108
#[cfg(test)]
92109
mod tests {
@@ -117,11 +134,8 @@ mod tests {
117134
fn oprf_retrieval() -> Result<(), InternalPakeError> {
118135
let input = b"hunter2";
119136
let mut rng = OsRng;
120-
let (token, alpha) = blind_with_postprocessing::<_, RistrettoPoint>(
121-
&input[..],
122-
&mut rng,
123-
std::convert::identity,
124-
)?;
137+
let (token, alpha) =
138+
blind::<_, RistrettoPoint>(&input[..], &mut rng, std::convert::identity)?;
125139
let oprf_key_bytes = arr![
126140
u8; 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23,
127141
24, 25, 26, 27, 28, 29, 30, 31, 32,
@@ -139,12 +153,8 @@ mod tests {
139153
let mut rng = OsRng;
140154
let mut input = vec![0u8; 64];
141155
rng.fill_bytes(&mut input);
142-
let (token, alpha) = blind_with_postprocessing::<_, RistrettoPoint>(
143-
&input,
144-
&mut rng,
145-
std::convert::identity,
146-
)
147-
.unwrap();
156+
let (token, alpha) =
157+
blind::<_, RistrettoPoint>(&input, &mut rng, std::convert::identity).unwrap();
148158
let res = unblind_and_finalize::<RistrettoPoint, sha2::Sha256>(&token, alpha).unwrap();
149159

150160
let (hashed_input, _) = Hkdf::<Sha512>::extract(Some(STR_VOPRF), &input);

src/tests/opaque_ke_test.rs

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@ where
255255
id_s,
256256
password,
257257
&mut blinding_factor_registration_rng,
258+
std::convert::identity,
258259
)
259260
.unwrap();
260261
let r1_bytes = r1.serialize().to_vec();
@@ -292,6 +293,7 @@ where
292293
id_s,
293294
password,
294295
&mut client_login_start_rng,
296+
std::convert::identity,
295297
)
296298
.unwrap();
297299
let l1_bytes = l1.serialize().to_vec();
@@ -363,14 +365,15 @@ fn postprocess_blinding_factor<G: Group>(_: G::Scalar) -> G::Scalar {
363365
fn test_r1() -> Result<(), PakeError> {
364366
let parameters = populate_test_vectors(&serde_json::from_str(TEST_VECTOR).unwrap());
365367
let mut rng = OsRng;
366-
let (r1, client_registration) = ClientRegistration::<X255193dhNoSlowHash>::start_with_user_and_server_name_and_postprocessing(
367-
&parameters.id_u,
368-
&parameters.id_s,
369-
&parameters.password,
370-
&mut rng,
371-
postprocess_blinding_factor::<<X255193dhNoSlowHash as CipherSuite>::Group>,
372-
)
373-
.unwrap();
368+
let (r1, client_registration) =
369+
ClientRegistration::<X255193dhNoSlowHash>::start_with_user_and_server_name(
370+
&parameters.id_u,
371+
&parameters.id_s,
372+
&parameters.password,
373+
&mut rng,
374+
postprocess_blinding_factor::<<X255193dhNoSlowHash as CipherSuite>::Group>,
375+
)
376+
.unwrap();
374377
assert_eq!(hex::encode(&parameters.r1), hex::encode(r1.serialize()));
375378
assert_eq!(
376379
hex::encode(&parameters.client_registration_state),
@@ -453,15 +456,14 @@ fn test_l1() -> Result<(), PakeError> {
453456
]
454457
.concat();
455458
let mut client_login_start_rng = CycleRng::new(client_login_start);
456-
let (l1, client_login) =
457-
ClientLogin::<X255193dhNoSlowHash>::start_with_user_and_server_name_and_postprocessing(
458-
&parameters.id_u,
459-
&parameters.id_s,
460-
&parameters.password,
461-
&mut client_login_start_rng,
462-
postprocess_blinding_factor::<<X255193dhNoSlowHash as CipherSuite>::Group>,
463-
)
464-
.unwrap();
459+
let (l1, client_login) = ClientLogin::<X255193dhNoSlowHash>::start_with_user_and_server_name(
460+
&parameters.id_u,
461+
&parameters.id_s,
462+
&parameters.password,
463+
&mut client_login_start_rng,
464+
postprocess_blinding_factor::<<X255193dhNoSlowHash as CipherSuite>::Group>,
465+
)
466+
.unwrap();
465467
assert_eq!(hex::encode(&parameters.l1), hex::encode(l1.serialize()));
466468
assert_eq!(
467469
hex::encode(&parameters.client_login_state),

0 commit comments

Comments
 (0)