Skip to content

Commit 073ba09

Browse files
committed
Finalize mpc.np_equal().
1 parent 4b9a03e commit 073ba09

File tree

2 files changed

+59
-8
lines changed

2 files changed

+59
-8
lines changed

mpyc/runtime.py

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,7 @@ def pow(self, a, b):
10701070

10711071
def np_pow(self, a, b):
10721072
"""Secure exponentiation a raised to the power of b, for public integer b."""
1073+
# TODO: extend to non-scalar b
10731074
if b == 254: # addition chain for AES S-Box (11 multiplications in 9 rounds)
10741075
d = a
10751076
c = self.np_multiply(d, d)
@@ -1142,7 +1143,7 @@ def is_zero(self, a):
11421143
return self.sgn(a, EQ=True)
11431144

11441145
@mpc_coro
1145-
async def _is_zero(self, a):
1146+
async def _is_zero(self, a): # a la [NO07]
11461147
"""Probabilistic zero test."""
11471148
stype = type(a)
11481149
await self.returnType((stype, True))
@@ -1156,9 +1157,9 @@ async def _is_zero(self, a):
11561157
a = a.value
11571158
r = self._randoms(Zp, k)
11581159
c = [Zp(a * r[i].value + (1-(z[i].value << 1)) * u2[i].value) for i in range(k)]
1159-
# -1 is nonsquare for Blum p, u_i !=0 w.v.h.p.
1160-
# If a == 0, c_i is square mod p iff z[i] == 0.
1161-
# If a != 0, c_i is square mod p independent of z[i].
1160+
# -1 is nonsquare for Blum p, u[i] !=0 w.v.h.p.
1161+
# If a == 0, c[i] is square mod p iff z[i] == 0.
1162+
# If a != 0, c[i] is square mod p independent of z[i].
11621163
c = await self.output(c, threshold=2*self.threshold)
11631164
for i in range(k):
11641165
if c[i] == 0:
@@ -2508,10 +2509,50 @@ def np_less(self, a, b):
25082509
return self.np_sgn(a - b, LT=True)
25092510

25102511
def np_equal(self, a, b):
2511-
"""Secure comparison a < b."""
2512-
return self.np_sgn(a - b, EQ=True)
2513-
# TODO: use prob. zerotest as well like in is_zero
2514-
# TODO: cover finite field arrays (introduce np_pow())
2512+
"""Secure comparison a == b."""
2513+
d = a - b
2514+
stype = d.sectype
2515+
if issubclass(stype, self.SecureFiniteField):
2516+
return 1 - self.np_pow(d, stype.field.order - 1)
2517+
2518+
if stype.bit_length/2 > self.options.sec_param >= 8 and stype.field.order%4 == 3:
2519+
return self._np_is_zero(d)
2520+
2521+
return self.np_sgn(d, EQ=True)
2522+
2523+
@mpc_coro
2524+
async def _np_is_zero(self, a):
2525+
"""Probabilistic zero test, elementwise."""
2526+
stype = type(a)
2527+
shape = a.shape
2528+
await self.returnType((stype, True, shape))
2529+
Zp = stype.sectype.field
2530+
2531+
n = a.size
2532+
k = self.options.sec_param
2533+
z = self.np_random_bits(Zp, k * n)
2534+
r = self._np_randoms(Zp, k * n)
2535+
u2 = self._reshare(r * r)
2536+
r = self._np_randoms(Zp, k * n)
2537+
a, u2, z = await self.gather(a, u2, z)
2538+
a = a.value.reshape((n,))
2539+
r = r.value.reshape((k, n))
2540+
z = z.value.reshape((k, n))
2541+
u2 = u2.value.reshape((k, n))
2542+
2543+
c = Zp.array(a * r + (1-(z << 1)) * u2)
2544+
del a, r, u2
2545+
# -1 is nonsquare for Blum p, u2[i,j] !=0 w.v.h.p.
2546+
# If a[j] == 0, c[i,j] is square mod p iff z[i,j] == 0.
2547+
# If a[j] != 0, c[i,j] is square mod p independent of z[i,j].
2548+
c = await self.output(c, threshold=2*self.threshold)
2549+
z = np.where(c.value == 0, 0, z)
2550+
c = np.where(c.is_sqr(), 1 - z, z)
2551+
del z
2552+
e = await self.np_all(map(Zp.array, c))
2553+
e <<= stype.frac_length
2554+
e = e.reshape(shape)
2555+
return e
25152556

25162557
@mpc_coro
25172558
async def np_sgn(self, a, l=None, LT=False, EQ=False):

tests/test_runtime.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,13 @@ def test_secfxp_array(self):
138138
np.assertEqual(b, a * a)
139139
self.assertTrue(np.issubdtype(b.dtype, np.floating))
140140

141+
np.assertEqual(mpc.run(mpc.output(np.equal(c, c))), True)
142+
np.assertEqual(mpc.run(mpc.output(np.equal(c, 0))), False)
143+
secnum = mpc.SecFxp(64)
144+
c = secnum.array(a)
145+
np.assertEqual(mpc.run(mpc.output(np.equal(c, c))), True)
146+
np.assertEqual(mpc.run(mpc.output(np.equal(c, 0))), False)
147+
141148
@unittest.skipIf(not np, 'NumPy not available or inside MPyC disabled')
142149
def test_secfld_array(self):
143150
np.assertEqual = np.testing.assert_array_equal
@@ -150,6 +157,9 @@ def test_secfld_array(self):
150157
self.assertEqual(len(c), 1)
151158
self.assertEqual(len(c.T), 2)
152159

160+
np.assertEqual(mpc.run(mpc.output(np.equal(c, c))), True)
161+
np.assertEqual(mpc.run(mpc.output(np.equal(c, c+1))), False)
162+
153163
def test_async(self):
154164
mpc.options.no_async = False
155165
a = mpc.SecInt()(7)

0 commit comments

Comments
 (0)