Skip to content

Commit 5271894

Browse files
sipapeterdettman
authored andcommitted
Update safegcd writeup to reflect the code
1 parent f027e1f commit 5271894

File tree

1 file changed

+85
-63
lines changed

1 file changed

+85
-63
lines changed

doc/safegcd_implementation.md

Lines changed: 85 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -155,14 +155,14 @@ do one division by *2<sup>N</sup>* as a final step:
155155
```python
156156
def divsteps_n_matrix(delta, f, g):
157157
"""Compute delta and transition matrix t after N divsteps (multiplied by 2^N)."""
158-
u, v, q, r = 1, 0, 0, 1 # start with identity matrix
158+
u, v, q, r = 1<<N, 0, 0, 1<<N # start with identity matrix (scaled by 2^N)
159159
for _ in range(N):
160160
if delta > 0 and g & 1:
161-
delta, f, g, u, v, q, r = 1 - delta, g, (g - f) // 2, 2*q, 2*r, q-u, r-v
161+
delta, f, g, u, v, q, r = 1 - delta, g, (g-f)//2, q, r, (q-u)//2, (r-v)//2
162162
elif g & 1:
163-
delta, f, g, u, v, q, r = 1 + delta, f, (g + f) // 2, 2*u, 2*v, q+u, r+v
163+
delta, f, g, u, v, q, r = 1 + delta, f, (g+f)//2, u, v, (q+u)//2, (r+v)//2
164164
else:
165-
delta, f, g, u, v, q, r = 1 + delta, f, (g ) // 2, 2*u, 2*v, q , r
165+
delta, f, g, u, v, q, r = 1 + delta, f, (g )//2, u, v, (q )//2, (r )//2
166166
return delta, (u, v, q, r)
167167
```
168168

