|
| 1 | +"""Demo decision tree learning using ID3, vectorized. |
| 2 | +
|
| 3 | +This demo is a fully equivalent reimplementation of the id3gini.py demo. |
| 4 | +Performance improvement of over 6x speed-up when run with three parties |
| 5 | +on local host. Memory consumption is reduced accordingly. |
| 6 | +
|
| 7 | +See id3gini.py for background information on decision tree learning and ID3. |
| 8 | +""" |
| 9 | +# TODO: vectorize mpc.argmax() |
| 10 | + |
| 11 | +import os |
| 12 | +import logging |
| 13 | +import argparse |
| 14 | +import csv |
| 15 | +import asyncio |
| 16 | +import numpy as np |
| 17 | +from mpyc.runtime import mpc |
| 18 | + |
| 19 | + |
| 20 | +@mpc.coroutine |
| 21 | +async def id3(T, R) -> asyncio.Future: |
| 22 | + sizes = S[C] @ T |
| 23 | + i, mx = mpc.argmax(sizes) |
| 24 | + sizeT = sizes.sum() |
| 25 | + stop = (sizeT <= int(args.epsilon * len(T))) + (mx == sizeT) |
| 26 | + if not (R and await mpc.is_zero_public(stop)): |
| 27 | + i = await mpc.output(i) |
| 28 | + logging.info(f'Leaf node label {i}') |
| 29 | + tree = i |
| 30 | + else: |
| 31 | + T_SC = (T * S[C]).T |
| 32 | + k = mpc.argmax([GI(S[A] @ T_SC) for A in R], key=SecureFraction)[0] |
| 33 | + A = list(R)[await mpc.output(k)] |
| 34 | + logging.info(f'Attribute node {A}') |
| 35 | + T_SA = T * S[A] |
| 36 | + if args.parallel_subtrees: |
| 37 | + subtrees = await mpc.gather([id3(Tj, R.difference([A])) for Tj in T_SA]) |
| 38 | + else: |
| 39 | + subtrees = [await id3(Tj, R.difference([A])) for Tj in T_SA] |
| 40 | + tree = A, subtrees |
| 41 | + return tree |
| 42 | + |
| 43 | + |
| 44 | +def GI(x): |
| 45 | + """Gini impurity for contingency table x.""" |
| 46 | + y = args.alpha * np.sum(x, axis=1) + 1 # NB: alternatively, use s + (s == 0) |
| 47 | + D = mpc.prod(y.tolist()) |
| 48 | + G = np.sum(np.sum(x * x, axis=1) / y) |
| 49 | + return [D * G, D] # numerator, denominator |
| 50 | + |
| 51 | + |
| 52 | +class SecureFraction: |
| 53 | + def __init__(self, a): |
| 54 | + self.n, self.d = a # numerator, denominator |
| 55 | + |
| 56 | + def __lt__(self, other): # NB: __lt__() is basic comparison as in Python's list.sort() |
| 57 | + return mpc.in_prod([self.n, -self.d], [other.d, other.n]) < 0 |
| 58 | + |
| 59 | + |
| 60 | +depth = lambda tree: 0 if isinstance(tree, int) else max(map(depth, tree[1])) + 1 |
| 61 | + |
| 62 | +size = lambda tree: 1 if isinstance(tree, int) else sum(map(size, tree[1])) + 1 |
| 63 | + |
| 64 | + |
| 65 | +def pretty(prefix, tree, names, ranges): |
| 66 | + """Convert raw tree into multiline textual representation, using attribute names and values.""" |
| 67 | + if isinstance(tree, int): # leaf |
| 68 | + return ranges[C][tree] |
| 69 | + |
| 70 | + A, subtrees = tree |
| 71 | + s = '' |
| 72 | + for a, t in zip(ranges[A], subtrees): |
| 73 | + s += f'\n{prefix}{names[A]} == {a}: {pretty("| " + prefix, t, names, ranges)}' |
| 74 | + return s |
| 75 | + |
| 76 | + |
| 77 | +async def main(): |
| 78 | + global args, secint, C, S |
| 79 | + |
| 80 | + parser = argparse.ArgumentParser() |
| 81 | + parser.add_argument('-i', '--dataset', type=int, metavar='I', |
| 82 | + help=('dataset 0=tennis (default), 1=balance-scale, 2=car, ' |
| 83 | + '3=SPECT, 4=KRKPA7, 5=tic-tac-toe, 6=house-votes-84')) |
| 84 | + parser.add_argument('-l', '--bit-length', type=int, metavar='L', |
| 85 | + help='override preset bit length for dataset') |
| 86 | + parser.add_argument('-e', '--epsilon', type=float, metavar='E', |
| 87 | + help='minimum fraction E of samples for a split, 0.0<=E<=1.0') |
| 88 | + parser.add_argument('-a', '--alpha', type=int, metavar='A', |
| 89 | + help='scale factor A to prevent division by zero, A>=1') |
| 90 | + parser.add_argument('--parallel-subtrees', action='store_true', |
| 91 | + help='process subtrees in parallel (rather than in series)') |
| 92 | + parser.add_argument('--no-pretty-tree', action='store_true', |
| 93 | + help='print raw flat tree instead of pretty tree') |
| 94 | + parser.set_defaults(dataset=0, bit_length=0, alpha=8, epsilon=0.05) |
| 95 | + args = parser.parse_args() |
| 96 | + |
| 97 | + settings = [('tennis', 32), ('balance-scale', 77), ('car', 95), |
| 98 | + ('SPECT', 42), ('KRKPA7', 69), ('tic-tac-toe', 75), ('house-votes-84', 62)] |
| 99 | + name, bit_length = settings[args.dataset] |
| 100 | + if args.bit_length: |
| 101 | + bit_length = args.bit_length |
| 102 | + secint = mpc.SecInt(bit_length) |
| 103 | + print(f'Using secure integers: {secint.__name__}') |
| 104 | + |
| 105 | + with open(os.path.join('data', 'id3', name + '.csv')) as file: |
| 106 | + reader = csv.reader(file) |
| 107 | + attr_names = next(reader) |
| 108 | + C = 0 if attr_names[0].lower().startswith('class') else len(attr_names)-1 # class attribute |
| 109 | + transactions = list(reader) |
| 110 | + n, d = len(transactions), len(attr_names) |
| 111 | + attr_ranges = [sorted(set(t[i] for t in transactions)) for i in range(d)] |
| 112 | + # one-hot encoding of attributes: |
| 113 | + S = [secint.array(np.array([[t[i] == j for t in transactions] for j in attr_ranges[i]])) |
| 114 | + for i in range(d)] |
| 115 | + T = secint.array(np.ones(n, dtype='O')) |
| 116 | + print(f'dataset: {name} with {n} samples and {d-1} attributes') |
| 117 | + |
| 118 | + await mpc.start() |
| 119 | + tree = await id3(T, frozenset(range(d)).difference([C])) |
| 120 | + await mpc.shutdown() |
| 121 | + |
| 122 | + print(f'Decision tree of depth {depth(tree)} and size {size(tree)}: ', end='') |
| 123 | + if args.no_pretty_tree: |
| 124 | + print(tree) |
| 125 | + else: |
| 126 | + print(pretty('if ', tree, attr_names, attr_ranges)) |
| 127 | + |
| 128 | + |
| 129 | +if __name__ == '__main__': |
| 130 | + mpc.run(main()) |
0 commit comments