|
| 1 | +"""Demo Linear Programming (LP) solver, using secure integer arithmetic. |
| 2 | +
|
| 3 | +Vectorized. |
| 4 | +
|
| 5 | +See demo lpsolver.py for background information. |
| 6 | +
|
| 7 | +... work in progress for MPyC version 0.9 |
| 8 | +""" |
| 9 | + |
| 10 | +import os |
| 11 | +import logging |
| 12 | +import argparse |
| 13 | +import csv |
| 14 | +import math |
| 15 | +import numpy as np |
| 16 | +from mpyc.runtime import mpc |
| 17 | + |
| 18 | + |
| 19 | +# TODO: unify approaches for argmin etc. with secure NumPy arrays (also see lpsolverfxp.py) |
| 20 | + |
| 21 | + |
| 22 | +def argmin_int(x): |
| 23 | + secintarr = type(x) |
| 24 | + n = len(x) |
| 25 | + if n == 1: |
| 26 | + return (secintarr(np.array([1])), x[0]) |
| 27 | + |
| 28 | + if n == 2: |
| 29 | + b = x[0] < x[1] |
| 30 | + arg = mpc.np_fromlist([b, 1 - b]) |
| 31 | + min = b * (x[0] - x[1]) + x[1] |
| 32 | + return arg, min |
| 33 | + |
| 34 | + a = x[:n//2] ## split even odd? start at n%2 as in reduce in mpctools |
| 35 | + b = x[(n+1)//2:] |
| 36 | + c = a < b |
| 37 | + m = c * (a - b) + b |
| 38 | + if n%2 == 1: |
| 39 | + m = np.concatenate((m, x[n//2:(n+1)//2])) |
| 40 | + ag, mn = argmin_int(m) |
| 41 | + if n%2 == 1: |
| 42 | + ag_1 = ag[-1:] |
| 43 | + ag = ag[:-1] |
| 44 | + arg1 = ag * c |
| 45 | + arg2 = ag - arg1 |
| 46 | + if n%2 == 1: |
| 47 | + arg = np.concatenate((arg1, ag_1, arg2)) |
| 48 | + else: |
| 49 | + arg = np.concatenate((arg1, arg2)) |
| 50 | + return arg, mn |
| 51 | + |
| 52 | + |
| 53 | +def argmin_rat(nd): |
| 54 | + secintarr = type(nd) |
| 55 | + N = nd.shape[1] |
| 56 | + if N == 1: |
| 57 | + return (secintarr(np.array([1])), (nd[0, 0], nd[1, 0])) |
| 58 | + |
| 59 | + if N == 2: |
| 60 | + b = mpc.in_prod([nd[0, 0], -nd[1, 0]], [nd[1, 1], nd[0, 1]]) < 0 |
| 61 | + arg = mpc.np_fromlist([b, 1 - b]) |
| 62 | + min = b * (nd[:, 0] - nd[:, 1]) + nd[:, 1] |
| 63 | + return arg, (min[0], min[1]) |
| 64 | + |
| 65 | + a = nd[:, :N//2] |
| 66 | + b = nd[:, (N+1)//2:] |
| 67 | + c = a[0] * b[1] < b[0] * a[1] |
| 68 | + m = c * (a - b) + b |
| 69 | + if N%2 == 1: |
| 70 | + m = np.concatenate((m, nd[:, N//2:(N+1)//2]), axis=1) |
| 71 | + ag, mn = argmin_rat(m) |
| 72 | + if N%2 == 1: |
| 73 | + ag_1 = ag[-1:] |
| 74 | + ag = ag[:-1] |
| 75 | + arg1 = ag * c |
| 76 | + arg2 = ag - arg1 |
| 77 | + if N%2 == 1: |
| 78 | + arg = np.concatenate((arg1, ag_1, arg2)) |
| 79 | + else: |
| 80 | + arg = np.concatenate((arg1, arg2)) |
| 81 | + return arg, mn |
| 82 | + |
| 83 | + |
| 84 | +def pow_list(a, x, n): |
| 85 | + """Return [a,ax, ax^2, ..., ax^(n-1)]. |
| 86 | +
|
| 87 | + Runs in O(log n) rounds using minimal number of n-1 secure multiplications. |
| 88 | + """ |
| 89 | + if n == 1: |
| 90 | + powers = mpc.np_fromlist([a]) |
| 91 | + elif n == 2: |
| 92 | + powers = mpc.np_fromlist([a, a * x]) |
| 93 | + else: |
| 94 | + even_powers = pow_list(a, x * x, (n+1)//2) |
| 95 | + if n%2: |
| 96 | + d = even_powers[-1] |
| 97 | + even_powers = even_powers[:-1] |
| 98 | + odd_powers = x * even_powers |
| 99 | + powers = np.stack((even_powers, odd_powers)) |
| 100 | + powers = powers.T.reshape(n - (n%2)) |
| 101 | + if n%2: |
| 102 | + powers = np.append(powers, mpc.np_fromlist([d])) |
| 103 | + return powers |
| 104 | + |
| 105 | + |
| 106 | +# TODO: consider next version of unit_vector() as alternative for current mpc.unit_vector() |
| 107 | +@mpc.coroutine |
| 108 | +async def unit_vector(a, n): |
| 109 | + """Length-n unit vector [0]*a + [1] + [0]*(n-1-a) for secret a, assuming 0 <= a < n. |
| 110 | +
|
| 111 | + NB: If a = n, unit vector [1] + [0]*(n-1) is returned. See mpyc.statistics. |
| 112 | + """ |
| 113 | + await mpc.returnType(type(a), n) |
| 114 | + # TODO: add conversion from prime field GF(p) to secint, 0<=a<n, n < p/2 needed? |
| 115 | + u = await mpc.gather(mpc.random.random_unit_vector(type(a), n)) |
| 116 | + r = sum(a * b for a, b in zip(u, range(n))) |
| 117 | + R = 1 + mpc._random(type(a).field, 1<<mpc.options.sec_param) |
| 118 | + c = await mpc.output(a - r + R * n) |
| 119 | + c %= n |
| 120 | + # rotate u over c positions to the right |
| 121 | + v = u[:n-c] |
| 122 | + u = u[n-c:] |
| 123 | + u.extend(v) |
| 124 | + return u |
| 125 | + |
| 126 | + |
| 127 | +async def main(): |
| 128 | + parser = argparse.ArgumentParser() |
| 129 | + parser.add_argument('-i', '--dataset', type=int, metavar='I', |
| 130 | + help=('dataset 0=uvlp (default), 1=wiki, 2=tb2x2, 3=woody, ' |
| 131 | + '4=LPExample_R20, 5=sc50b, 6=kb2, 7=LPExample')) |
| 132 | + parser.add_argument('-l', '--bit-length', type=int, metavar='L', |
| 133 | + help='override preset bit length for dataset') |
| 134 | + parser.set_defaults(dataset=0, bit_length=0) |
| 135 | + args = parser.parse_args() |
| 136 | + |
| 137 | + settings = [('uvlp', 8, 1, 2), |
| 138 | + ('wiki', 6, 1, 2), |
| 139 | + ('tb2x2', 6, 1, 2), |
| 140 | + ('woody', 8, 1, 3), |
| 141 | + ('LPExample_R20', 70, 1, 5), |
| 142 | + ('sc50b', 104, 10, 55), |
| 143 | + ('kb2', 536, 100000, 106), |
| 144 | + ('LPExample', 110, 1, 178)] |
| 145 | + name, bit_length, scale, n_iter = settings[args.dataset] |
| 146 | + if args.bit_length: |
| 147 | + bit_length = args.bit_length |
| 148 | + |
| 149 | + T = np.genfromtxt(os.path.join('data', 'lp', name + '.csv'), dtype=float, delimiter=',') |
| 150 | + m, n = T.shape[0] - 1, T.shape[1] - 1 |
| 151 | + secint = mpc.SecInt(bit_length, n=m + n) # force existence of Nth root of unity, N>=m+n |
| 152 | + print(f'Using secure {bit_length}-bit integers: {secint.__name__}') |
| 153 | + print(f'dataset: {name} with {m} constraints and {n} variables (scale factor {scale})') |
| 154 | + T[0, -1] = 0.0 # initialize optimal value |
| 155 | + T = np.vectorize(int, otypes='O')(T * scale) |
| 156 | + g = np.gcd.reduce(T[1:], axis=1, keepdims=True) |
| 157 | + T[1:] //= np.maximum(g, 1) # remove common factors per row (skipping cost row) |
| 158 | + T = secint.array(T) |
| 159 | + c, A, b = -T[0, :-1], T[1:, :-1], T[1:, -1] # maximize c.x subject to A.x <= b, x >= 0 |
| 160 | + |
| 161 | + Zp = secint.field |
| 162 | + N = Zp.nth |
| 163 | + w = Zp.root # w is an Nth root of unity in Zp, where N >= m + n |
| 164 | + w_powers = [Zp(1)] |
| 165 | + for _ in range(N-1): |
| 166 | + w_powers.append(w_powers[-1] * w) |
| 167 | + assert w_powers[-1] * w == 1 |
| 168 | + |
| 169 | + await mpc.start() |
| 170 | + |
| 171 | + cobasis = Zp.array(np.array([w_powers[-j].value for j in range(n)])) |
| 172 | + basis = Zp.array(np.array([w_powers[-(i + n)].value for i in range(m)])) |
| 173 | + previous_pivot = secint(1) |
| 174 | + |
| 175 | + iteration = 0 |
| 176 | + while True: |
| 177 | + # find index of pivot column |
| 178 | + p_col_index, minimum = argmin_int(T[0, :-1]) |
| 179 | + if await mpc.output(minimum >= 0): |
| 180 | + break # maximum reached |
| 181 | + |
| 182 | + # find index of pivot row |
| 183 | + p_col = T[:, :-1] @ p_col_index |
| 184 | + den = p_col[1:] |
| 185 | + num = T[1:, -1] + (den <= 0) |
| 186 | + p_row_index, (_, pivot) = argmin_rat(np.stack((num, den))) |
| 187 | + |
| 188 | + # reveal progress a bit |
| 189 | + iteration += 1 |
| 190 | + mx, cd, p = await mpc.output([T[0, -1], previous_pivot, pivot]) |
| 191 | + logging.info(f'Iteration {iteration}/{n_iter}: {mx / cd} pivot={p / cd}') |
| 192 | + |
| 193 | + # swap basis entries |
| 194 | + delta = basis @ p_row_index - cobasis @ p_col_index |
| 195 | + cobasis += delta * p_col_index |
| 196 | + basis -= delta * p_row_index |
| 197 | + |
| 198 | + # update tableau Tij = Tij*Tkl/Tkl' - (Til/Tkl' - bool(i==k)) * (Tkj + bool(j==l)*Tkl') |
| 199 | + p_col_index = np.concatenate((p_col_index, np.array([0]))) |
| 200 | + p_row_index = np.concatenate((np.array([0]), p_row_index)) |
| 201 | + |
| 202 | + pp_inv = 1 / previous_pivot |
| 203 | + p_col = p_col * pp_inv - p_row_index |
| 204 | + p_row = p_row_index @ T + previous_pivot * p_col_index |
| 205 | + T = T * (pivot * pp_inv) - p_col.reshape(len(p_col), 1) @ p_row.reshape(1, len(p_row)) # consider np.gauss |
| 206 | + previous_pivot = pivot |
| 207 | + |
| 208 | + mx = await mpc.output(T[0, -1]) |
| 209 | + cd = await mpc.output(previous_pivot) # common denominator for all entries of T |
| 210 | + print(f'max = {mx} / {cd} / {scale} = {mx / cd / scale} in {iteration} iterations') |
| 211 | + |
| 212 | + logging.info('Solution x') |
| 213 | + sum_powers = secint.array(np.zeros((N,), dtype='O')) # TODO: compare with fromiter approach below |
| 214 | + for i in range(m): |
| 215 | + x_powers = pow_list(T[i+1][-1] / N, basis[i], N) |
| 216 | + sum_powers += x_powers |
| 217 | + coefs = Zp.array([[w_powers[(j * k) % N].value for k in range(N)] for j in range(n)]) ## TODO: vandermonde ff array |
| 218 | + x = coefs @ sum_powers |
| 219 | + Ax_bounded_by_b = mpc.all((A @ x <= b * cd).tolist()) ## TODO: np.all |
| 220 | + x_nonnegative = mpc.all((x >= 0).tolist()) |
| 221 | + |
| 222 | + logging.info('Dual solution y') |
| 223 | + coefs = Zp.array([[w_powers[((n + i) * k) % N].value for k in range(N)] for i in range(m)]) |
| 224 | + sum_powers = np.sum(np.fromiter((pow_list(T[0][j] / N, cobasis[j], N) for j in range(n)), 'O')) #TODO fix for Numpy 1.22, WSL |
| 225 | + y = coefs @ sum_powers |
| 226 | + yA_bounded_by_c = mpc.all((y @ A >= c * cd).tolist()) |
| 227 | + y_nonnegative = mpc.all((y >= 0).tolist()) |
| 228 | + |
| 229 | + cx_eq_yb = c @ x == y @ b |
| 230 | + check = mpc.all([cx_eq_yb, Ax_bounded_by_b, x_nonnegative, yA_bounded_by_c, y_nonnegative]) |
| 231 | + check = bool(await mpc.output(check)) |
| 232 | + print(f'verification c.x == y.b, A.x <= b, x >= 0, y.A >= c, y >= 0: {check}') |
| 233 | + |
| 234 | + x = await mpc.output(x) |
| 235 | + print(f'solution = {[a / cd for a in x]}') |
| 236 | + |
| 237 | + await mpc.shutdown() |
| 238 | + |
| 239 | +if __name__ == '__main__': |
| 240 | + mpc.run(main()) |
0 commit comments