@@ -1070,6 +1070,7 @@ def pow(self, a, b):
1070
1070
1071
1071
def np_pow (self , a , b ):
1072
1072
"""Secure exponentiation a raised to the power of b, for public integer b."""
1073
+ # TODO: extend to non-scalar b
1073
1074
if b == 254 : # addition chain for AES S-Box (11 multiplications in 9 rounds)
1074
1075
d = a
1075
1076
c = self .np_multiply (d , d )
@@ -1142,7 +1143,7 @@ def is_zero(self, a):
1142
1143
return self .sgn (a , EQ = True )
1143
1144
1144
1145
@mpc_coro
1145
- async def _is_zero (self , a ):
1146
+ async def _is_zero (self , a ): # a la [NO07]
1146
1147
"""Probabilistic zero test."""
1147
1148
stype = type (a )
1148
1149
await self .returnType ((stype , True ))
@@ -1156,9 +1157,9 @@ async def _is_zero(self, a):
1156
1157
a = a .value
1157
1158
r = self ._randoms (Zp , k )
1158
1159
c = [Zp (a * r [i ].value + (1 - (z [i ].value << 1 )) * u2 [i ].value ) for i in range (k )]
1159
- # -1 is nonsquare for Blum p, u_i !=0 w.v.h.p.
1160
- # If a == 0, c_i is square mod p iff z[i] == 0.
1161
- # If a != 0, c_i is square mod p independent of z[i].
1160
+ # -1 is nonsquare for Blum p, u[i] !=0 w.v.h.p.
1161
+ # If a == 0, c[i] is square mod p iff z[i] == 0.
1162
+ # If a != 0, c[i] is square mod p independent of z[i].
1162
1163
c = await self .output (c , threshold = 2 * self .threshold )
1163
1164
for i in range (k ):
1164
1165
if c [i ] == 0 :
@@ -2508,10 +2509,50 @@ def np_less(self, a, b):
2508
2509
return self .np_sgn (a - b , LT = True )
2509
2510
2510
2511
def np_equal (self , a , b ):
2511
- """Secure comparison a < b."""
2512
- return self .np_sgn (a - b , EQ = True )
2513
- # TODO: use prob. zerotest as well like in is_zero
2514
- # TODO: cover finite field arrays (introduce np_pow())
2512
+ """Secure comparison a == b."""
2513
+ d = a - b
2514
+ stype = d .sectype
2515
+ if issubclass (stype , self .SecureFiniteField ):
2516
+ return 1 - self .np_pow (d , stype .field .order - 1 )
2517
+
2518
+ if stype .bit_length / 2 > self .options .sec_param >= 8 and stype .field .order % 4 == 3 :
2519
+ return self ._np_is_zero (d )
2520
+
2521
+ return self .np_sgn (d , EQ = True )
2522
+
2523
+ @mpc_coro
2524
+ async def _np_is_zero (self , a ):
2525
+ """Probabilistic zero test, elementwise."""
2526
+ stype = type (a )
2527
+ shape = a .shape
2528
+ await self .returnType ((stype , True , shape ))
2529
+ Zp = stype .sectype .field
2530
+
2531
+ n = a .size
2532
+ k = self .options .sec_param
2533
+ z = self .np_random_bits (Zp , k * n )
2534
+ r = self ._np_randoms (Zp , k * n )
2535
+ u2 = self ._reshare (r * r )
2536
+ r = self ._np_randoms (Zp , k * n )
2537
+ a , u2 , z = await self .gather (a , u2 , z )
2538
+ a = a .value .reshape ((n ,))
2539
+ r = r .value .reshape ((k , n ))
2540
+ z = z .value .reshape ((k , n ))
2541
+ u2 = u2 .value .reshape ((k , n ))
2542
+
2543
+ c = Zp .array (a * r + (1 - (z << 1 )) * u2 )
2544
+ del a , r , u2
2545
+ # -1 is nonsquare for Blum p, u2[i,j] !=0 w.v.h.p.
2546
+ # If a[j] == 0, c[i,j] is square mod p iff z[i,j] == 0.
2547
+ # If a[j] != 0, c[i,j] is square mod p independent of z[i,j].
2548
+ c = await self .output (c , threshold = 2 * self .threshold )
2549
+ z = np .where (c .value == 0 , 0 , z )
2550
+ c = np .where (c .is_sqr (), 1 - z , z )
2551
+ del z
2552
+ e = await self .np_all (map (Zp .array , c ))
2553
+ e <<= stype .frac_length
2554
+ e = e .reshape (shape )
2555
+ return e
2515
2556
2516
2557
@mpc_coro
2517
2558
async def np_sgn (self , a , l = None , LT = False , EQ = False ):
0 commit comments