Skip to content

Commit f6f1952

Browse files
committed
Add np_id3gini demo.
1 parent 32a75eb commit f6f1952

File tree

2 files changed

+141
-11
lines changed

2 files changed

+141
-11
lines changed

demos/id3gini.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""Demo decision tree learning using ID3.
22
3-
This demo implements a slight variant of Protocol 4.1 from the paper
4-
'Practical Secure Decision Tree Learning in a Teletreatment Application'
5-
by Sebastiaan de Hoogh, Berry Schoenmakers, Ping Chen, and Harm op den Akker,
6-
which appeared at the 18th International Conference on Financial Cryptography
7-
and Data Security (FC 2014), LNCS 8437, pp. 179-194, Springer.
3+
This demo implements Protocol 4.1 from the paper 'Practical Secure Decision
4+
Tree Learning in a Teletreatment Application' by Sebastiaan de Hoogh, Berry
5+
Schoenmakers, Ping Chen, and Harm op den Akker, which appeared at the 18th
6+
International Conference on Financial Cryptography and Data Security (FC 2014),
7+
LNCS 8437, pp. 179-194, Springer.
88
See https://doi.org/10.1007/978-3-662-45472-5_12 (or,
99
see https://fc14.ifca.ai/papers/fc14_submission_103.pdf or,
1010
see https://www.researchgate.net/publication/295148009).
@@ -74,17 +74,17 @@ async def id3(T, R) -> asyncio.Future:
7474
logging.info(f'Leaf node label {i}')
7575
tree = i
7676
else:
77-
T_R = [[mpc.schur_prod(T, v) for v in S[A]] for A in R]
78-
gains = [GI(mpc.matrix_prod(T_A, S[C], True)) for T_A in T_R]
77+
T_SC = [mpc.schur_prod(T, v) for v in S[C]]
78+
gains = [GI(mpc.matrix_prod(S[A], T_SC, True)) for A in R]
7979
k = await mpc.output(mpc.argmax(gains, key=SecureFraction)[0])
80-
T_Rk = T_R[k]
81-
del T_R, gains # release memory
80+
del gains # release memory
8281
A = list(R)[k]
82+
T_SA = [mpc.schur_prod(T, v) for v in S[A]]
8383
logging.info(f'Attribute node {A}')
8484
if args.parallel_subtrees:
85-
subtrees = await mpc.gather([id3(Tj, R.difference([A])) for Tj in T_Rk])
85+
subtrees = await mpc.gather([id3(Tj, R.difference([A])) for Tj in T_SA])
8686
else:
87-
subtrees = [await id3(Tj, R.difference([A])) for Tj in T_Rk]
87+
subtrees = [await id3(Tj, R.difference([A])) for Tj in T_SA]
8888
tree = A, subtrees
8989
return tree
9090

