Skip to content

Commit 1dc3fe1

Browse files
authored
First version of vectorized mpc.np_argmin/max().
1 parent f3fe903 commit 1dc3fe1

12 files changed

+215
-231
lines changed

.travis.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ python:
44
- 3.8
55
- 3.9
66
- 3.10
7+
- 3.11
78
- pypy3.8-7.3.9
89
install:
910
- pip install --upgrade pip

demos/lpsolver.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -153,13 +153,13 @@ async def main():
153153
args = parser.parse_args()
154154

155155
settings = [('uvlp', 8, 1, 2),
156-
('wiki', 6, 1, 2),
156+
('wiki', 6, 1, 1),
157157
('tb2x2', 6, 1, 2),
158158
('woody', 8, 1, 3),
159-
('LPExample_R20', 70, 1, 5),
159+
('LPExample_R20', 70, 1, 9),
160160
('sc50b', 104, 10, 55),
161-
('kb2', 536, 100000, 106),
162-
('LPExample', 110, 1, 178)]
161+
('kb2', 560, 100000, 154),
162+
('LPExample', 110, 1, 175)]
163163
name, bit_length, scale, n_iter = settings[args.dataset]
164164
if args.bit_length:
165165
bit_length = args.bit_length
@@ -200,11 +200,9 @@ async def main():
200200
previous_pivot = secint(1)
201201

202202
iteration = 0
203-
while True:
203+
while await mpc.output((arg_min := argmin_int(T[0][:-1]))[1] < 0):
204204
# find index of pivot column
205-
p_col_index, minimum = argmin_int(T[0][:-1])
206-
if await mpc.output(minimum >= 0):
207-
break # maximum reached
205+
p_col_index = arg_min[0]
208206

209207
# find index of pivot row
210208
p_col = mpc.matrix_prod([p_col_index], T, True)[0]

demos/lpsolverfxp.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,9 @@ async def main():
8787
basis = [secfxp(n + i) for i in range(m)]
8888

8989
iteration = 0
90-
while True:
90+
while await mpc.output((arg_min := argmin_int(T[0][:-1]))[1] < 0):
9191
# find index of pivot column
92-
p_col_index, minimum = argmin_int(T[0][:-1])
93-
94-
if await mpc.output(minimum >= 0):
95-
break # maximum reached
92+
p_col_index = arg_min[0]
9693

9794
# find index of pivot row
9895
p_col = mpc.matrix_prod([p_col_index], T, True)[0]

demos/np_bnnmnist.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -240,8 +240,7 @@ async def main():
240240
if args.no_legendre:
241241
secint.bit_length = 14
242242
for i in range(batch_size):
243-
prediction = int(await mpc.output(mpc.argmax(L[i].tolist())[0]))
244-
243+
prediction = await mpc.output(np.argmax(L[i]))
245244
error = '******* ERROR *******' if prediction != labels[i] else ''
246245
print(f'Image #{offset+i} with label {labels[i]}: {prediction} predicted. {error}')
247246
print(await mpc.output(L[i]))

demos/np_cnnmnist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ async def main():
177177
secnum.bit_length = 37
178178

179179
for i in range(batch_size):
180-
prediction = int(await mpc.output(mpc.argmax(x[i].tolist())[0]))
180+
prediction = int(await mpc.output(np.argmax(x[i])))
181181
error = '******* ERROR *******' if prediction != labels[i] else ''
182182
print(f'Image #{offset+i} with label {labels[i]}: {prediction} predicted. {error}')
183183
print(await mpc.output(x[i]))

