Skip to content

Commit 0eab8e2

Browse files
authored
hash2curve: move oversized DST requirements to runtime errors (#1901)
This PR introduces two changes: - Remove requirements that were only relevant for oversized `DST`s. Now these requirements are checked on runtime. While this is unfortunate, the currently limitations simply was that usage with regular sized `DST`s incurred limitations that were not necessary. - Change `len_in_bytes` from `NonZero<usize>` to `NonZero<u16>`. This isn't a big improvement because the error is just moved from `expand_msg()` to the various `GroupDigest` methods. Companion PR: RustCrypto/elliptic-curves#1256. --- I know I have been refactoring this API over and over again, but I actually think this is the last of it (apart from #872 with `generic_const_exprs`). But for completions sake I want to mention the following [from the spec](https://www.rfc-editor.org/rfc/rfc9380.html#section-5.3.1-8): > It is possible, however, to entirely avoid this overhead by taking advantage of the fact that Z_pad depends only on H, and not on the arguments to expand_message_xmd. To do so, first precompute and save the internal state of H after ingesting Z_pad. Then, when computing b_0, initialize H using the saved state. Further details are implementation dependent and are beyond the scope of this document. In summary, we could cache this part: ```rust let mut b_0 = HashT::default(); b_0.update(&Array::<u8, HashT::BlockSize>::default()); ``` Doing this requires passing `ExpandMsg` state, which would change the entire API having to add a parameter to every function. However, as the spec mentions, the cost of not caching it is most likely negligible. We will see in the future if this shows up in benchmarks and if it does we can re-evaluate. I don't believe this will be the case though. Alternatively, we could add a trait to `digest` which allows users to construct a hash prefixed with a `BlockSize` full of zeros that has been computed at compile-time. Which would also require no changes to the API except binding to this trait.
1 parent f24c2ae commit 0eab8e2

File tree

5 files changed

+59
-62
lines changed

5 files changed

+59
-62
lines changed

elliptic-curve/src/hash2curve/group_digest.rs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@ pub trait GroupDigest: MapToCurve {
2323
/// > hash function is used.
2424
///
2525
/// # Errors
26-
/// See implementors of [`ExpandMsg`] for errors:
27-
/// - [`ExpandMsgXmd`]
28-
/// - [`ExpandMsgXof`]
26+
/// - `len_in_bytes > u16::MAX`
27+
/// - See implementors of [`ExpandMsg`] for additional errors:
28+
/// - [`ExpandMsgXmd`]
29+
/// - [`ExpandMsgXof`]
2930
///
3031
/// `len_in_bytes = <Self::FieldElement as FromOkm>::Length * 2`
3132
///
@@ -53,9 +54,10 @@ pub trait GroupDigest: MapToCurve {
5354
/// > points in this set are more likely to be output than others.
5455
///
5556
/// # Errors
56-
/// See implementors of [`ExpandMsg`] for errors:
57-
/// - [`ExpandMsgXmd`]
58-
/// - [`ExpandMsgXof`]
57+
/// - `len_in_bytes > u16::MAX`
58+
/// - See implementors of [`ExpandMsg`] for additional errors:
59+
/// - [`ExpandMsgXmd`]
60+
/// - [`ExpandMsgXof`]
5961
///
6062
/// `len_in_bytes = <Self::FieldElement as FromOkm>::Length`
6163
///
@@ -76,9 +78,10 @@ pub trait GroupDigest: MapToCurve {
7678
/// and returns a scalar.
7779
///
7880
/// # Errors
79-
/// See implementors of [`ExpandMsg`] for errors:
80-
/// - [`ExpandMsgXmd`]
81-
/// - [`ExpandMsgXof`]
81+
/// - `len_in_bytes > u16::MAX`
82+
/// - See implementors of [`ExpandMsg`] for additional errors:
83+
/// - [`ExpandMsgXmd`]
84+
/// - [`ExpandMsgXof`]
8285
///
8386
/// `len_in_bytes = <Self::Scalar as FromOkm>::Length`
8487
///

elliptic-curve/src/hash2curve/hash2field.rs

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
55
mod expand_msg;
66

7-
use core::num::NonZeroUsize;
7+
use core::num::NonZeroU16;
88

99
pub use expand_msg::{xmd::*, xof::*, *};
1010

@@ -28,9 +28,10 @@ pub trait FromOkm {
2828
/// <https://www.rfc-editor.org/rfc/rfc9380.html#name-hash_to_field-implementatio>
2929
///
3030
/// # Errors
31-
/// See implementors of [`ExpandMsg`] for errors:
32-
/// - [`ExpandMsgXmd`]
33-
/// - [`ExpandMsgXof`]
31+
/// - `len_in_bytes > u16::MAX`
32+
/// - See implementors of [`ExpandMsg`] for additional errors:
33+
/// - [`ExpandMsgXmd`]
34+
/// - [`ExpandMsgXof`]
3435
///
3536
/// `len_in_bytes = T::Length * out.len()`
3637
///
@@ -42,9 +43,10 @@ where
4243
E: ExpandMsg<K>,
4344
T: FromOkm + Default,
4445
{
45-
let len_in_bytes = T::Length::to_usize()
46+
let len_in_bytes = T::Length::USIZE
4647
.checked_mul(out.len())
47-
.and_then(NonZeroUsize::new)
48+
.and_then(|len| len.try_into().ok())
49+
.and_then(NonZeroU16::new)
4850
.ok_or(Error)?;
4951
let mut tmp = Array::<u8, <T as FromOkm>::Length>::default();
5052
let mut expander = E::expand_message(data, domain, len_in_bytes)?;

elliptic-curve/src/hash2curve/hash2field/expand_msg.rs

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ use core::num::NonZero;
77

88
use crate::{Error, Result};
99
use digest::{Digest, ExtendableOutput, Update, XofReader};
10-
use hybrid_array::typenum::{IsLess, True, U256};
1110
use hybrid_array::{Array, ArraySize};
1211

1312
/// Salt when the DST is too long
@@ -34,7 +33,7 @@ pub trait ExpandMsg<K> {
3433
fn expand_message<'dst>(
3534
msg: &[&[u8]],
3635
dst: &'dst [&[u8]],
37-
len_in_bytes: NonZero<usize>,
36+
len_in_bytes: NonZero<u16>,
3837
) -> Result<Self::Expander<'dst>>;
3938
}
4039

@@ -50,20 +49,14 @@ pub trait Expander {
5049
///
5150
/// [dst]: https://www.rfc-editor.org/rfc/rfc9380.html#name-using-dsts-longer-than-255-
5251
#[derive(Debug)]
53-
pub(crate) enum Domain<'a, L>
54-
where
55-
L: ArraySize + IsLess<U256, Output = True>,
56-
{
52+
pub(crate) enum Domain<'a, L: ArraySize> {
5753
/// > 255
5854
Hashed(Array<u8, L>),
5955
/// <= 255
6056
Array(&'a [&'a [u8]]),
6157
}
6258

63-
impl<'a, L> Domain<'a, L>
64-
where
65-
L: ArraySize + IsLess<U256, Output = True>,
66-
{
59+
impl<'a, L: ArraySize> Domain<'a, L> {
6760
pub fn xof<X>(dst: &'a [&'a [u8]]) -> Result<Self>
6861
where
6962
X: Default + ExtendableOutput + Update,
@@ -72,6 +65,10 @@ where
7265
if dst.iter().map(|slice| slice.len()).sum::<usize>() == 0 {
7366
Err(Error)
7467
} else if dst.iter().map(|slice| slice.len()).sum::<usize>() > MAX_DST_LEN {
68+
if L::USIZE > u8::MAX.into() {
69+
return Err(Error);
70+
}
71+
7572
let mut data = Array::<u8, L>::default();
7673
let mut hash = X::default();
7774
hash.update(OVERSIZE_DST_SALT);
@@ -96,6 +93,10 @@ where
9693
if dst.iter().map(|slice| slice.len()).sum::<usize>() == 0 {
9794
Err(Error)
9895
} else if dst.iter().map(|slice| slice.len()).sum::<usize>() > MAX_DST_LEN {
96+
if L::USIZE > u8::MAX.into() {
97+
return Err(Error);
98+
}
99+
99100
Ok(Self::Hashed({
100101
let mut hash = X::new();
101102
hash.update(OVERSIZE_DST_SALT);
@@ -124,8 +125,8 @@ where
124125

125126
pub fn len(&self) -> u8 {
126127
match self {
127-
// Can't overflow because it's enforced on a type level.
128-
Self::Hashed(_) => L::to_u8(),
128+
// Can't overflow because it's checked on creation.
129+
Self::Hashed(_) => L::U8,
129130
// Can't overflow because it's checked on creation.
130131
Self::Array(d) => {
131132
u8::try_from(d.iter().map(|d| d.len()).sum::<usize>()).expect("length overflow")

elliptic-curve/src/hash2curve/hash2field/expand_msg/xmd.rs

Lines changed: 15 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use digest::{
88
FixedOutput, HashMarker,
99
array::{
1010
Array,
11-
typenum::{IsGreaterOrEqual, IsLess, IsLessOrEqual, Prod, True, U2, U256, Unsigned},
11+
typenum::{IsGreaterOrEqual, IsLessOrEqual, Prod, True, U2, Unsigned},
1212
},
1313
block_api::BlockSizeUser,
1414
};
@@ -18,22 +18,17 @@ use digest::{
1818
///
1919
/// # Errors
2020
/// - `dst` contains no bytes
21-
/// - `len_in_bytes > u16::MAX`
21+
/// - `dst > 255 && HashT::OutputSize > 255`
2222
/// - `len_in_bytes > 255 * HashT::OutputSize`
2323
#[derive(Debug)]
2424
pub struct ExpandMsgXmd<HashT>(PhantomData<HashT>)
2525
where
2626
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
27-
HashT::OutputSize: IsLess<U256, Output = True>,
2827
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>;
2928

3029
impl<HashT, K> ExpandMsg<K> for ExpandMsgXmd<HashT>
3130
where
3231
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
33-
// If DST is larger than 255 bytes, the length of the computed DST will depend on the output
34-
// size of the hash, which is still not allowed to be larger than 255.
35-
// https://www.rfc-editor.org/rfc/rfc9380.html#section-5.3.1-6
36-
HashT::OutputSize: IsLess<U256, Output = True>,
3732
// The number of bits output by `HashT` MUST be at most `HashT::BlockSize`.
3833
// https://www.rfc-editor.org/rfc/rfc9380.html#section-5.3.1-4
3934
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>,
@@ -47,17 +42,17 @@ where
4742
fn expand_message<'dst>(
4843
msg: &[&[u8]],
4944
dst: &'dst [&[u8]],
50-
len_in_bytes: NonZero<usize>,
45+
len_in_bytes: NonZero<u16>,
5146
) -> Result<Self::Expander<'dst>> {
52-
let len_in_bytes_u16 = u16::try_from(len_in_bytes.get()).map_err(|_| Error)?;
47+
let b_in_bytes = HashT::OutputSize::USIZE;
5348

5449
// `255 * <b_in_bytes>` can not exceed `u16::MAX`
55-
if len_in_bytes_u16 > 255 * HashT::OutputSize::to_u16() {
50+
if usize::from(len_in_bytes.get()) > 255 * b_in_bytes {
5651
return Err(Error);
5752
}
5853

59-
let b_in_bytes = HashT::OutputSize::to_usize();
60-
let ell = u8::try_from(len_in_bytes.get().div_ceil(b_in_bytes)).map_err(|_| Error)?;
54+
let ell = u8::try_from(usize::from(len_in_bytes.get()).div_ceil(b_in_bytes))
55+
.expect("should never pass the previous check");
6156

6257
let domain = Domain::xmd::<HashT>(dst)?;
6358
let mut b_0 = HashT::default();
@@ -67,7 +62,7 @@ where
6762
b_0.update(msg);
6863
}
6964

70-
b_0.update(&len_in_bytes_u16.to_be_bytes());
65+
b_0.update(&len_in_bytes.get().to_be_bytes());
7166
b_0.update(&[0]);
7267
domain.update_hash(&mut b_0);
7368
b_0.update(&[domain.len()]);
@@ -96,7 +91,6 @@ where
9691
pub struct ExpanderXmd<'a, HashT>
9792
where
9893
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
99-
HashT::OutputSize: IsLess<U256, Output = True>,
10094
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>,
10195
{
10296
b_0: Array<u8, HashT::OutputSize>,
@@ -110,7 +104,6 @@ where
110104
impl<HashT> ExpanderXmd<'_, HashT>
111105
where
112106
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
113-
HashT::OutputSize: IsLess<U256, Output = True>,
114107
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>,
115108
{
116109
fn next(&mut self) -> bool {
@@ -140,7 +133,6 @@ where
140133
impl<HashT> Expander for ExpanderXmd<'_, HashT>
141134
where
142135
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
143-
HashT::OutputSize: IsLess<U256, Output = True>,
144136
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>,
145137
{
146138
fn fill_bytes(&mut self, okm: &mut [u8]) {
@@ -157,11 +149,10 @@ where
157149
#[cfg(test)]
158150
mod test {
159151
use super::*;
160-
use core::mem::size_of;
161152
use hex_literal::hex;
162153
use hybrid_array::{
163154
ArraySize,
164-
typenum::{U4, U8, U32, U128},
155+
typenum::{IsLess, U4, U8, U32, U128, U65536},
165156
};
166157
use sha2::Sha256;
167158

@@ -172,9 +163,8 @@ mod test {
172163
bytes: &[u8],
173164
) where
174165
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
175-
HashT::OutputSize: IsLess<U256, Output = True>,
176166
{
177-
let block = HashT::BlockSize::to_usize();
167+
let block = HashT::BlockSize::USIZE;
178168
assert_eq!(
179169
Array::<u8, HashT::BlockSize>::default().as_slice(),
180170
&bytes[..block]
@@ -206,25 +196,24 @@ mod test {
206196

207197
impl TestVector {
208198
#[allow(clippy::panic_in_result_fn)]
209-
fn assert<HashT, L: ArraySize>(
199+
fn assert<HashT, L>(
210200
&self,
211201
dst: &'static [u8],
212202
domain: &Domain<'_, HashT::OutputSize>,
213203
) -> Result<()>
214204
where
215205
HashT: BlockSizeUser + Default + FixedOutput + HashMarker,
216-
HashT::OutputSize: IsLess<U256, Output = True>
217-
+ IsLessOrEqual<HashT::BlockSize, Output = True>
218-
+ Mul<U8>,
206+
HashT::OutputSize: IsLessOrEqual<HashT::BlockSize, Output = True>,
219207
HashT::OutputSize: IsGreaterOrEqual<U8, Output = True>,
208+
L: ArraySize + IsLess<U65536, Output = True>,
220209
{
221-
assert_message::<HashT>(self.msg, domain, L::to_u16(), self.msg_prime);
210+
assert_message::<HashT>(self.msg, domain, L::U16, self.msg_prime);
222211

223212
let dst = [dst];
224213
let mut expander = <ExpandMsgXmd<HashT> as ExpandMsg<U4>>::expand_message(
225214
&[self.msg],
226215
&dst,
227-
NonZero::new(L::to_usize()).ok_or(Error)?,
216+
NonZero::new(L::U16).ok_or(Error)?,
228217
)?;
229218

230219
let mut uniform_bytes = Array::<u8, L>::default();

elliptic-curve/src/hash2curve/hash2field/expand_msg/xof.rs

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,22 @@
11
//! `expand_message_xof` for the `ExpandMsg` trait
22
33
use super::{Domain, ExpandMsg, Expander};
4-
use crate::{Error, Result};
4+
use crate::Result;
55
use core::{fmt, num::NonZero, ops::Mul};
66
use digest::{
77
CollisionResistance, ExtendableOutput, HashMarker, Update, XofReader, typenum::IsGreaterOrEqual,
88
};
99
use hybrid_array::{
1010
ArraySize,
11-
typenum::{IsLess, Prod, True, U2, U256},
11+
typenum::{Prod, True, U2},
1212
};
1313

1414
/// Implements `expand_message_xof` via the [`ExpandMsg`] trait:
1515
/// <https://www.rfc-editor.org/rfc/rfc9380.html#name-expand_message_xof>
1616
///
1717
/// # Errors
1818
/// - `dst` contains no bytes
19-
/// - `len_in_bytes > u16::MAX`
19+
/// - `dst > 255 && K * 2 > 255`
2020
pub struct ExpandMsgXof<HashT>
2121
where
2222
HashT: Default + ExtendableOutput + Update + HashMarker,
@@ -41,7 +41,7 @@ where
4141
HashT: Default + ExtendableOutput + Update + HashMarker,
4242
// If DST is larger than 255 bytes, the length of the computed DST is calculated by `K * 2`.
4343
// https://www.rfc-editor.org/rfc/rfc9380.html#section-5.3.1-2.1
44-
K: Mul<U2, Output: ArraySize + IsLess<U256, Output = True>>,
44+
K: Mul<U2, Output: ArraySize>,
4545
// The collision resistance of `HashT` MUST be at least `K` bits.
4646
// https://www.rfc-editor.org/rfc/rfc9380.html#section-5.3.2-2.1
4747
HashT: CollisionResistance<CollisionResistance: IsGreaterOrEqual<K, Output = True>>,
@@ -51,9 +51,9 @@ where
5151
fn expand_message<'dst>(
5252
msg: &[&[u8]],
5353
dst: &'dst [&[u8]],
54-
len_in_bytes: NonZero<usize>,
54+
len_in_bytes: NonZero<u16>,
5555
) -> Result<Self::Expander<'dst>> {
56-
let len_in_bytes = u16::try_from(len_in_bytes.get()).map_err(|_| Error)?;
56+
let len_in_bytes = len_in_bytes.get();
5757

5858
let domain = Domain::<Prod<K, U2>>::xof::<HashT>(dst)?;
5959
let mut reader = HashT::default();
@@ -81,12 +81,14 @@ where
8181

8282
#[cfg(test)]
8383
mod test {
84+
use crate::Error;
85+
8486
use super::*;
8587
use core::mem::size_of;
8688
use hex_literal::hex;
8789
use hybrid_array::{
8890
Array, ArraySize,
89-
typenum::{U16, U32, U128},
91+
typenum::{IsLess, U16, U32, U128, U65536},
9092
};
9193
use sha3::Shake128;
9294

@@ -124,14 +126,14 @@ mod test {
124126
+ Update
125127
+ HashMarker
126128
+ CollisionResistance<CollisionResistance: IsGreaterOrEqual<U16, Output = True>>,
127-
L: ArraySize,
129+
L: ArraySize + IsLess<U65536, Output = True>,
128130
{
129131
assert_message(self.msg, domain, L::to_u16(), self.msg_prime);
130132

131133
let mut expander = <ExpandMsgXof<HashT> as ExpandMsg<U16>>::expand_message(
132134
&[self.msg],
133135
&[dst],
134-
NonZero::new(L::to_usize()).ok_or(Error)?,
136+
NonZero::new(L::U16).ok_or(Error)?,
135137
)?;
136138

137139
let mut uniform_bytes = Array::<u8, L>::default();

0 commit comments

Comments
 (0)