Skip to content

Commit 4b9a03e

Browse files
authored
Add np_aes, np_lpsolver(fxp) demos.
1 parent f6f1952 commit 4b9a03e

File tree

9 files changed

+796
-15
lines changed

9 files changed

+796
-15
lines changed

demos/np_aes.py

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
"""Demo Threshold AES cipher, vectorized.
2+
3+
This demo is a fully equivalent reimplementation of the aes.py demo.
4+
Secure arrays over GF(256) to perform all computations in a vectorized
5+
manner. For example, in each encryption round the S-Boxes are evaluated
6+
for all 16 bytes of the state in one go; in aes.py this was done by
7+
applying the S-Box to each byte one at a time. Similarly, the 4 S-Boxes
8+
in each round of the key expansion are evaluated in one go.
9+
10+
Apart from reducing the overhead, which makes the vectorized version about
11+
twice as fast as aes.py, the code is rather simple as well.
12+
13+
See demo aes.py for background information.
14+
"""
15+
16+
import sys
17+
import numpy as np
18+
from mpyc.runtime import mpc
19+
20+
secfld = mpc.SecFld(2**8) # Secure AES field GF(2^8) for secret values.
21+
f256 = secfld.field # Plain AES field GF(2^8) for public values.
22+
23+
24+
def circulant(r):
25+
"""Circulant matrix with first row r."""
26+
r = np.stack([np.roll(r, j, axis=0) for j in range(len(r))])
27+
return f256.array(r)
28+
29+
30+
A = circulant([1, 0, 0, 0, 1, 1, 1, 1]) # 8x8 matrix A over GF(2)
31+
A1 = np.linalg.inv(A) # inverse of A
32+
B = f256.array([1, 1, 0, 0, 0, 1, 1, 0]) # vector B over GF(2)
33+
C = circulant([2, 3, 1, 1]) # 4x4 matrix C over GF(2^8)
34+
C1 = np.linalg.inv(C) # inverse of C
35+
36+
37+
def sbox(x):
38+
"""AES S-Box."""
39+
x = mpc.np_to_bits(x**254)
40+
x = (A @ x[..., np.newaxis]).reshape(*x.shape)
41+
x += B
42+
x = mpc.np_from_bits(x)
43+
return x
44+
45+
46+
def sbox1(x):
47+
"""AES inverse S-Box."""
48+
x = mpc.np_to_bits(x)
49+
x += B
50+
x = (A1 @ x[..., np.newaxis]).reshape(*x.shape)
51+
x = mpc.np_from_bits(x)**254
52+
return x
53+
54+
55+
def key_expansion(k):
56+
"""AES key expansion for 128/256-bit keys."""
57+
w = k
58+
Nk = k.shape[1] # Nk is 4 or 8
59+
Nr = 10 if Nk == 4 else 14
60+
for i in range(Nk, 4*(Nr+1)):
61+
t = w[:, -1]
62+
if i % Nk in (0, 4):
63+
t = sbox(t)
64+
if i % Nk == 0:
65+
t = np.roll(t, -1, axis=0)
66+
t = mpc.np_update(t, 0, t[0] + (f256(1) << i // Nk - 1))
67+
t += w[:, -Nk]
68+
t = t.reshape(4, 1)
69+
w = np.append(w, t, axis=1)
70+
K = np.hsplit(w, Nr+1)
71+
return K
72+
73+
74+
def encrypt(K, s):
75+
"""AES encryption of s given key schedule K."""
76+
Nr = len(K) - 1 # Nr is 10 or 14
77+
s += K[0]
78+
for i in range(1, Nr+1):
79+
s = sbox(s)
80+
s = np.stack([np.roll(s[j], -j, axis=0) for j in range(4)])
81+
if i < Nr:
82+
s = C @ s
83+
s += K[i]
84+
return s
85+
86+
87+
def decrypt(K, s):
88+
"""AES decryption of s given key schedule K."""
89+
Nr = len(K) - 1 # Nr is 10 or 14
90+
for i in range(Nr, 0, -1):
91+
s += K[i]
92+
if i < Nr:
93+
s = C1 @ s
94+
s = np.stack([np.roll(s[j], j, axis=0) for j in range(4)])
95+
s = sbox1(s)
96+
s += K[0]
97+
return s
98+
99+
100+
async def xprint(text, s):
101+
"""Print matrix s transposed and flattened as hex string."""
102+
s = await mpc.output(s)
103+
s = s.T.flatten()
104+
print(f'{text} {bytes(map(int, s)).hex()}')
105+
106+
107+
async def main():
108+
if sys.argv[1:]:
109+
full = False
110+
print('AES-128 encryption only.')
111+
else:
112+
full = True
113+
print('AES-128 en/decryption and AES-256 en/decryption.')
114+
115+
print('AES polynomial:', f256.modulus) # x^8 + x^4 + x^3 + x + 1
116+
117+
await mpc.start()
118+
119+
p = secfld.array(f256.array([[17 * (4*j + i) for j in range(4)] for i in range(4)]))
120+
await xprint('Plaintext: ', p)
121+
122+
k128 = secfld.array(f256.array([[4*j + i for j in range(4)] for i in range(4)]))
123+
await xprint('AES-128 key:', k128)
124+
K = key_expansion(k128)
125+
c = encrypt(K, p)
126+
await xprint('Ciphertext: ', c)
127+
if full:
128+
p = decrypt(K, c)
129+
await xprint('Plaintext: ', p)
130+
131+
k256 = secfld.array(f256.array([[4*j + i for j in range(8)] for i in range(4)]))
132+
await xprint('AES-256 key:', k256)
133+
K = key_expansion(k256)
134+
c = encrypt(K, p)
135+
136+
await xprint('Ciphertext: ', c)
137+
p = decrypt(K, c)
138+
await xprint('Plaintext: ', p)
139+
140+
await mpc.shutdown()
141+
142+
if __name__ == '__main__':
143+
mpc.run(main())

demos/np_lpsolver.py

Lines changed: 240 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,240 @@
1+
"""Demo Linear Programming (LP) solver, using secure integer arithmetic.
2+
3+
Vectorized.
4+
5+
See demo lpsolver.py for background information.
6+
7+
... work in progress for MPyC version 0.9
8+
"""
9+
10+
import os
11+
import logging
12+
import argparse
13+
import csv
14+
import math
15+
import numpy as np
16+
from mpyc.runtime import mpc
17+
18+
19+
# TODO: unify approaches for argmin etc. with secure NumPy arrays (also see lpsolverfxp.py)
20+
21+
22+
def argmin_int(x):
23+
secintarr = type(x)
24+
n = len(x)
25+
if n == 1:
26+
return (secintarr(np.array([1])), x[0])
27+
28+
if n == 2:
29+
b = x[0] < x[1]
30+
arg = mpc.np_fromlist([b, 1 - b])
31+
min = b * (x[0] - x[1]) + x[1]
32+
return arg, min
33+
34+
a = x[:n//2] ## split even odd? start at n%2 as in reduce in mpctools
35+
b = x[(n+1)//2:]
36+
c = a < b
37+
m = c * (a - b) + b
38+
if n%2 == 1:
39+
m = np.concatenate((m, x[n//2:(n+1)//2]))
40+
ag, mn = argmin_int(m)
41+
if n%2 == 1:
42+
ag_1 = ag[-1:]
43+
ag = ag[:-1]
44+
arg1 = ag * c
45+
arg2 = ag - arg1
46+
if n%2 == 1:
47+
arg = np.concatenate((arg1, ag_1, arg2))
48+
else:
49+
arg = np.concatenate((arg1, arg2))
50+
return arg, mn
51+
52+
53+
def argmin_rat(nd):
54+
secintarr = type(nd)
55+
N = nd.shape[1]
56+
if N == 1:
57+
return (secintarr(np.array([1])), (nd[0, 0], nd[1, 0]))
58+
59+
if N == 2:
60+
b = mpc.in_prod([nd[0, 0], -nd[1, 0]], [nd[1, 1], nd[0, 1]]) < 0
61+
arg = mpc.np_fromlist([b, 1 - b])
62+
min = b * (nd[:, 0] - nd[:, 1]) + nd[:, 1]
63+
return arg, (min[0], min[1])
64+
65+
a = nd[:, :N//2]
66+
b = nd[:, (N+1)//2:]
67+
c = a[0] * b[1] < b[0] * a[1]
68+
m = c * (a - b) + b
69+
if N%2 == 1:
70+
m = np.concatenate((m, nd[:, N//2:(N+1)//2]), axis=1)
71+
ag, mn = argmin_rat(m)
72+
if N%2 == 1:
73+
ag_1 = ag[-1:]
74+
ag = ag[:-1]
75+
arg1 = ag * c
76+
arg2 = ag - arg1
77+
if N%2 == 1:
78+
arg = np.concatenate((arg1, ag_1, arg2))
79+
else:
80+
arg = np.concatenate((arg1, arg2))
81+
return arg, mn
82+
83+
84+
def pow_list(a, x, n):
85+
"""Return [a,ax, ax^2, ..., ax^(n-1)].
86+
87+
Runs in O(log n) rounds using minimal number of n-1 secure multiplications.
88+
"""
89+
if n == 1:
90+
powers = mpc.np_fromlist([a])
91+
elif n == 2:
92+
powers = mpc.np_fromlist([a, a * x])
93+
else:
94+
even_powers = pow_list(a, x * x, (n+1)//2)
95+
if n%2:
96+
d = even_powers[-1]
97+
even_powers = even_powers[:-1]
98+
odd_powers = x * even_powers
99+
powers = np.stack((even_powers, odd_powers))
100+
powers = powers.T.reshape(n - (n%2))
101+
if n%2:
102+
powers = np.append(powers, mpc.np_fromlist([d]))
103+
return powers
104+
105+
106+
# TODO: consider next version of unit_vector() as alternative for current mpc.unit_vector()
107+
@mpc.coroutine
108+
async def unit_vector(a, n):
109+
"""Length-n unit vector [0]*a + [1] + [0]*(n-1-a) for secret a, assuming 0 <= a < n.
110+
111+
NB: If a = n, unit vector [1] + [0]*(n-1) is returned. See mpyc.statistics.
112+
"""
113+
await mpc.returnType(type(a), n)
114+
# TODO: add conversion from prime field GF(p) to secint, 0<=a<n, n < p/2 needed?
115+
u = await mpc.gather(mpc.random.random_unit_vector(type(a), n))
116+
r = sum(a * b for a, b in zip(u, range(n)))
117+
R = 1 + mpc._random(type(a).field, 1<<mpc.options.sec_param)
118+
c = await mpc.output(a - r + R * n)
119+
c %= n
120+
# rotate u over c positions to the right
121+
v = u[:n-c]
122+
u = u[n-c:]
123+
u.extend(v)
124+
return u
125+
126+
127+
async def main():
128+
parser = argparse.ArgumentParser()
129+
parser.add_argument('-i', '--dataset', type=int, metavar='I',
130+
help=('dataset 0=uvlp (default), 1=wiki, 2=tb2x2, 3=woody, '
131+
'4=LPExample_R20, 5=sc50b, 6=kb2, 7=LPExample'))
132+
parser.add_argument('-l', '--bit-length', type=int, metavar='L',
133+
help='override preset bit length for dataset')
134+
parser.set_defaults(dataset=0, bit_length=0)
135+
args = parser.parse_args()
136+
137+
settings = [('uvlp', 8, 1, 2),
138+
('wiki', 6, 1, 2),
139+
('tb2x2', 6, 1, 2),
140+
('woody', 8, 1, 3),
141+
('LPExample_R20', 70, 1, 5),
142+
('sc50b', 104, 10, 55),
143+
('kb2', 536, 100000, 106),
144+
('LPExample', 110, 1, 178)]
145+
name, bit_length, scale, n_iter = settings[args.dataset]
146+
if args.bit_length:
147+
bit_length = args.bit_length
148+
149+
T = np.genfromtxt(os.path.join('data', 'lp', name + '.csv'), dtype=float, delimiter=',')
150+
m, n = T.shape[0] - 1, T.shape[1] - 1
151+
secint = mpc.SecInt(bit_length, n=m + n) # force existence of Nth root of unity, N>=m+n
152+
print(f'Using secure {bit_length}-bit integers: {secint.__name__}')
153+
print(f'dataset: {name} with {m} constraints and {n} variables (scale factor {scale})')
154+
T[0, -1] = 0.0 # initialize optimal value
155+
T = np.vectorize(int, otypes='O')(T * scale)
156+
g = np.gcd.reduce(T[1:], axis=1, keepdims=True)
157+
T[1:] //= np.maximum(g, 1) # remove common factors per row (skipping cost row)
158+
T = secint.array(T)
159+
c, A, b = -T[0, :-1], T[1:, :-1], T[1:, -1] # maximize c.x subject to A.x <= b, x >= 0
160+
161+
Zp = secint.field
162+
N = Zp.nth
163+
w = Zp.root # w is an Nth root of unity in Zp, where N >= m + n
164+
w_powers = [Zp(1)]
165+
for _ in range(N-1):
166+
w_powers.append(w_powers[-1] * w)
167+
assert w_powers[-1] * w == 1
168+
169+
await mpc.start()
170+
171+
cobasis = Zp.array(np.array([w_powers[-j].value for j in range(n)]))
172+
basis = Zp.array(np.array([w_powers[-(i + n)].value for i in range(m)]))
173+
previous_pivot = secint(1)
174+
175+
iteration = 0
176+
while True:
177+
# find index of pivot column
178+
p_col_index, minimum = argmin_int(T[0, :-1])
179+
if await mpc.output(minimum >= 0):
180+
break # maximum reached
181+
182+
# find index of pivot row
183+
p_col = T[:, :-1] @ p_col_index
184+
den = p_col[1:]
185+
num = T[1:, -1] + (den <= 0)
186+
p_row_index, (_, pivot) = argmin_rat(np.stack((num, den)))
187+
188+
# reveal progress a bit
189+
iteration += 1
190+
mx, cd, p = await mpc.output([T[0, -1], previous_pivot, pivot])
191+
logging.info(f'Iteration {iteration}/{n_iter}: {mx / cd} pivot={p / cd}')
192+
193+
# swap basis entries
194+
delta = basis @ p_row_index - cobasis @ p_col_index
195+
cobasis += delta * p_col_index
196+
basis -= delta * p_row_index
197+
198+
# update tableau Tij = Tij*Tkl/Tkl' - (Til/Tkl' - bool(i==k)) * (Tkj + bool(j==l)*Tkl')
199+
p_col_index = np.concatenate((p_col_index, np.array([0])))
200+
p_row_index = np.concatenate((np.array([0]), p_row_index))
201+
202+
pp_inv = 1 / previous_pivot
203+
p_col = p_col * pp_inv - p_row_index
204+
p_row = p_row_index @ T + previous_pivot * p_col_index
205+
T = T * (pivot * pp_inv) - p_col.reshape(len(p_col), 1) @ p_row.reshape(1, len(p_row)) # consider np.gauss
206+
previous_pivot = pivot
207+
208+
mx = await mpc.output(T[0, -1])
209+
cd = await mpc.output(previous_pivot) # common denominator for all entries of T
210+
print(f'max = {mx} / {cd} / {scale} = {mx / cd / scale} in {iteration} iterations')
211+
212+
logging.info('Solution x')
213+
sum_powers = secint.array(np.zeros((N,), dtype='O')) # TODO: compare with fromiter approach below
214+
for i in range(m):
215+
x_powers = pow_list(T[i+1][-1] / N, basis[i], N)
216+
sum_powers += x_powers
217+
coefs = Zp.array([[w_powers[(j * k) % N].value for k in range(N)] for j in range(n)]) ## TODO: vandermonde ff array
218+
x = coefs @ sum_powers
219+
Ax_bounded_by_b = mpc.all((A @ x <= b * cd).tolist()) ## TODO: np.all
220+
x_nonnegative = mpc.all((x >= 0).tolist())
221+
222+
logging.info('Dual solution y')
223+
coefs = Zp.array([[w_powers[((n + i) * k) % N].value for k in range(N)] for i in range(m)])
224+
sum_powers = np.sum(np.fromiter((pow_list(T[0][j] / N, cobasis[j], N) for j in range(n)), 'O')) #TODO fix for Numpy 1.22, WSL
225+
y = coefs @ sum_powers
226+
yA_bounded_by_c = mpc.all((y @ A >= c * cd).tolist())
227+
y_nonnegative = mpc.all((y >= 0).tolist())
228+
229+
cx_eq_yb = c @ x == y @ b
230+
check = mpc.all([cx_eq_yb, Ax_bounded_by_b, x_nonnegative, yA_bounded_by_c, y_nonnegative])
231+
check = bool(await mpc.output(check))
232+
print(f'verification c.x == y.b, A.x <= b, x >= 0, y.A >= c, y >= 0: {check}')
233+
234+
x = await mpc.output(x)
235+
print(f'solution = {[a / cd for a in x]}')
236+
237+
await mpc.shutdown()
238+
239+
if __name__ == '__main__':
240+
mpc.run(main())

0 commit comments

Comments
 (0)