Skip to content

Implement the modular inverse using unsigned 256-bit integers addition and shifts #1073

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 13 commits into from
4 changes: 4 additions & 0 deletions src/modinv64.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#error "modinv64 requires 128-bit wide multiplication support"
#endif

#include "scalar.h"

/* A signed 62-bit limb representation of integers.
*
* Its value is sum(v[i] * 2^(62*i), i=0..4). */
Expand All @@ -43,4 +45,6 @@ static void secp256k1_modinv64_var(secp256k1_modinv64_signed62 *x, const secp256
/* Same as secp256k1_modinv64_var, but constant time in x (not in the modulus). */
static void secp256k1_modinv64(secp256k1_modinv64_signed62 *x, const secp256k1_modinv64_modinfo *modinfo);

static void secp256k1_modinv64_scalar(secp256k1_scalar *ret, const secp256k1_scalar *x, const secp256k1_scalar *m);

#endif /* SECP256K1_MODINV64_H */
143 changes: 143 additions & 0 deletions src/modinv64_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -590,4 +590,147 @@ static void secp256k1_modinv64_var(secp256k1_modinv64_signed62 *x, const secp256
*x = d;
}

static void _secp256k1_scalar_shr_void(secp256k1_scalar *x, int bits) {
while (bits >= 15) {
secp256k1_scalar_shr_int(x, 15);
bits -= 15;
}
if (bits) {
secp256k1_scalar_shr_int(x, bits);
}
}

static void _secp256k1_scalar_shl_void(secp256k1_scalar *x, int bits) {
while (bits >= 15) {
secp256k1_scalar_shl_int(x, 15);
bits -= 15;
}
if (bits) {
secp256k1_scalar_shl_int(x, bits);
}
}

static unsigned int _secp256k1_scalar_msb_signed(const secp256k1_scalar *x, unsigned int forcePositive, int hint) {
if (!forcePositive && secp256k1_scalar_get_bit(x, 255)) {
/*
return secp256k1_scalar_msb_neg(x);
*/
secp256k1_scalar a = *x;
secp256k1_scalar_neg(&a, &a);
return secp256k1_scalar_msb_hint(&a, hint);

}
return secp256k1_scalar_msb_hint(x, hint);
}

/*
* Original algorithm borrowed from this paper:
* https://www.researchgate.net/publication/304417579_Modular_Inverse_Algorithms_Without_Multiplications_for_Cryptographic_Applications (LS3)
*
* Was improved by Anton Bukov and Mikhail Melnik to use uint256 integers.
* https://gist.github.com/k06a/b990b7c7dda766d4f661e653d6804a53
*/
static void secp256k1_modinv64_scalar(secp256k1_scalar *ret, const secp256k1_scalar *a, const secp256k1_scalar *m) {
unsigned int i, llu, llv, f, lltmp, firstFlip = 0;
secp256k1_scalar *tmp;
secp256k1_scalar zero = SECP256K1_SCALAR_CONST(0, 0, 0, 0, 0, 0, 0, 0);
secp256k1_scalar one = SECP256K1_SCALAR_CONST(0, 0, 0, 0, 0, 0, 0, 1);
secp256k1_scalar _u, _v, _r, _s, _vv, _ss;
secp256k1_scalar *u = &_u, *v = &_v, *r = &_r, *s = &_s, *vv = &_vv, *ss = &_ss;

#ifdef VERIFY
secp256k1_scalar six = SECP256K1_SCALAR_CONST(0, 0, 0, 0, 0, 0, 0, 0x2d);
secp256k1_scalar twelve = SECP256K1_SCALAR_CONST(0, 0, 0, 0, 0, 0, 0, 0x2d*2);
VERIFY_CHECK(secp256k1_scalar_msb(&zero) == 0);
VERIFY_CHECK(secp256k1_scalar_msb(&one) == 1);
VERIFY_CHECK(secp256k1_scalar_msb(&six) == 6);

/* printf("%d\n", (int)sizeof(*m)); */
if (sizeof(*m) == 32) {
secp256k1_scalar huge = SECP256K1_SCALAR_CONST(0, 0, 0, 0, 0, 2, 0, 0);
VERIFY_CHECK(secp256k1_scalar_msb(&huge) == 66);
secp256k1_scalar_neg(&huge, &huge);
VERIFY_CHECK(_secp256k1_scalar_msb_signed(&huge, 0, 256) == 66);
}

secp256k1_scalar_neg(&six, &six);
VERIFY_CHECK(_secp256k1_scalar_msb_signed(&six, 0, 256) == 6);
secp256k1_scalar_neg(&six, &six);

VERIFY_CHECK(secp256k1_scalar_shl_int(&six, 1) == 0);
VERIFY_CHECK(secp256k1_scalar_eq(&six, &twelve));
#endif

/* if (a < m) { u = m; v = a; r = 0; s = 1; } */
/* else { v = m; u = a; s = 0; r = 1; } */
if (secp256k1_scalar_cmp(a, m) < 0) { *u = *m; *v = *a; *r = zero; *s = one; }
else { *v = *m; *u = *a; *s = zero; *r = one; }

llu = _secp256k1_scalar_msb_signed(u, 1, 256);
llv = _secp256k1_scalar_msb_signed(v, 1, 256);
/* while (ll(v) > 1) */
for (i = 0; !secp256k1_scalar_is_around_zero(v); i++) {
f = llu - llv;
/* if (i == 0 || (u >> 255) == (v >> 255)) */
if (((1 - secp256k1_scalar_get_bit(u, 255)) | (i < 1 + firstFlip)) == ((1 - secp256k1_scalar_get_bit(v, 255)) | (i < 2 - firstFlip))) {
/* u = u - (v << f); */
*vv = *v;
_secp256k1_scalar_shl_void(vv, f);
secp256k1_scalar_minus(u, u, vv);

/* r = r - (s << f); */
*ss = *s;
_secp256k1_scalar_shl_void(ss, f);
secp256k1_scalar_minus(r, r, ss);
}
else {
/* u = u + (v << f); */
*vv = *v;
_secp256k1_scalar_shl_void(vv, f);
secp256k1_scalar_plus(u, u, vv);

/* r = r + (s << f); */
*ss = *s;
_secp256k1_scalar_shl_void(ss, f);
secp256k1_scalar_plus(r, r, ss);
}

/* llu = ll(u); */
llu = _secp256k1_scalar_msb_signed(u, 0, llu);
if (llu < llv) {
firstFlip = 1;
/* (u,v,r,s,llu,llv) = (v,u,s,r,llv,llu); */
tmp = u; u = v; v = tmp;
tmp = r; r = s; s = tmp;
lltmp = llu; llu = llv; llv = lltmp;
}
}

/* if (v == 0) { return 0; } */
if (secp256k1_scalar_is_zero(v)) {
*ret = zero;
return;
}

/* if (v >> 255 == 1) { s = 0-s; } */
if (secp256k1_scalar_get_bit(v, 255)) {
secp256k1_scalar_neg(s, s);
}

/* if (s >> 255 == 0 && s > m) { return s - m; } */
if (!secp256k1_scalar_get_bit(s, 255) && secp256k1_scalar_cmp(s, m) > 0) {
secp256k1_scalar_minus(ret, s, m);
return;
}

/* if (s >> 255 == 1) { return s + m; } */
if (secp256k1_scalar_get_bit(s, 255)) {
secp256k1_scalar_plus(ret, s, m);
return;
}

/* return s; */
*ret = *s;
}

#endif /* SECP256K1_MODINV64_IMPL_H */
34 changes: 34 additions & 0 deletions src/scalar.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ static unsigned int secp256k1_scalar_get_bits(const secp256k1_scalar *a, unsigne
/** Access bits from a scalar. Not constant time. */
static unsigned int secp256k1_scalar_get_bits_var(const secp256k1_scalar *a, unsigned int offset, unsigned int count);

/** Access bit from a scalar. */
static unsigned int secp256k1_scalar_get_bit(const secp256k1_scalar *a, unsigned int offset);

/** Set a scalar from a big endian byte array. The scalar will be reduced modulo group order `n`.
* In: bin: pointer to a 32-byte array.
* Out: r: scalar to be set.
Expand All @@ -52,6 +55,12 @@ static void secp256k1_scalar_get_b32(unsigned char *bin, const secp256k1_scalar*
/** Add two scalars together (modulo the group order). Returns whether it overflowed. */
static int secp256k1_scalar_add(secp256k1_scalar *r, const secp256k1_scalar *a, const secp256k1_scalar *b);

/** Add two scalars together (NO modulo the group order). Returns whether it overflowed. */
static int secp256k1_scalar_plus(secp256k1_scalar *r, const secp256k1_scalar *a, const secp256k1_scalar *b);

/** Substracts two scalars together (NO modulo the group order). Returns whether it overflowed. */
static int secp256k1_scalar_minus(secp256k1_scalar *r, const secp256k1_scalar *a, const secp256k1_scalar *b);

/** Conditionally add a power of two to a scalar. The result is not allowed to overflow. */
static void secp256k1_scalar_cadd_bit(secp256k1_scalar *r, unsigned int bit, int flag);

Expand All @@ -62,6 +71,10 @@ static void secp256k1_scalar_mul(secp256k1_scalar *r, const secp256k1_scalar *a,
* the low bits that were shifted off */
static int secp256k1_scalar_shr_int(secp256k1_scalar *r, int n);

/** Shift a scalar right by some amount strictly between 0 and 16, returning
* the high bits that were shifted off */
static int secp256k1_scalar_shl_int(secp256k1_scalar *r, int n);

/** Compute the inverse of a scalar (modulo the group order). */
static void secp256k1_scalar_inverse(secp256k1_scalar *r, const secp256k1_scalar *a);

Expand All @@ -71,6 +84,9 @@ static void secp256k1_scalar_inverse_var(secp256k1_scalar *r, const secp256k1_sc
/** Compute the complement of a scalar (modulo the group order). */
static void secp256k1_scalar_negate(secp256k1_scalar *r, const secp256k1_scalar *a);

/** Compute the complement of a scalar (NO modulo the group order). */
static void secp256k1_scalar_neg(secp256k1_scalar *r, const secp256k1_scalar *a);

/** Check whether a scalar equals zero. */
static int secp256k1_scalar_is_zero(const secp256k1_scalar *a);

Expand All @@ -90,6 +106,24 @@ static int secp256k1_scalar_cond_negate(secp256k1_scalar *a, int flag);
/** Compare two scalars. */
static int secp256k1_scalar_eq(const secp256k1_scalar *a, const secp256k1_scalar *b);

/** Compare two scalars. */
static int secp256k1_scalar_cmp(const secp256k1_scalar *a, const secp256k1_scalar *b);

/** Compare two scalars in constant time. */
static int secp256k1_scalar_cmp_var(const secp256k1_scalar *a, const secp256k1_scalar *b);

/** Find the most significant bit. */
static int secp256k1_scalar_msb(const secp256k1_scalar *a);

/** Find the most significant bit. */
static int secp256k1_scalar_msb_hint(const secp256k1_scalar *a, int hint);

/** Find the most significant bit, considering arg is negative. */
static int secp256k1_scalar_msb_neg(const secp256k1_scalar *a);

/** Check whether a scalar equals zero, -1 or +1. */
static int secp256k1_scalar_is_around_zero(const secp256k1_scalar *a);

/** Find r1 and r2 such that r1+r2*2^128 = k. */
static void secp256k1_scalar_split_128(secp256k1_scalar *r1, secp256k1_scalar *r2, const secp256k1_scalar *k);
/** Find r1 and r2 such that r1+r2*lambda = k,
Expand Down
Loading