demos/np_id3gini.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
"""Demo decision tree learning using ID3, vectorized.
2+
3+
This demo is a fully equivalent reimplementation of the id3gini.py demo.
4+
Performance improvement of over 6x speed-up when run with three parties
5+
on local host. Memory consumption is reduced accordingly.
6+
7+
See id3gini.py for background information on decision tree learning and ID3.
8+
"""
9+
# TODO: vectorize mpc.argmax()
10+
11+
import os
12+
import logging
13+
import argparse
14+
import csv
15+
import asyncio
16+
import numpy as np
17+
from mpyc.runtime import mpc
18+
19+
20+
@mpc.coroutine
21+
async def id3(T, R) -> asyncio.Future:
22+
sizes = S[C] @ T
23+
i, mx = mpc.argmax(sizes)
24+
sizeT = sizes.sum()
25+
stop = (sizeT <= int(args.epsilon * len(T))) + (mx == sizeT)
26+
if not (R and await mpc.is_zero_public(stop)):
27+
i = await mpc.output(i)
28+
logging.info(f'Leaf node label {i}')
29+
tree = i
30+
else:
31+
T_SC = (T * S[C]).T
32+
k = mpc.argmax([GI(S[A] @ T_SC) for A in R], key=SecureFraction)[0]
33+
A = list(R)[await mpc.output(k)]
34+
logging.info(f'Attribute node {A}')
35+
T_SA = T * S[A]
36+
if args.parallel_subtrees:
37+
subtrees = await mpc.gather([id3(Tj, R.difference([A])) for Tj in T_SA])
38+
else:
39+
subtrees = [await id3(Tj, R.difference([A])) for Tj in T_SA]
40+
tree = A, subtrees
41+
return tree
42+
43+
44+
def GI(x):
45+
"""Gini impurity for contingency table x."""
46+
y = args.alpha * np.sum(x, axis=1) + 1 # NB: alternatively, use s + (s == 0)
47+
D = mpc.prod(y.tolist())
48+
G = np.sum(np.sum(x * x, axis=1) / y)
49+
return [D * G, D] # numerator, denominator
50+
51+
52+
class SecureFraction:
53+
def __init__(self, a):
54+
self.n, self.d = a # numerator, denominator
55+
56+
def __lt__(self, other): # NB: __lt__() is basic comparison as in Python's list.sort()
57+
return mpc.in_prod([self.n, -self.d], [other.d, other.n]) < 0
58+
59+
60+
depth = lambda tree: 0 if isinstance(tree, int) else max(map(depth, tree[1])) + 1
61+
62+
size = lambda tree: 1 if isinstance(tree, int) else sum(map(size, tree[1])) + 1
63+
64+
65+
def pretty(prefix, tree, names, ranges):
66+
"""Convert raw tree into multiline textual representation, using attribute names and values."""
67+
if isinstance(tree, int): # leaf
68+
return ranges[C][tree]
69+
70+
A, subtrees = tree
71+
s = ''
72+
for a, t in zip(ranges[A], subtrees):
73+
s += f'\n{prefix}{names[A]} == {a}: {pretty("| " + prefix, t, names, ranges)}'
74+
return s
75+
76+
77+
async def main():
78+
global args, secint, C, S
79+
80+
parser = argparse.ArgumentParser()
81+
parser.add_argument('-i', '--dataset', type=int, metavar='I',
82+
help=('dataset 0=tennis (default), 1=balance-scale, 2=car, '
83+
'3=SPECT, 4=KRKPA7, 5=tic-tac-toe, 6=house-votes-84'))
84+
parser.add_argument('-l', '--bit-length', type=int, metavar='L',
85+
help='override preset bit length for dataset')
86+
parser.add_argument('-e', '--epsilon', type=float, metavar='E',
87+
help='minimum fraction E of samples for a split, 0.0<=E<=1.0')
88+
parser.add_argument('-a', '--alpha', type=int, metavar='A',
89+
help='scale factor A to prevent division by zero, A>=1')
90+
parser.add_argument('--parallel-subtrees', action='store_true',
91+
help='process subtrees in parallel (rather than in series)')
92+
parser.add_argument('--no-pretty-tree', action='store_true',
93+
help='print raw flat tree instead of pretty tree')
94+
parser.set_defaults(dataset=0, bit_length=0, alpha=8, epsilon=0.05)
95+
args = parser.parse_args()
96+
97+
settings = [('tennis', 32), ('balance-scale', 77), ('car', 95),
98+
('SPECT', 42), ('KRKPA7', 69), ('tic-tac-toe', 75), ('house-votes-84', 62)]
99+
name, bit_length = settings[args.dataset]
100+
if args.bit_length:
101+
bit_length = args.bit_length
102+
secint = mpc.SecInt(bit_length)
103+
print(f'Using secure integers: {secint.__name__}')
104+
105+
with open(os.path.join('data', 'id3', name + '.csv')) as file:
106+
reader = csv.reader(file)
107+
attr_names = next(reader)
108+
C = 0 if attr_names[0].lower().startswith('class') else len(attr_names)-1 # class attribute
109+
transactions = list(reader)
110+
n, d = len(transactions), len(attr_names)
111+
attr_ranges = [sorted(set(t[i] for t in transactions)) for i in range(d)]
112+
# one-hot encoding of attributes:
113+
S = [secint.array(np.array([[t[i] == j for t in transactions] for j in attr_ranges[i]]))
114+
for i in range(d)]
115+
T = secint.array(np.ones(n, dtype='O'))
116+
print(f'dataset: {name} with {n} samples and {d-1} attributes')
117+
118+
await mpc.start()
119+
tree = await id3(T, frozenset(range(d)).difference([C]))
120+
await mpc.shutdown()
121+
122+
print(f'Decision tree of depth {depth(tree)} and size {size(tree)}: ', end='')
123+
if args.no_pretty_tree:
124+
print(tree)
125+
else:
126+
print(pretty('if ', tree, attr_names, attr_ranges))
127+
128+
129+
if __name__ == '__main__':
130+
mpc.run(main())

0 commit comments

Comments
 (0)