Skip to content

Commit 3b47a8a

Browse files
authored
Unify ConstMontyParams and MontyParams (#873)
Changes `ConstMontyParams` from having associated constants that duplicate `MontyParams` to having an associated `MontyParams` constant. `MontyParams` has `const fn` constructors, which the `const_monty_params!` (and legacy `impl_modulus!`) macros have been changed to use. This gets us a bit closer to the ideal solution which would leverage the `adt_const_params` feature (rust-lang/rust#95174), as that would allow `ConstMontyForm` to be generic around a `MontyParams` constant.
1 parent 73c1067 commit 3b47a8a

File tree

15 files changed

+110
-146
lines changed

15 files changed

+110
-146
lines changed

benches/const_monty.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ fn bench_montgomery_conversion<M: Measurement>(group: &mut BenchmarkGroup<'_, M>
2121
let mut rng = ChaChaRng::from_os_rng();
2222
group.bench_function("ConstMontyForm creation", |b| {
2323
b.iter_batched(
24-
|| U256::random_mod(&mut rng, Modulus::MODULUS.as_nz_ref()),
24+
|| U256::random_mod(&mut rng, Modulus::PARAMS.modulus().as_nz_ref()),
2525
|x| black_box(ConstMontyForm::new(&x)),
2626
BatchSize::SmallInput,
2727
)

src/modular.rs

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -74,15 +74,15 @@ mod tests {
7474
#[test]
7575
fn test_montgomery_params() {
7676
assert_eq!(
77-
Modulus1::ONE,
77+
Modulus1::PARAMS.one,
7878
U256::from_be_hex("1824b159acc5056f998c4fefecbc4ff55884b7fa0003480200000001fffffffe")
7979
);
8080
assert_eq!(
81-
Modulus1::R2,
81+
Modulus1::PARAMS.r2,
8282
U256::from_be_hex("0748d9d99f59ff1105d314967254398f2b6cedcb87925c23c999e990f3f29c6d")
8383
);
8484
assert_eq!(
85-
Modulus1::MOD_NEG_INV,
85+
Modulus1::PARAMS.mod_neg_inv,
8686
U64::from_be_hex("fffffffeffffffff").limbs[0]
8787
);
8888
}
@@ -98,9 +98,9 @@ mod tests {
9898
// Divide the value R by R, which should equal 1
9999
assert_eq!(
100100
montgomery_reduction::<{ Modulus2::LIMBS }>(
101-
&(Modulus2::ONE, Uint::ZERO),
102-
&Modulus2::MODULUS,
103-
Modulus2::MOD_NEG_INV
101+
&(Modulus2::PARAMS.one, Uint::ZERO),
102+
&Modulus2::PARAMS.modulus,
103+
Modulus2::PARAMS.mod_neg_inv
104104
),
105105
Uint::ONE
106106
);
@@ -111,25 +111,25 @@ mod tests {
111111
// Divide the value R^2 by R, which should equal R
112112
assert_eq!(
113113
montgomery_reduction::<{ Modulus2::LIMBS }>(
114-
&(Modulus2::R2, Uint::ZERO),
115-
&Modulus2::MODULUS,
116-
Modulus2::MOD_NEG_INV
114+
&(Modulus2::PARAMS.r2, Uint::ZERO),
115+
&Modulus2::PARAMS.modulus,
116+
Modulus2::PARAMS.mod_neg_inv
117117
),
118-
Modulus2::ONE
118+
Modulus2::PARAMS.one
119119
);
120120
}
121121

122122
#[test]
123123
fn test_reducing_r2_wide() {
124124
// Divide the value ONE^2 by R, which should equal ONE
125-
let (lo, hi) = Modulus2::ONE.square().split();
125+
let (lo, hi) = Modulus2::PARAMS.one.square().split();
126126
assert_eq!(
127127
montgomery_reduction::<{ Modulus2::LIMBS }>(
128128
&(lo, hi),
129-
&Modulus2::MODULUS,
130-
Modulus2::MOD_NEG_INV
129+
&Modulus2::PARAMS.modulus,
130+
Modulus2::PARAMS.mod_neg_inv
131131
),
132-
Modulus2::ONE
132+
Modulus2::PARAMS.one
133133
);
134134
}
135135

@@ -138,12 +138,12 @@ mod tests {
138138
// Reducing xR should return x
139139
let x =
140140
U256::from_be_hex("44acf6b7e36c1342c2c5897204fe09504e1e2efb1a900377dbc4e7a6a133ec56");
141-
let product = x.widening_mul(&Modulus2::ONE);
141+
let product = x.widening_mul(&Modulus2::PARAMS.one);
142142
assert_eq!(
143143
montgomery_reduction::<{ Modulus2::LIMBS }>(
144144
&product,
145-
&Modulus2::MODULUS,
146-
Modulus2::MOD_NEG_INV
145+
&Modulus2::PARAMS.modulus,
146+
Modulus2::PARAMS.mod_neg_inv
147147
),
148148
x
149149
);
@@ -154,20 +154,21 @@ mod tests {
154154
// Reducing xR^2 should return xR
155155
let x =
156156
U256::from_be_hex("44acf6b7e36c1342c2c5897204fe09504e1e2efb1a900377dbc4e7a6a133ec56");
157-
let product = x.widening_mul(&Modulus2::R2);
157+
let product = x.widening_mul(&Modulus2::PARAMS.r2);
158158

159159
// Computing xR mod modulus without Montgomery reduction
160-
let (lo, hi) = x.widening_mul(&Modulus2::ONE);
160+
let (lo, hi) = x.widening_mul(&Modulus2::PARAMS.one);
161161
let c = lo.concat(&hi);
162-
let red = c.rem_vartime(&NonZero::new(Modulus2::MODULUS.0.concat(&U256::ZERO)).unwrap());
162+
let red =
163+
c.rem_vartime(&NonZero::new(Modulus2::PARAMS.modulus.0.concat(&U256::ZERO)).unwrap());
163164
let (lo, hi) = red.split();
164165
assert_eq!(hi, Uint::ZERO);
165166

166167
assert_eq!(
167168
montgomery_reduction::<{ Modulus2::LIMBS }>(
168169
&product,
169-
&Modulus2::MODULUS,
170-
Modulus2::MOD_NEG_INV
170+
&Modulus2::PARAMS.modulus,
171+
Modulus2::PARAMS.mod_neg_inv
171172
),
172173
lo
173174
);

src/modular/boxed_monty_form.rs

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,10 @@ mod neg;
88
mod pow;
99
mod sub;
1010

11-
use super::{ConstMontyParams, Retrieve, div_by_2};
11+
use super::{Retrieve, div_by_2};
1212
use mul::BoxedMontyMultiplier;
1313

14-
use crate::{BoxedUint, Limb, Monty, Odd, Resize, Word};
14+
use crate::{BoxedUint, Limb, Monty, Odd, Resize, Word, modular::MontyParams};
1515
use alloc::sync::Arc;
1616
use subtle::Choice;
1717

@@ -155,23 +155,30 @@ impl BoxedMontyParams {
155155
pub(crate) fn mod_leading_zeros(&self) -> u32 {
156156
self.0.mod_leading_zeros
157157
}
158+
}
158159

159-
/// Create from a set of [`ConstMontyParams`].
160-
pub fn from_const_params<const LIMBS: usize, P: ConstMontyParams<LIMBS>>() -> Self {
160+
impl<const LIMBS: usize> From<&MontyParams<LIMBS>> for BoxedMontyParams {
161+
fn from(params: &MontyParams<LIMBS>) -> Self {
161162
Self(
162163
BoxedMontyParamsInner {
163-
modulus: P::MODULUS.into(),
164-
one: P::ONE.into(),
165-
r2: P::R2.into(),
166-
r3: P::R3.into(),
167-
mod_neg_inv: P::MOD_NEG_INV,
168-
mod_leading_zeros: P::MOD_LEADING_ZEROS,
164+
modulus: params.modulus.into(),
165+
one: params.one.into(),
166+
r2: params.r2.into(),
167+
r3: params.r3.into(),
168+
mod_neg_inv: params.mod_neg_inv,
169+
mod_leading_zeros: params.mod_leading_zeros,
169170
}
170171
.into(),
171172
)
172173
}
173174
}
174175

176+
impl<const LIMBS: usize> From<MontyParams<LIMBS>> for BoxedMontyParams {
177+
fn from(params: MontyParams<LIMBS>) -> Self {
178+
BoxedMontyParams::from(&params)
179+
}
180+
}
181+
175182
/// An integer in Montgomery form represented using heap-allocated limbs.
176183
#[derive(Clone, Debug, Eq, PartialEq)]
177184
pub struct BoxedMontyForm {

src/modular/const_monty_form.rs

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,10 @@ mod pow;
99
mod sub;
1010

1111
use self::invert::ConstMontyFormInverter;
12-
use super::{Retrieve, SafeGcdInverter, div_by_2::div_by_2, reduction::montgomery_reduction};
13-
use crate::{ConstZero, Limb, Odd, PrecomputeInverter, Uint};
12+
use super::{
13+
MontyParams, Retrieve, SafeGcdInverter, div_by_2::div_by_2, reduction::montgomery_reduction,
14+
};
15+
use crate::{ConstZero, Odd, PrecomputeInverter, Uint};
1416
use core::{fmt::Debug, marker::PhantomData};
1517
use subtle::{Choice, ConditionallySelectable, ConstantTimeEq};
1618

@@ -39,19 +41,8 @@ pub trait ConstMontyParams<const LIMBS: usize>:
3941
/// Number of limbs required to encode the Montgomery form
4042
const LIMBS: usize;
4143

42-
/// The constant modulus
43-
const MODULUS: Odd<Uint<LIMBS>>;
44-
/// 1 in Montgomery form
45-
const ONE: Uint<LIMBS>;
46-
/// `R^2 mod MODULUS`, used to move into Montgomery form
47-
const R2: Uint<LIMBS>;
48-
/// `R^3 mod MODULUS`, used to perform a multiplicative inverse
49-
const R3: Uint<LIMBS>;
50-
/// The lowest limbs of -(MODULUS^-1) mod R
51-
// We only need the LSB because during reduction this value is multiplied modulo 2**Limb::BITS.
52-
const MOD_NEG_INV: Limb;
53-
/// Leading zeros in the modulus, used to choose optimized algorithms
54-
const MOD_LEADING_ZEROS: u32;
44+
/// Montgomery parameters constant.
45+
const PARAMS: MontyParams<LIMBS>;
5546

5647
/// Precompute a Bernstein-Yang inverter for this modulus.
5748
///
@@ -67,7 +58,7 @@ pub trait ConstMontyParams<const LIMBS: usize>:
6758
/// An integer in Montgomery form modulo `MOD`, represented using `LIMBS` limbs.
6859
/// The modulus is constant, so it cannot be set at runtime.
6960
///
70-
/// Internally, the value is stored in Montgomery form (multiplied by MOD::ONE) until it is retrieved.
61+
/// Internally, the value is stored in Montgomery form (multiplied by MOD::PARAMS.one) until it is retrieved.
7162
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
7263
pub struct ConstMontyForm<MOD: ConstMontyParams<LIMBS>, const LIMBS: usize> {
7364
montgomery_form: Uint<LIMBS>,
@@ -89,16 +80,16 @@ impl<MOD: ConstMontyParams<LIMBS>, const LIMBS: usize> ConstMontyForm<MOD, LIMBS
8980

9081
/// The representation of 1 mod `MOD`.
9182
pub const ONE: Self = Self {
92-
montgomery_form: MOD::ONE,
83+
montgomery_form: MOD::PARAMS.one,
9384
phantom: PhantomData,
9485
};
9586

9687
/// Internal helper function to convert to Montgomery form;
9788
/// this lets us cleanly wrap the constructors.
9889
const fn from_integer(integer: &Uint<LIMBS>) -> Self {
99-
let product = integer.widening_mul(&MOD::R2);
90+
let product = integer.widening_mul(&MOD::PARAMS.r2);
10091
let montgomery_form =
101-
montgomery_reduction::<LIMBS>(&product, &MOD::MODULUS, MOD::MOD_NEG_INV);
92+
montgomery_reduction::<LIMBS>(&product, &MOD::PARAMS.modulus, MOD::PARAMS.mod_neg_inv);
10293

10394
Self {
10495
montgomery_form,
@@ -115,8 +106,8 @@ impl<MOD: ConstMontyParams<LIMBS>, const LIMBS: usize> ConstMontyForm<MOD, LIMBS
115106
pub const fn retrieve(&self) -> Uint<LIMBS> {
116107
montgomery_reduction::<LIMBS>(
117108
&(self.montgomery_form, Uint::ZERO),
118-
&MOD::MODULUS,
119-
MOD::MOD_NEG_INV,
109+
&MOD::PARAMS.modulus,
110+
MOD::PARAMS.mod_neg_inv,
120111
)
121112
}
122113

@@ -146,7 +137,7 @@ impl<MOD: ConstMontyParams<LIMBS>, const LIMBS: usize> ConstMontyForm<MOD, LIMBS
146137
/// Performs division by 2, that is returns `x` such that `x + x = self`.
147138
pub const fn div_by_2(&self) -> Self {
148139
Self {
149-
montgomery_form: div_by_2(&self.montgomery_form, &MOD::MODULUS),
140+
montgomery_form: div_by_2(&self.montgomery_form, &MOD::PARAMS.modulus),
150141
phantom: PhantomData,
151142
}
152143
}
@@ -206,7 +197,7 @@ where
206197
fn try_random<R: TryRngCore + ?Sized>(rng: &mut R) -> Result<Self, R::Error> {
207198
Ok(Self::new(&Uint::try_random_mod(
208199
rng,
209-
MOD::MODULUS.as_nz_ref(),
200+
MOD::PARAMS.modulus.as_nz_ref(),
210201
)?))
211202
}
212203
}
@@ -229,7 +220,7 @@ where
229220
D: Deserializer<'de>,
230221
{
231222
Uint::<LIMBS>::deserialize(deserializer).and_then(|montgomery_form| {
232-
if montgomery_form < MOD::MODULUS.0 {
223+
if montgomery_form < MOD::PARAMS.modulus.0 {
233224
Ok(Self {
234225
montgomery_form,
235226
phantom: PhantomData,

src/modular/const_monty_form/add.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ impl<MOD: ConstMontyParams<LIMBS>, const LIMBS: usize> ConstMontyForm<MOD, LIMBS
1111
montgomery_form: add_montgomery_form(
1212
&self.montgomery_form,
1313
&rhs.montgomery_form,
14-
&MOD::MODULUS,
14+
&MOD::PARAMS.modulus,
1515
),
1616
phantom: core::marker::PhantomData,
1717
}
@@ -20,7 +20,7 @@ impl<MOD: ConstMontyParams<LIMBS>, const LIMBS: usize> ConstMontyForm<MOD, LIMBS
2020
/// Double `self`.
2121
pub const fn double(&self) -> Self {
2222
Self {
23-
montgomery_form: double_montgomery_form(&self.montgomery_form, &MOD::MODULUS),
23+
montgomery_form: double_montgomery_form(&self.montgomery_form, &MOD::PARAMS.modulus),
2424
phantom: core::marker::PhantomData,
2525
}
2626
}

src/modular/const_monty_form/invert.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,10 @@ where
3131
/// If the number was invertible, the second element of the tuple is the truthy value,
3232
/// otherwise it is the falsy value (in which case the first element's value is unspecified).
3333
pub const fn invert(&self) -> ConstCtOption<Self> {
34-
let inverter =
35-
<Odd<Uint<SAT_LIMBS>> as PrecomputeInverter>::Inverter::new(&MOD::MODULUS, &MOD::R2);
34+
let inverter = <Odd<Uint<SAT_LIMBS>> as PrecomputeInverter>::Inverter::new(
35+
&MOD::PARAMS.modulus,
36+
&MOD::PARAMS.r2,
37+
);
3638

3739
let maybe_inverse = inverter.invert(&self.montgomery_form);
3840
let (inverse, inverse_is_some) = maybe_inverse.components_ref();
@@ -67,8 +69,10 @@ where
6769
/// This version is variable-time with respect to the value of `self`, but constant-time with
6870
/// respect to `MOD`.
6971
pub const fn invert_vartime(&self) -> ConstCtOption<Self> {
70-
let inverter =
71-
<Odd<Uint<SAT_LIMBS>> as PrecomputeInverter>::Inverter::new(&MOD::MODULUS, &MOD::R2);
72+
let inverter = <Odd<Uint<SAT_LIMBS>> as PrecomputeInverter>::Inverter::new(
73+
&MOD::PARAMS.modulus,
74+
&MOD::PARAMS.r2,
75+
);
7276

7377
let maybe_inverse = inverter.invert_vartime(&self.montgomery_form);
7478
let (inverse, inverse_is_some) = maybe_inverse.components_ref();
@@ -121,7 +125,7 @@ where
121125
/// Create a new [`ConstMontyFormInverter`] for the given [`ConstMontyParams`].
122126
#[allow(clippy::new_without_default)]
123127
pub const fn new() -> Self {
124-
let inverter = SafeGcdInverter::new(&MOD::MODULUS, &MOD::R2);
128+
let inverter = SafeGcdInverter::new(&MOD::PARAMS.modulus, &MOD::PARAMS.r2);
125129

126130
Self {
127131
inverter,

src/modular/const_monty_form/lincomb.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,11 @@ impl<MOD: ConstMontyParams<LIMBS>, const LIMBS: usize> ConstMontyForm<MOD, LIMBS
1212
/// For a modulus with leading zeros, this method is more efficient than a naive sum of products.
1313
pub const fn lincomb_vartime(products: &[(Self, Self)]) -> Self {
1414
Self {
15-
montgomery_form: lincomb_const_monty_form(products, &MOD::MODULUS, MOD::MOD_NEG_INV),
15+
montgomery_form: lincomb_const_monty_form(
16+
products,
17+
&MOD::PARAMS.modulus,
18+
MOD::PARAMS.mod_neg_inv,
19+
),
1620
phantom: PhantomData,
1721
}
1822
}
@@ -32,7 +36,7 @@ mod tests {
3236
U256,
3337
"7fffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551"
3438
);
35-
let modulus = P::MODULUS.as_nz_ref();
39+
let modulus = P::PARAMS.modulus.as_nz_ref();
3640

3741
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(1);
3842
for n in 0..1000 {

0 commit comments

Comments
 (0)