demos/np_id3gini.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
77
See id3gini.py for background information on decision tree learning and ID3.
88
"""
9-
# TODO: vectorize mpc.argmax()
109

1110
import os
1211
import logging
@@ -20,7 +19,7 @@
2019
@mpc.coroutine
2120
async def id3(T, R) -> asyncio.Future:
2221
sizes = S[C] @ T
23-
i, mx = mpc.argmax(sizes)
22+
i, mx = sizes.argmax(raw=False)
2423
sizeT = sizes.sum()
2524
stop = (sizeT <= int(args.epsilon * len(T))) + (mx == sizeT)
2625
if not (R and await mpc.is_zero_public(stop)):
@@ -29,7 +28,8 @@ async def id3(T, R) -> asyncio.Future:
2928
tree = i
3029
else:
3130
T_SC = (T * S[C]).T
32-
k = mpc.argmax([GI(S[A] @ T_SC) for A in R], key=SecureFraction)[0]
31+
CT = np.stack(tuple(GI(S[A] @ T_SC) for A in R))
32+
k = CT.argmax(key=SecureFraction, raw=False, raw2=False)
3333
A = list(R)[await mpc.output(k)]
3434
logging.info(f'Attribute node {A}')
3535
T_SA = T * S[A]
@@ -46,15 +46,15 @@ def GI(x):
4646
y = args.alpha * np.sum(x, axis=1) + 1 # NB: alternatively, use s + (s == 0)
4747
D = mpc.prod(y.tolist())
4848
G = np.sum(np.sum(x * x, axis=1) / y)
49-
return [D * G, D] # numerator, denominator
49+
return mpc.np_fromlist([D * G, D]) # numerator, denominator
5050

5151

5252
class SecureFraction:
5353
def __init__(self, a):
54-
self.n, self.d = a # numerator, denominator
54+
self.a = a # numerator, denominator
5555

5656
def __lt__(self, other): # NB: __lt__() is basic comparison as in Python's list.sort()
57-
return mpc.in_prod([self.n, -self.d], [other.d, other.n]) < 0
57+
return self.a[:, 0] * other.a[:, 1] < self.a[:, 1] * other.a[:, 0]
5858

5959

6060
depth = lambda tree: 0 if isinstance(tree, int) else max(map(depth, tree[1])) + 1

demos/np_lpsolver.py

Lines changed: 14 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -16,69 +16,12 @@
1616
from mpyc.runtime import mpc
1717

1818

19-
# TODO: unify approaches for argmin etc. with secure NumPy arrays (also see lpsolverfxp.py)
19+
class SecureFraction:
20+
def __init__(self, a):
21+
self.a = a # numerator, denominator
2022

21-
22-
def argmin_int(x):
23-
secintarr = type(x)
24-
n = len(x)
25-
if n == 1:
26-
return (secintarr(np.array([1])), x[0])
27-
28-
if n == 2:
29-
b = x[0] < x[1]
30-
arg = mpc.np_fromlist([b, 1 - b])
31-
min = b * (x[0] - x[1]) + x[1]
32-
return arg, min
33-
34-
a = x[:n//2] ## split even odd? start at n%2 as in reduce in mpctools
35-
b = x[(n+1)//2:]
36-
c = a < b
37-
m = c * (a - b) + b
38-
if n%2 == 1:
39-
m = np.concatenate((m, x[n//2:(n+1)//2]))
40-
ag, mn = argmin_int(m)
41-
if n%2 == 1:
42-
ag_1 = ag[-1:]
43-
ag = ag[:-1]
44-
arg1 = ag * c
45-
arg2 = ag - arg1
46-
if n%2 == 1:
47-
arg = np.concatenate((arg1, ag_1, arg2))
48-
else:
49-
arg = np.concatenate((arg1, arg2))
50-
return arg, mn
51-
52-
53-
def argmin_rat(nd):
54-
secintarr = type(nd)
55-
N = nd.shape[1]
56-
if N == 1:
57-
return (secintarr(np.array([1])), (nd[0, 0], nd[1, 0]))
58-
59-
if N == 2:
60-
b = mpc.in_prod([nd[0, 0], -nd[1, 0]], [nd[1, 1], nd[0, 1]]) < 0
61-
arg = mpc.np_fromlist([b, 1 - b])
62-
min = b * (nd[:, 0] - nd[:, 1]) + nd[:, 1]
63-
return arg, (min[0], min[1])
64-
65-
a = nd[:, :N//2]
66-
b = nd[:, (N+1)//2:]
67-
c = a[0] * b[1] < b[0] * a[1]
68-
m = c * (a - b) + b
69-
if N%2 == 1:
70-
m = np.concatenate((m, nd[:, N//2:(N+1)//2]), axis=1)
71-
ag, mn = argmin_rat(m)
72-
if N%2 == 1:
73-
ag_1 = ag[-1:]
74-
ag = ag[:-1]
75-
arg1 = ag * c
76-
arg2 = ag - arg1
77-
if N%2 == 1:
78-
arg = np.concatenate((arg1, ag_1, arg2))
79-
else:
80-
arg = np.concatenate((arg1, arg2))
81-
return arg, mn
23+
def __lt__(self, other): # NB: __lt__() is basic comparison as in Python's list.sort()
24+
return self.a[:, 0] * other.a[:, 1] < self.a[:, 1] * other.a[:, 0]
8225

8326

8427
def pow_list(a, x, n):
@@ -135,13 +78,13 @@ async def main():
13578
args = parser.parse_args()
13679

13780
settings = [('uvlp', 8, 1, 2),
138-
('wiki', 6, 1, 2),
81+
('wiki', 6, 1, 1),
13982
('tb2x2', 6, 1, 2),
14083
('woody', 8, 1, 3),
141-
('LPExample_R20', 70, 1, 5),
84+
('LPExample_R20', 70, 1, 9),
14285
('sc50b', 104, 10, 55),
143-
('kb2', 536, 100000, 106),
144-
('LPExample', 110, 1, 178)]
86+
('kb2', 560, 100000, 154),
87+
('LPExample', 110, 1, 175)]
14588
name, bit_length, scale, n_iter = settings[args.dataset]
14689
if args.bit_length:
14790
bit_length = args.bit_length
@@ -173,17 +116,15 @@ async def main():
173116
previous_pivot = secint(1)
174117

175118
iteration = 0
176-
while True:
119+
while await mpc.output((arg_min := T[0, :-1].argmin())[1] < 0):
177120
# find index of pivot column
178-
p_col_index, minimum = argmin_int(T[0, :-1])
179-
if await mpc.output(minimum >= 0):
180-
break # maximum reached
121+
p_col_index = arg_min[0]
181122

182123
# find index of pivot row
183124
p_col = T[:, :-1] @ p_col_index
184-
den = p_col[1:]
185-
num = T[1:, -1] + (den <= 0)
186-
p_row_index, (_, pivot) = argmin_rat(np.stack((num, den)))
125+
denominator = p_col[1:]
126+
constraints = np.column_stack((T[1:, -1] + (denominator <= 0), denominator))
127+
p_row_index, (_, pivot) = constraints.argmin(key=SecureFraction)
187128

188129
# reveal progress a bit
189130
iteration += 1
@@ -198,7 +139,6 @@ async def main():
198139
# update tableau Tij = Tij*Tkl/Tkl' - (Til/Tkl' - bool(i==k)) * (Tkj + bool(j==l)*Tkl')
199140
p_col_index = np.concatenate((p_col_index, np.array([0])))
200141
p_row_index = np.concatenate((np.array([0]), p_row_index))
201-
202142
pp_inv = 1 / previous_pivot
203143
p_col = p_col * pp_inv - p_row_index
204144
p_row = p_row_index @ T + previous_pivot * p_col_index

0 commit comments

Comments
 (0)