Skip to content

Commit 9876610

Browse files
committed
Extend mpc.random.shuffle() to lists of lists.
Consistent with mpc.sorted(), mpc.if_else(), mpc.if_swap(), mpc.min(), mpc.argmax() etc., which also work for lists of (all same length) lists.
1 parent 1dc3fe1 commit 9876610

File tree

2 files changed

+28
-7
lines changed

2 files changed

+28
-7
lines changed

mpyc/random.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -192,17 +192,35 @@ def shuffle(sectype, x):
192192
"""Shuffle list x secretly in-place, and return None.
193193
194194
Given list x may contain public or secret elements.
195+
Elements of x are all numbers or all lists (of the same length) of numbers.
195196
"""
196197
n = len(x)
197-
if not isinstance(x[0], sectype): # assume same type for all elts of x
198-
for i in range(len(x)):
199-
x[i] = sectype(x[i])
198+
# assume same type for all elts of x
199+
x_i_is_list = isinstance(x[0], list)
200+
if not x_i_is_list:
201+
# elements of x are numbers
202+
if not isinstance(x[0], sectype):
203+
for i in range(n):
204+
x[i] = sectype(x[i])
205+
for i in range(n-1):
206+
u = random_unit_vector(sectype, n - i)
207+
x_u = runtime.in_prod(x[i:], u)
208+
d = runtime.scalar_mul(x[i] - x_u, u)
209+
x[i] = x_u
210+
x[i:] = runtime.vector_add(x[i:], d)
211+
return
212+
213+
# elements of x are lists of numbers
214+
for j in range(len(x[0])):
215+
if not isinstance(x[0][j], sectype):
216+
for i in range(n):
217+
x[i][j] = sectype(x[i][j])
200218
for i in range(n-1):
201219
u = random_unit_vector(sectype, n - i)
202-
x_u = runtime.in_prod(x[i:], u)
203-
d = runtime.scalar_mul(x[i] - x_u, u)
220+
x_u = runtime.matrix_prod([u], x[i:])[0]
221+
d = runtime.matrix_prod([[a] for a in u], [runtime.vector_sub(x[i], x_u)])
204222
x[i] = x_u
205-
x[i:] = runtime.vector_add(x[i:], d)
223+
x[i:] = runtime.matrix_add(x[i:], d)
206224

207225

208226
def random_permutation(sectype, x):

tests/test_random.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,12 @@ def test_secint(self):
3838

3939
x = list(range(8))
4040
shuffle(secint, x)
41-
shuffle(secint, x)
4241
x = mpc.run(mpc.output(x))
4342
self.assertSetEqual(set(x), set(range(8)))
43+
x = list(map(list, zip(range(8), range(0, -8, -1))))
44+
shuffle(secint, x)
45+
a = mpc.run(mpc.output(x[0]))
46+
self.assertEqual(a[1], -a[0])
4447
x = mpc.run(mpc.output(random_permutation(secint, 8)))
4548
self.assertSetEqual(set(x), set(range(8)))
4649
x = mpc.run(mpc.output(random_derangement(secint, 2)))

0 commit comments

Comments
 (0)