Skip to content

Commit 2629c01

Browse files
authored
Add np_onewayhashchains demo; secfxp (np)mul optimization.
1 parent 701b73b commit 2629c01

File tree

7 files changed

+232
-26
lines changed

7 files changed

+232
-26
lines changed

demos/np_onewayhashchains.py

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
"""Demo Threshold One-Way Hash Chains, vectorized.
2+
3+
This demo is an extended reimplementation of the onewayhashchain.py demo for
4+
generating and reversing one-way hash chains in a multiparty setting.
5+
6+
In addition to the Matyas-Meyer-Oseas one-way function based on AES, the SHAKE128
7+
oneway function from the SHA3 faimlty is also provided as an option.
8+
9+
Note that in the output stage the hashes pertaining to different pebbles are
10+
evaluated in parallel, without increasing the overall round complexity. Multiple
11+
hashes pertaining to the same pebble, however, are necessarily evaluated in series,
12+
increasing the overall round complxity accordingly.
13+
14+
See demo onewayhashchain.py for more information.
15+
"""
16+
17+
import argparse
18+
import itertools
19+
import numpy as np
20+
from mpyc.runtime import mpc
21+
import np_aes as aes # vectorized AES demo operating on secure arrays over GF(256)
22+
import sha3 # vectorized SHA3/SHAKE demo operating on secure arrays over GF(2)
23+
24+
f = None # one-way function
25+
26+
27+
def tS(k, r):
28+
"""Optimal schedule for binary pebbling."""
29+
if r < 2**(k-1):
30+
return 0
31+
32+
return ((k + r)%2 + k+1 - ((2*r) % (2**(2**k - r).bit_length())).bit_length()) // 2
33+
34+
35+
def P(k, x):
36+
"""Recursive optimal binary pebbler outputs {f^i(x)}_{i=0}^{n-1} in reverse, n=2^k."""
37+
# initial stage
38+
y = [None]*k + [x]
39+
i = k
40+
g = 0
41+
for r in range(1, 2**k):
42+
for _ in range(tS(k, r)):
43+
z = y[i]
44+
if g == 0:
45+
i -= 1
46+
g = 2**i
47+
y[i] = f(z)
48+
g -= 1
49+
yield
50+
# output stage
51+
yield y[0]
52+
for v in itertools.zip_longest(*(P(i-1, y[i]) for i in range(1, k+1))):
53+
yield next(filter(None, v))
54+
55+
56+
def p(k, x):
57+
"""Iterative optimal binary pebbler generating {f^i(x)}_{i=0}^{n-1} in reverse, n=2^k."""
58+
# initial stage
59+
z = []
60+
y = x
61+
for h in range(2**k, 1, -1):
62+
if h & (h-1) == 0: # h is power of 2
63+
z.insert(0, y)
64+
y = f(y)
65+
yield
66+
# output stage
67+
yield y
68+
a = [None] * (k>>1)
69+
v = 0
70+
for r in range(2**k - 1, 0, -1):
71+
yield z[0]
72+
c = r
73+
i = 0
74+
while ~c&1:
75+
z[i] = z[i+1]
76+
i += 1
77+
c >>= 1
78+
i += 1
79+
c >>= 1
80+
if c&1:
81+
a[v] = (i, 0)
82+
v += 1
83+
u = v
84+
w = (r&1) + i+1
85+
while c:
86+
while ~c&1:
87+
w += 1
88+
c >>= 1
89+
u -= 1
90+
q, g = a[u]
91+
for _ in range(w//2):
92+
y = z[q]
93+
if not g:
94+
q -= 1
95+
g = 2**q
96+
z[q] = f(y)
97+
g -= 1
98+
if q:
99+
a[u] = q, g
100+
else:
101+
v -= 1
102+
w = w&1
103+
while c&1:
104+
w += 1
105+
c >>= 1
106+
107+
108+
async def main():
109+
parser = argparse.ArgumentParser()
110+
parser.add_argument('-k', '--order', type=int, metavar='K',
111+
help='order K of hash chain, length n=2**K')
112+
parser.add_argument('--recursive', action='store_true',
113+
help='use recursive pebbler')
114+
parser.add_argument('--sha3', action='store_true',
115+
help='use SHAKE128 as one-way function')
116+
parser.add_argument('--no-one-way', action='store_true',
117+
help='use dummy one-way function')
118+
parser.add_argument('--no-random-seed', action='store_true',
119+
help='use fixed seed')
120+
parser.set_defaults(order=1)
121+
args = parser.parse_args()
122+
123+
await mpc.start()
124+
125+
if args.recursive:
126+
Pebbler = P
127+
else:
128+
Pebbler = p
129+
130+
secfld = sha3.secfld if args.sha3 else aes.secfld
131+
132+
IV = np.array([[3] * 4] * 4) # IV as 4x4 array of bytes
133+
global f
134+
if args.no_one_way:
135+
D = aes.circulant([3, 0, 0, 0])
136+
f = lambda x: D @ x
137+
elif args.sha3:
138+
f = lambda x: sha3.shake(x, 128)
139+
else:
140+
K = aes.key_expansion(secfld.array(IV))
141+
f = lambda x: aes.encrypt(K, x) + x
142+
143+
if args.no_random_seed:
144+
if args.sha3:
145+
# convert 4x4 array of bytes to length-128 array of bits
146+
IV = np.array([(b >> i) & 1 for b in IV.flat for i in range(8)])
147+
x0 = secfld.array(IV)
148+
else:
149+
x0 = mpc.np_random_bits(secfld, 128)
150+
if not args.sha3:
151+
# convert length-128 array of bits to 4x4 array of bytes
152+
x0 = mpc.np_from_bits(x0.reshape(4, 4, 8))
153+
154+
xprint = sha3.xprint if args.sha3 else aes.xprint
155+
156+
k = args.order
157+
print(f'Hash chain of length {2**k}:')
158+
r = 1
159+
for v in Pebbler(k, x0):
160+
if v is None: # initial stage
161+
print(f'{r:4}', '-')
162+
await mpc.throttler(0.0625) # raise barrier every 16 calls to one-way f()
163+
else: # output stage
164+
await xprint(f'{r:4} x{2**(k+1) - 1 - r:<4} =', v)
165+
r += 1
166+
print(f'Performed {k * 2**(k-1) = } hashes in total.')
167+
168+
await mpc.shutdown()
169+
170+
if __name__ == '__main__':
171+
mpc.run(main())

demos/onewayhashchains.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ async def main():
143143
else: # output stage
144144
await aes.xprint(f'{r:4} x{2**(k+1) - 1 - r:<4} =', v)
145145
r += 1
146+
print(f'Performed {k * 2**(k-1) = } hashes in total.')
146147

147148
await mpc.shutdown()
148149

demos/sha3.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,16 @@ def shake(M, d, c=256):
148148
return keccak(c, N, d)
149149

150150

151+
async def xprint(text, s):
152+
"""Print and return bit array s as hex string."""
153+
s = await mpc.output(s)
154+
s = np.fliplr(s.reshape(-1, 8)).reshape(-1) # reverse bits for each byte
155+
d = len(s)
156+
s = f'{int("".join(str(int(b)) for b in s), 2):0{d//4}x}' # bits to hex digits with leading 0s
157+
print(f'{text} {s}')
158+
return s
159+
160+
151161
async def main():
152162
global keccak_f1600
153163

@@ -206,15 +216,11 @@ async def main():
206216

207217
X = args.i.encode() * args.n
208218
print(f'Input: {X}')
209-
210219
x = np.array([(b >> i) & 1 for b in X for i in range(8)]) # bytes to bits
211220
x = secfld.array(x) # secret-shared input bits
212-
y = F(x, d, c) # secret-shared output bits
213-
y = await mpc.output(y)
214-
y = np.fliplr(y.reshape(-1, 8)).reshape(-1) # reverse bits for each byte
215-
Y = f'{int("".join(str(int(b)) for b in y), 2):0{d//4}x}' # bits to hex digits with leading 0s
216221

217-
print(f'Output: {Y}')
222+
y = F(x, d, c) # secret-shared output bits
223+
Y = await xprint(f'Output:', y)
218224
assert Y == f(X).hexdigest(*e)
219225

220226
await mpc.shutdown()

mpyc/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
and statistics (securely mimicking Python’s statistics module).
2727
"""
2828

29-
__version__ = '0.8.9'
29+
__version__ = '0.8.10'
3030
__license__ = 'MIT License'
3131

3232
import os
@@ -86,7 +86,7 @@ def get_arg_parser():
8686

8787
group = parser.add_argument_group('MPyC misc')
8888
group.add_argument('--output-windows', action='store_true',
89-
help='screen output for parties 0<i<m (only on Windows)')
89+
help='screen output for parties 0<i<m (one window each)')
9090
group.add_argument('--output-file', action='store_true',
9191
help='append output of parties 0<i<m to party{m}_{i}.log')
9292
group.add_argument('-f', type=str, default='',

mpyc/runtime.py

Lines changed: 40 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -884,13 +884,15 @@ async def mul(self, a, b):
884884
else:
885885
a_integral = a.integral
886886
b_integral = shb and b.integral
887-
b_is_int = False
887+
z = 0
888888
if not shb:
889889
if isinstance(b, int):
890-
b_is_int = True
890+
z = f
891891
elif isinstance(b, float):
892892
b = round(b * 2**f)
893-
await self.returnType((stype, a_integral and (b_integral or b_is_int)))
893+
z = max(0, min(f, (b & -b).bit_length() - 1))
894+
b >>= z # remove trailing zeros
895+
await self.returnType((stype, a_integral and (b_integral or z == f)))
894896

895897
if not shb:
896898
a = await self.gather(a)
@@ -899,12 +901,12 @@ async def mul(self, a, b):
899901
else:
900902
a, b = await self.gather(a, b)
901903
c = a * b
902-
if f and (a_integral or b_integral) and not b_is_int:
903-
c >>= f # NB: in-place rshift
904+
if f and (a_integral or b_integral) and z != f:
905+
c >>= f - z # NB: in-place rshift
904906
if shb:
905907
c = self._reshare(c)
906-
if f and not (a_integral or b_integral) and not b_is_int:
907-
c = self.trunc(stype(c))
908+
if f and not (a_integral or b_integral) and z != f:
909+
c = self.trunc(stype(c), f=f - z)
908910
return c
909911

910912
@mpc_coro
@@ -920,21 +922,23 @@ async def np_multiply(self, a, b):
920922
else:
921923
a_integral = a.integral
922924
b_integral = shb and b.integral
923-
b_is_int = False
925+
z = 0
924926
if not shb:
925927
if isinstance(b, int):
926-
b_is_int = True
928+
z = f
927929
elif isinstance(b, float):
928930
b = round(b * 2**f)
931+
z = max(0, min(f, (b & -b).bit_length() - 1))
932+
b >>= z # remove trailing zeros
929933
elif isinstance(b, np.ndarray):
930934
if np.issubdtype(b.dtype, np.integer):
931-
b_is_int = True
935+
z = f
932936
elif np.issubdtype(b.dtype, np.floating):
933937
# NB: unlike for self.mul() no test if all entries happen to be integral
934938
# Scale to Python int entries (by setting otypes='O', prevents overflow):
935939
b = np.vectorize(round, otypes='O')(b * 2**f)
936940
# TODO: handle b.dtype=object, checking if all elts are int
937-
await self.returnType((stype, a_integral and (b_integral or b_is_int), shape))
941+
await self.returnType((stype, a_integral and (b_integral or z == f), shape))
938942

939943
if not shb:
940944
a = await self.gather(a)
@@ -943,12 +947,12 @@ async def np_multiply(self, a, b):
943947
else:
944948
a, b = await self.gather(a, b)
945949
c = a * b
946-
if f and (a_integral or b_integral) and not b_is_int:
947-
c >>= f # NB: in-place rshift
950+
if f and (a_integral or b_integral) and z != f:
951+
c >>= f - z # NB: in-place rshift
948952
if shb:
949953
c = self._reshare(c)
950-
if f and not (a_integral or b_integral) and not b_is_int:
951-
c = self.np_trunc(stype(c, shape=shape))
954+
if f and not (a_integral or b_integral) and z != f:
955+
c = self.np_trunc(stype(c, shape=shape), f=f - z)
952956
return c
953957

954958
def div(self, a, b):
@@ -2240,7 +2244,16 @@ async def np_reshape(self, a, shape, order='C'):
22402244
if isinstance(shape, int):
22412245
shape = (shape,) # ensure shape is a tuple
22422246
if -1 in shape:
2243-
raise ValueError('reshape with unknown dimension not allowed for secure arrays')
2247+
if shape.count(-1) > 1:
2248+
raise ValueError('can only specify one unknown dimension')
2249+
2250+
if (n := a.size) % (n1 := -math.prod(shape)) != 0:
2251+
raise ValueError(f'cannot reshape array of size {n} into shape {shape}')
2252+
2253+
i = shape.index(-1)
2254+
shape = list(shape)
2255+
shape[i] = n // n1
2256+
shape = tuple(shape)
22442257

22452258
if issubclass(stype, self.SecureFixedPointArray):
22462259
await self.returnType((stype, a.integral, shape))
@@ -2461,6 +2474,17 @@ def np_append(self, arr, values, axis=None):
24612474
"""
24622475
return self.np_concatenate((arr, values), axis=axis)
24632476

2477+
@mpc_coro_no_pc
2478+
async def np_fliplr(self, a):
2479+
"""Reverse the order of elements along axis 1 (left/right).
2480+
2481+
For a 2D array, this flips the entries in each row in the left/right direction.
2482+
Columns are preserved, but appear in a different order than before.
2483+
"""
2484+
await self.returnType((type(a), a.shape))
2485+
a = await self.gather(a)
2486+
return np.fliplr(a)
2487+
24642488
def np_minimum(self, a, b):
24652489
return b + (a < b) * (a - b)
24662490

mpyc/sectypes.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -588,8 +588,6 @@ def _coerce(self, other):
588588

589589
def _coerce2(self, other):
590590
if isinstance(other, float):
591-
if other.is_integer():
592-
other = round(other)
593591
return other
594592

595593
return super()._coerce2(other)
@@ -1030,6 +1028,10 @@ def __init__(self, value=None, shape=None):
10301028
self.shape = shape
10311029
super().__init__(value)
10321030

1031+
def __bool__(self):
1032+
"""Return True if secure array is nonempty, False otherwise."""
1033+
return bool(self.size)
1034+
10331035
def __array_function__(self, func, types, args, kwargs):
10341036
# minimal redirect for now
10351037
return eval(f'runtime.np_{func.__name__}')(*args, **kwargs)

tests/test_runtime.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def test_secint_array(self):
5555
np.assertEqual(mpc.run(mpc.output(np.row_stack((c, c, c)))), np.row_stack((a, a, a)))
5656
np.assertEqual(mpc.run(mpc.output(np.split(c, 2, 1)[0])), np.split(a, 2, 1)[0])
5757
np.assertEqual(mpc.run(mpc.output(np.vsplit(c, np.array([1]))[0])), np.vsplit(a, [1])[0])
58+
np.assertEqual(mpc.run(mpc.output(np.reshape(c, (-1,)))), np.reshape(a, (-1,)))
59+
np.assertEqual(mpc.run(mpc.output(np.fliplr(c))), np.fliplr(a))
5860
a1, a2 = a[:, :, 1], a[:, 0, :].reshape(2, 1)
5961
np.assertEqual(mpc.run(mpc.output(np.add(secnum.array(a1), secnum.array(a2)))), a1 + a2)
6062
np.assertEqual(mpc.run(mpc.output(a1 + secnum.array(a[:, 0, :]).reshape(2, 1))), a1 + a2)

0 commit comments

Comments
 (0)