Skip to content

Commit f3fe903

Browse files
committed
Add np.outer for secure arrays.
1 parent 4dbd3d5 commit f3fe903

File tree

4 files changed

+43
-3
lines changed

4 files changed

+43
-3
lines changed

demos/np_lpsolver.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ async def main():
202202
pp_inv = 1 / previous_pivot
203203
p_col = p_col * pp_inv - p_row_index
204204
p_row = p_row_index @ T + previous_pivot * p_col_index
205-
T = T * (pivot * pp_inv) - p_col.reshape(len(p_col), 1) @ p_row.reshape(1, len(p_row)) # consider np.gauss
205+
T = T * (pivot * pp_inv) - np.outer(p_col, p_row)
206206
previous_pivot = pivot
207207

208208
mx = await mpc.output(T[0, -1])

demos/np_lpsolverfxp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,14 +178,14 @@ async def main():
178178
cobasis += delta * p_col_index
179179
basis -= delta * p_row_index
180180

181-
# update tableau Tij = Tij - (Til - bool(i==k))/Tkl * (Tkj + bool(j==l))
181+
# update tableau Tij = Tij - (Til - bool(i==k))/Tkl *outer (Tkj + bool(j==l))
182182
p_col_index = np.concatenate((p_col_index, np.array([0])))
183183
p_row_index = np.concatenate((np.array([0]), p_row_index))
184184
p_col_index.integral = True
185185
p_row_index.integral = True
186186
p_col = (p_col - p_row_index) / pivot
187187
p_row = p_row_index @ T + p_col_index
188-
T -= p_col.reshape(len(p_col), 1) @ p_row.reshape(1, len(p_row))
188+
T -= np.outer(p_col, p_row)
189189

190190
mx = await mpc.output(T[0, -1])
191191
rel_error = (mx - exact_max) / exact_max

mpyc/runtime.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2168,6 +2168,43 @@ async def np_matmul(self, A, B):
21682168
C = self.np_trunc(stype(C, shape=shape))
21692169
return C
21702170

2171+
@mpc_coro
2172+
async def np_outer(self, a, b):
2173+
"""Outer product of vectors a and b.
2174+
2175+
Input arrays a and b are flattened if not already 1d.
2176+
"""
2177+
sha = isinstance(a, self.SecureObject)
2178+
shb = isinstance(b, self.SecureObject)
2179+
stype = type(a) if sha else type(b)
2180+
shape = (a.size, b.size)
2181+
f = stype.frac_length
2182+
if not f:
2183+
rettype = (stype, shape)
2184+
else:
2185+
a_integral = a.integral
2186+
b_integral = b.integral
2187+
rettype = (stype, a_integral and b_integral, shape)
2188+
# TODO: handle a or b public integral value
2189+
await self.returnType(rettype)
2190+
2191+
if a is b:
2192+
a = b = await self.gather(a)
2193+
elif sha and shb:
2194+
a, b = await self.gather(a, b)
2195+
elif sha:
2196+
a = await self.gather(a)
2197+
else:
2198+
b = await self.gather(b)
2199+
c = np.outer(a, b) # NB: flattens a and/or b
2200+
if f and (a_integral or b_integral):
2201+
c >>= f # NB: in-place rshift
2202+
if sha and shb:
2203+
c = self._reshare(c)
2204+
if f and not a_integral and not b_integral:
2205+
c = self.np_trunc(stype(c, shape=shape))
2206+
return c
2207+
21712208
@mpc_coro_no_pc
21722209
async def np_getitem(self, a, key):
21732210
"""SecureArray a, index/slice key."""

tests/test_runtime.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def test_secint_array(self):
4646
d = np.stack((c, c, c))
4747
np.assertEqual(mpc.run(mpc.output(c @ d)), a @ b)
4848
np.assertEqual(mpc.run(mpc.output(d @ d)), b @ b)
49+
np.assertEqual(mpc.run(mpc.output(np.outer(c, c))), np.outer(a, a))
4950
np.assertEqual(mpc.run(mpc.output(np.stack((c, c), axis=1))), np.stack((a, a), axis=1))
5051
np.assertEqual(mpc.run(mpc.output(np.block([[c, c], [c, c]]))), np.block([[a, a], [a, a]]))
5152
np.assertEqual(mpc.run(mpc.output(np.block([[secint(9), -1]]))), np.block([[9, -1]]))
@@ -141,6 +142,7 @@ def test_secfxp_array(self):
141142
np.assertEqual(mpc.run(mpc.output(mpc.run(mpc.transfer(c, senders=0)))), a)
142143
np.assertEqual(mpc.run(mpc.output(mpc.input(c, senders=0))), a)
143144
np.assertEqual(mpc.run(mpc.output(mpc._reshare(c) @ c)), a @ a)
145+
np.assertEqual(mpc.run(mpc.output(np.outer(c, c))), np.outer(a, a))
144146
b = mpc.run(mpc.output(c * c))
145147
np.assertEqual(b, a * a)
146148
self.assertTrue(np.issubdtype(b.dtype, np.floating))
@@ -162,6 +164,7 @@ def test_secfld_array(self):
162164
c = secfld.array(a)
163165
np.assertEqual(mpc.run(mpc.np_is_zero_public(c.flatten())), [False, True, True, False])
164166
self.assertEqual(mpc.run(mpc.output(c.flatten().tolist())), [-1, 0, 0, -1])
167+
np.assertEqual(mpc.run(mpc.output(np.outer(c, c))), np.outer(a, a))
165168
self.assertEqual(len(c), 1)
166169
self.assertEqual(len(c.T), 2)
167170
self.assertTrue(bool(c))

0 commit comments

Comments
 (0)