@@ -414,9 +414,9 @@ operations (and hope the C compiler isn't smart enough to turn them back into br
414414
divstep can be written instead as (compare to the inner loop of `gcd` in section 1).
415415

416416
```python
417-
x = -f if delta > 0 else f # set x equal to (input) -f or f
417+
x = f if delta > 0 else -f # set x equal to (input) f or -f
418418
if g & 1:
419-
g += x # set g to (input) g-f or g+f
419+
g -= x # set g to (input) g-f or g+f
420420
if delta > 0:
421421
delta = -delta
422422
f += g # set f to (input) g (note that g was set to g-f before)
@@ -433,19 +433,21 @@ that *-v == (v ^ -1) - (-1)*. Thus, if we have a variable *c* that takes on valu
433433
Using this we can write:
434434

435435
```python
436-
x = -f if delta > 0 else f
436+
x = f if delta > 0 else -f
437437
```
438438

439439
in constant-time form as:
440440

441441
```python
442-
c1 = (-delta) >> 63
442+
# Compute c1=0 if delta>0 and c1=-1 if delta<=0.
443+
c1 = (delta - 1) >> 63
443444
# Conditionally negate f based on c1:
444445
x = (f ^ c1) - c1
445446
```
446447

447-
To use that trick, we need a helper mask variable *c1* that resolves the condition *&delta;>0* to *-1*
448-
(if true) or *0* (if false). We compute *c1* using right shifting, which is equivalent to dividing by
448+
To use that trick, we need a helper mask variable *c1* that resolves the condition *&delta;&leq;0* to *-1*
449+
(if true) or *0* (if false). We compute *c1* by first subtracting *1*, which results in a negative value
450+
if and only if *&delta;&leq;0*. That is then right shifted, which is equivalent to dividing by
449451
the specified power of *2* and rounding down (in Python, and also in C under the assumption of a typical two's complement system; see
450452
`assumptions.h` for tests that this is the case). Right shifting by *63* thus maps all
451453
numbers in range *[-2<sup>63</sup>,0)* to *-1*, and numbers in range *[0,2<sup>63</sup>)* to *0*.
@@ -454,7 +456,7 @@ Using the facts that *x&0=0* and *x&(-1)=x* (on two's complement systems again),
454456

455457
```python
456458
if g & 1:
457-
g += x
459+
g -= x
458460
```
459461

460462
as:
@@ -463,7 +465,7 @@ as:
463465
# Compute c2=0 if g is even and c2=-1 if g is odd.
464466
c2 = -(g & 1)
465467
# This masks out x if g is even, and leaves x be if g is odd.
466-
g += x & c2
468+
g -= x & c2
467469
```
468470

469471
Using the conditional negation trick again we can write:
@@ -478,7 +480,7 @@ as:
478480

479481
```python
480482
# Compute c3=-1 if g is odd and delta>0, and 0 otherwise.
481-
c3 = c1 & c2
483+
c3 = ~c1 & c2
482484
# Conditionally negate delta based on c3:
483485
delta = (delta ^ c3) - c3
484486
```
@@ -497,45 +499,61 @@ becomes:
497499
f += g & c3
498500
```
499501

500-
It turns out that this can be implemented more efficiently by applying the substitution
501-
*&eta;=-&delta;*. In this representation, negating *&delta;* corresponds to negating *&eta;*, and incrementing
502-
*&delta;* corresponds to decrementing *&eta;*. This allows us to remove the negation in the *c1*
503-
computation:
502+
Putting everything together, extending all operations on f,g (with helper x) to also be applied
503+
to u,q (with helper y) and v,r (with helper z), gives:
504504

505505
```python
506-
# Compute a mask c1 for eta < 0, and compute the conditional negation x of f:
507-
c1 = eta >> 63
508-
x = (f ^ c1) - c1
509-
# Compute a mask c2 for odd g, and conditionally add x to g:
510-
c2 = -(g & 1)
511-
g += x & c2
512-
# Compute a mask c for (eta < 0) and odd (input) g, and use it to conditionally negate eta,
513-
# and add g to f:
514-
c3 = c1 & c2
515-
eta = (eta ^ c3) - c3
516-
f += g & c3
517-
# Incrementing delta corresponds to decrementing eta.
518-
eta -= 1
519-
g >>= 1
506+
def divsteps_n_matrix(delta, f, g):
507+
"""Compute delta and transition matrix t after N divsteps (multiplied by 2^N)."""
508+
u, v, q, r = 1<<N, 0, 0, 1<<N # start with identity matrix (scaled by 2^N).
509+
for i in range(N):
510+
c1 = (delta - 1) >> 63
511+
# Compute x, y, z as conditionally-negated versions of f, u, v.
512+
x, y, z = (f ^ c1) - c1, (u ^ c1) - c1, (v ^ c1) - c1
513+
c2 = -(g & 1)
514+
# Conditionally subtract x, y, z from g, q, r.
515+
g, q, r = g - (x & c2), q - (y & c2), r - (z & c2)
516+
c3 = ~c1 & c2
517+
# Conditionally negate delta, and then increment it by 1.
518+
delta = (delta ^ c3) - c3 + 1
519+
# Conditionally add g, q, r to f, u, v.
520+
f, u, v = f + (g & c3), u + (q & c3), v + (r & c3)
521+
# Shift down g, q, r.
522+
g, q, r = g >> 1, u >> 1, v >> 1
523+
return delta, (u, v, q, r)
520524
```
521525

522-
A variant of divsteps with better worst-case performance can be used instead: starting *&delta;* at
526+
An interesting optimization is possible here. If we were to drop the *-c1* in the computation
527+
of *x*, *y*, and *z*, we are making them at worst *1* less than the correct value. That
528+
translates to *g*, *q*, and *r* further being at worst *1* more than the correct value.
529+
Now observe that at the start of every iteration of the loop, *u*, *v*, *q*, and *r* are
530+
all multiples of *2<sup>N-i</sub>*, with *i* the iteration number, and thus all even.
531+
In other words, this potential off by one in *g*, *q*, and *r* only affects their bottommost
532+
bit, which is shifted away at the end of the loop. Thus we can instead write:
533+
534+
```python
535+
# Compute x, y, z as conditionally complemented versions of f, u, v.
536+
x, y, z = f ^ c1, u ^ c1, v ^ c1
537+
```
538+
539+
Finally, a variant of divsteps with better worst-case performance can be used instead: starting *&delta;* at
523540
*1/2* instead of *1*. This reduces the worst case number of iterations to *590* for *256*-bit inputs
524-
(which can be shown using convex hull analysis). In this case, the substitution *&zeta;=-(&delta;+1/2)*
525-
is used instead to keep the variable integral. Incrementing *&delta;* by *1* still translates to
526-
decrementing *&zeta;* by *1*, but negating *&delta;* now corresponds to going from *&zeta;* to *-(&zeta;+1)*, or
527-
*~&zeta;*. Doing that conditionally based on *c3* is simply:
541+
(which can be shown using [convex hull analysis](https://github.com/sipa/safegcd-bounds)).
542+
In this case, the substitution *&theta;=&delta;-1/2* is used to keep the variable integral.
543+
*&delta;&leq;0* then translates to *&theta;&leq;-1/2*, or because *&theta;* is integral, *&theta;<0*.
544+
Thus instead of `c1 = (delta - 1) >> 63` we get `c1 = theta >> 63`.
545+
Negating *&delta;* now corresponds to going from *&theta;* to
546+
*-&theta;-1*. Doing that conditionally based on *c3* (and then incrementing by one) gives us:
528547

529548
```python
530549
...
531-
c3 = c1 & c2
532-
zeta ^= c3
550+
theta = (theta ^ c3) + 1
533551
...
534552
```
535553

536554
By replacing the loop in `divsteps_n_matrix` with a variant of the divstep code above (extended to
537555
also apply all *f* operations to *u*, *v* and all *g* operations to *q*, *r*), a constant-time version of
538-
`divsteps_n_matrix` is obtained. The full code will be in section 7.
556+
`divsteps_n_matrix` is obtained. The resulting code will be in section 7.
539557

540558
These bit fiddling tricks can also be used to make the conditional negations and additions in
541559
`update_de` and `normalize` constant-time.
@@ -550,7 +568,7 @@ faster non-constant time `divsteps_n_matrix` function.
550568

551569
To do so, first consider yet another way of writing the inner loop of divstep operations in
552570
`gcd` from section 1. This decomposition is also explained in the paper in section 8.2. We use
553-
the original version with initial *&delta;=1* and *&eta;=-&delta;* here.
571+
the original version with initial *&delta;=1*, but make the substitution *&eta;=-&delta;*.
554572

555573
```python
556574
for _ in range(N):
@@ -651,37 +669,41 @@ Here we need the negated modular inverse, which is a simple transformation of th
651669
have this 6-bit function (based on the 3-bit function above):
652670
- *f(f<sup>2</sup> - 2)*
653671

654-
This loop, again extended to also handle *u*, *v*, *q*, and *r* alongside *f* and *g*, placed in
655-
`divsteps_n_matrix`, gives a significantly faster, but non-constant time version.
672+
This loop, extended to also handle *u*, *v*, *q*, and *r* alongside *f* and *g*, placed in
673+
`divsteps_n_matrix`, gives a significantly faster, but non-constant time version. In order to
674+
avoid intermediary values that need more than N+1 bits, it is possible to instead start
675+
*u* and *v* at *1* instead of at *2<sup>N</sup>*, and then shift up *u* and *v* whenever
676+
*g* is shifted down (instead of shifting down *q* and *r*). This is effectively making the
677+
algorithm operate on *i*-bits downshifted versions of all these variables. The resulting
678+
code is shown in the next section.
656679

657680

658681
## 7. Final Python version
659682

660683
All together we need the following functions:
661684

662685
- A way to compute the transition matrix in constant time, using the `divsteps_n_matrix` function
663-
from section 2, but with its loop replaced by a variant of the constant-time divstep from
664-
section 5, extended to handle *u*, *v*, *q*, *r*:
686+
from section 5, modified to operate on *&theta;* instead of *&delta;*:
665687

666688
```python
667-
def divsteps_n_matrix(zeta, f, g):
668-
"""Compute zeta and transition matrix t after N divsteps (multiplied by 2^N)."""
669-
u, v, q, r = 1, 0, 0, 1 # start with identity matrix
689+
def divsteps_n_matrix(theta, f, g):
690+
"""Compute theta and transition matrix t after N divsteps (multiplied by 2^N)."""
691+
u, v, q, r = 1<<N, 0, 0, 1<<N # start with identity matrix (scaled by 2^N).
670692
for _ in range(N):
671-
c1 = zeta >> 63
672-
# Compute x, y, z as conditionally-negated versions of f, u, v.
673-
x, y, z = (f ^ c1) - c1, (u ^ c1) - c1, (v ^ c1) - c1
693+
c1 = theta >> 63
694+
# Compute x, y, z as conditionally complemented versions of f, u, v.
695+
x, y, z = f ^ c1, u ^ c1, v ^ c1
674696
c2 = -(g & 1)
675-
# Conditionally add x, y, z to g, q, r.
676-
g, q, r = g + (x & c2), q + (y & c2), r + (z & c2)
677-
c1 &= c2 # reusing c1 here for the earlier c3 variable
678-
zeta = (zeta ^ c1) - 1 # inlining the unconditional zeta decrement here
697+
# Conditionally subtract x, y, z from g, q, r.
698+
g, q, r = g - (x & c2), q - (y & c2), r - (z & c2)
699+
c3 = ~c1 & c2
700+
# Conditionally completement theta, and then increment it by 1.
701+
theta = (theta ^ c3) + 1
679702
# Conditionally add g, q, r to f, u, v.
680-
f, u, v = f + (g & c1), u + (q & c1), v + (r & c1)
681-
# When shifting g down, don't shift q, r, as we construct a transition matrix multiplied
682-
# by 2^N. Instead, shift f's coefficients u and v up.
683-
g, u, v = g >> 1, u << 1, v << 1
684-
return zeta, (u, v, q, r)
703+
f, u, v = f + (g & c3), u + (q & c3), v + (r & c3)
704+
# Shift down f, q, r.
705+
g, q, r = g >> 1, u >> 1, v >> 1
706+
return theta, (u, v, q, r)
685707
```
686708

687709
- The functions to update *f* and *g*, and *d* and *e*, from section 2 and section 4, with the constant-time
@@ -723,15 +745,15 @@ def normalize(sign, v, M):
723745
return v
724746
```
725747

726-
- And finally the `modinv` function too, adapted to use *&zeta;* instead of *&delta;*, and using the fixed
748+
- And finally the `modinv` function too, adapted to use *&theta;* instead of *&delta;*, and using the fixed
727749
iteration count from section 5:
728750

729751
```python
730752
def modinv(M, Mi, x):
731753
"""Compute the modular inverse of x mod M, given Mi=1/M mod 2^N."""
732-
zeta, f, g, d, e = -1, M, x, 0, 1
754+
theta, f, g, d, e = 0, M, x, 0, 1
733755
for _ in range((590 + N - 1) // N):
734-
zeta, t = divsteps_n_matrix(zeta, f % 2**N, g % 2**N)
756+
theta, t = divsteps_n_matrix(theta, f % 2**N, g % 2**N)
735757
f, g = update_fg(f, g, t)
736758
d, e = update_de(d, e, t, M, Mi)
737759
return normalize(f, d, M)
@@ -745,7 +767,7 @@ def modinv(M, Mi, x):
745767
NEGINV16 = [15, 5, 3, 9, 7, 13, 11, 1] # NEGINV16[n//2] = (-n)^-1 mod 16, for odd n
746768
def divsteps_n_matrix_var(eta, f, g):
747769
"""Compute eta and transition matrix t after N divsteps (multiplied by 2^N)."""
748-
u, v, q, r = 1, 0, 0, 1
770+
u, v, q, r = 1, 0, 0, 1 # Start with identity matrix (not scaled; shift during run instead).
749771
i = N
750772
while True:
751773
zeros = min(i, count_trailing_zeros(g))

0 commit comments

Comments
 (0)