Skip to content

Commit 96f352d

Browse files
committed
tests treeutils, TransferTreeClassifier and TransferForestClassifier
1 parent 43439de commit 96f352d

File tree

2 files changed

+230
-38
lines changed

2 files changed

+230
-38
lines changed

tests/test_transfertree.py

Lines changed: 95 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,30 @@
11
import copy
22
import numpy as np
33
from sklearn.tree import DecisionTreeClassifier
4+
from sklearn.ensemble import RandomForestClassifier
45

5-
from adapt.parameter_based import TransferTreeClassifier
6+
from adapt.parameter_based import TransferTreeClassifier, TransferForestClassifier
67

78
methods = [
89
'relab',
910
'ser',
1011
'strut',
1112
'ser_nr',
13+
'ser_no_ext',
1214
'ser_nr_lambda',
1315
'strut_nd',
1416
'strut_lambda',
15-
'strut_lambda_np'
17+
'strut_np'
18+
'strut_lambda_np',
19+
'strut_lambda_np2'
1620
# 'strut_hi'
1721
]
18-
labels = [
19-
'relab',
20-
'$SER$',
21-
'$STRUT$',
22-
'$SER_{NP}$',
23-
'$SER_{NP}(\lambda)$',
24-
'$STRUT_{ND}$',
25-
'$STRUT(\lambda)$',
26-
'$STRUT_{NP}(\lambda)$'
27-
# 'STRUT$^{*}$',
28-
#'STRUT$^{*}$',
29-
]
22+
3023

3124
def test_transfer_tree():
3225

3326
np.random.seed(0)
3427

35-
plot_step = 0.01
3628
# Generate training source data
3729
ns = 200
3830
ns_perclass = ns // 2
@@ -65,12 +57,15 @@ def test_transfer_tree():
6557
yt_test[nt_test_perclass:] = 1
6658

6759
# Source classifier
68-
clf_source = DecisionTreeClassifier(max_depth=None)
69-
clf_source.fit(Xs, ys)
70-
score_src_src = clf_source.score(Xs, ys)
71-
score_src_trgt = clf_source.score(Xt_test, yt_test)
72-
print('Training score Source model: {:.3f}'.format(score_src_src))
73-
print('Testing score Source model: {:.3f}'.format(score_src_trgt))
60+
RF_SIZE = 10
61+
clf_source_dt = DecisionTreeClassifier(max_depth=None)
62+
clf_source_rf = RandomForestClassifier(n_estimators=RF_SIZE)
63+
clf_source_dt.fit(Xs, ys)
64+
clf_source_rf.fit(Xs, ys)
65+
#score_src_src = clf_source.score(Xs, ys)
66+
#score_src_trgt = clf_source.score(Xt_test, yt_test)
67+
#print('Training score Source model: {:.3f}'.format(score_src_src))
68+
#print('Testing score Source model: {:.3f}'.format(score_src_trgt))
7469
clfs = []
7570
scores = []
7671
# Transfer with SER
@@ -79,7 +74,7 @@ def test_transfer_tree():
7974

8075
for method in methods:
8176
Nkmin = sum(yt == 0 )
82-
root_source_values = clf_source.tree_.value[0].reshape(-1)
77+
root_source_values = clf_source_dt.tree_.value[0].reshape(-1)
8378
props_s = root_source_values
8479
props_s = props_s / sum(props_s)
8580
props_t = np.zeros(props_s.size)
@@ -88,43 +83,105 @@ def test_transfer_tree():
8883

8984
coeffs = np.divide(props_t, props_s)
9085

91-
clf_transfer = copy.deepcopy(clf_source)
86+
clf_transfer_dt = copy.deepcopy(clf_source_dt)
87+
clf_transfer_rf = copy.deepcopy(clf_source_rf)
88+
9289
if method == 'relab':
93-
transferred_dt = TransferTreeClassifier(estimator=clf_transfer,algo="")
90+
#decision tree
91+
transferred_dt = TransferTreeClassifier(estimator=clf_transfer_dt,algo="")
9492
transferred_dt.fit(Xt,yt)
93+
#random forest
94+
transferred_rf = TransferForestClassifier(estimator=clf_transfer_rf,algo="",bootstrap=True)
95+
transferred_rf.fit(Xt,yt)
9596
if method == 'ser':
96-
transferred_dt = TransferTreeClassifier(estimator=clf_transfer,algo="ser")
97+
#decision tree
98+
transferred_dt = TransferTreeClassifier(estimator=clf_transfer_dt,algo="ser",max_depth=10)
9799
transferred_dt.fit(Xt,yt)
98-
#transferred_dt._ser(Xt, yt, node=0, original_ser=True)
99-
#ser.SER(0, clf_transfer, Xt, yt, original_ser=True)
100+
#random forest
101+
transferred_rf = TransferForestClassifier(estimator=clf_transfer_rf,algo="ser")
102+
transferred_rf.fit(Xt,yt)
100103
if method == 'ser_nr':
101-
transferred_dt = TransferTreeClassifier(estimator=clf_transfer,algo="ser")
104+
#decision tree
105+
transferred_dt = TransferTreeClassifier(estimator=clf_transfer_dt,algo="ser")
102106
transferred_dt._ser(Xt, yt,node=0,original_ser=False,no_red_on_cl=True,cl_no_red=[0])
107+
#random forest
108+
transferred_rf = TransferForestClassifier(estimator=clf_transfer_rf,algo="ser")
109+
transferred_rf._ser_rf(Xt, yt,original_ser=False,no_red_on_cl=True,cl_no_red=[0])
110+
if method == 'ser_no_ext':
111+
#decision tree
112+
transferred_dt = TransferTreeClassifier(estimator=clf_transfer_dt,algo="ser")
113+
transferred_dt._ser(Xt, yt,node=0,original_ser=False,no_ext_on_cl=True,cl_no_red=[0],ext_cond=True)
114+
#random forest
115+
transferred_rf = TransferForestClassifier(estimator=clf_transfer_rf,algo="ser")
116+
transferred_rf._ser_rf(Xt, yt,original_ser=False,no_ext_on_cl=True,cl_no_ext=[0],ext_cond=True)
103117
if method == 'ser_nr_lambda':
104-
transferred_dt = TransferTreeClassifier(estimator=clf_transfer,algo="ser")
118+
#decision tree
119+
transferred_dt = TransferTreeClassifier(estimator=clf_transfer_dt,algo="ser")
105120
transferred_dt._ser(Xt, yt,node=0,original_ser=False,no_red_on_cl=True,cl_no_red=[0],
106121
leaf_loss_quantify=True,leaf_loss_threshold=0.5,
107122
root_source_values=root_source_values,Nkmin=Nkmin,coeffs=coeffs)
108-
#ser.SER(0, clf_transfer, Xt, yt,original_ser=False,no_red_on_cl=True,cl_no_red=[0],ext_cond=True)
123+
#random forest
124+
transferred_rf = TransferForestClassifier(estimator=clf_transfer_rf,algo="ser")
125+
transferred_rf._ser_rf(Xt, yt,original_ser=False,no_red_on_cl=True,cl_no_red=[0],
126+
leaf_loss_quantify=True,leaf_loss_threshold=0.5,
127+
root_source_values=root_source_values,Nkmin=Nkmin,coeffs=coeffs)
109128
if method == 'strut':
110-
transferred_dt = TransferTreeClassifier(estimator=clf_transfer,algo="strut")
129+
#decision tree
130+
transferred_dt = TransferTreeClassifier(estimator=clf_transfer_dt,algo="strut")
111131
transferred_dt.fit(Xt,yt)
112-
#transferred_dt._strut(Xt, yt,node=0)
132+
#random forest
133+
transferred_rf = TransferForestClassifier(estimator=clf_transfer_rf,algo="strut")
134+
transferred_rf.fit(Xt,yt)
113135
if method == 'strut_nd':
114-
transferred_dt = TransferTreeClassifier(estimator=clf_transfer,algo="strut")
136+
#decision tree
137+
transferred_dt = TransferTreeClassifier(estimator=clf_transfer_rf,algo="strut")
115138
transferred_dt._strut(Xt, yt,node=0,use_divergence=False)
139+
#random forest
140+
transferred_rf = TransferForestClassifier(estimator=clf_transfer_rf,algo="strut")
141+
transferred_rf._strut_rf(Xt, yt,use_divergence=False)
116142
if method == 'strut_lambda':
117-
transferred_dt = TransferTreeClassifier(estimator=clf_transfer,algo="strut")
143+
#decision tree
144+
transferred_dt = TransferTreeClassifier(estimator=clf_transfer_dt,algo="strut")
118145
transferred_dt._strut(Xt, yt,node=0,adapt_prop=True,root_source_values=root_source_values,
119146
Nkmin=Nkmin,coeffs=coeffs)
147+
#random forest
148+
transferred_rf = TransferForestClassifier(estimator=clf_transfer_rf,algo="strut")
149+
transferred_rf._strut_rf(Xt, yt,adapt_prop=True,root_source_values=root_source_values,
150+
Nkmin=Nkmin,coeffs=coeffs)
151+
if method == 'strut_np':
152+
#decision tree
153+
transferred_dt = TransferTreeClassifier(estimator=clf_transfer_dt,algo="strut")
154+
transferred_dt._strut(Xt, yt,node=0,adapt_prop=False,no_prune_on_cl=True,cl_no_prune=[0],
155+
leaf_loss_quantify=False,leaf_loss_threshold=0.5,no_prune_with_translation=False,
156+
root_source_values=root_source_values,Nkmin=Nkmin,coeffs=coeffs)
157+
#random forest
158+
transferred_rf = TransferForestClassifier(estimator=clf_transfer_rf,algo="strut")
159+
transferred_rf._strut(Xt, yt,adapt_prop=False,no_prune_on_cl=True,cl_no_prune=[0],
160+
leaf_loss_quantify=False,leaf_loss_threshold=0.5,no_prune_with_translation=False,
161+
root_source_values=root_source_values,Nkmin=Nkmin,coeffs=coeffs)
120162
if method == 'strut_lambda_np':
121-
transferred_dt = TransferTreeClassifier(estimator=clf_transfer,algo="strut")
163+
#decision tree
164+
transferred_dt = TransferTreeClassifier(estimator=clf_transfer_dt,algo="strut")
165+
transferred_dt._strut(Xt, yt,node=0,adapt_prop=False,no_prune_on_cl=True,cl_no_prune=[0],
166+
leaf_loss_quantify=False,leaf_loss_threshold=0.5,no_prune_with_translation=False,
167+
root_source_values=root_source_values,Nkmin=Nkmin,coeffs=coeffs)
168+
#random forest
169+
transferred_rf = TransferForestClassifier(estimator=clf_transfer_rf,algo="strut")
170+
transferred_rf._strut(Xt, yt,adapt_prop=True,no_prune_on_cl=True,cl_no_prune=[0],
171+
leaf_loss_quantify=True,leaf_loss_threshold=0.5,no_prune_with_translation=False,
172+
root_source_values=root_source_values,Nkmin=Nkmin,coeffs=coeffs)
173+
if method == 'strut_lambda_np2':
174+
#decision tree
175+
transferred_dt = TransferTreeClassifier(estimator=clf_transfer_dt,algo="strut")
122176
transferred_dt._strut(Xt, yt,node=0,adapt_prop=False,no_prune_on_cl=True,cl_no_prune=[0],
123177
leaf_loss_quantify=False,leaf_loss_threshold=0.5,no_prune_with_translation=False,
124178
root_source_values=root_source_values,Nkmin=Nkmin,coeffs=coeffs)
125-
#if method == 'strut_hi':
126-
#transferred_dt._strut(Xt, yt,node=0,no_prune_on_cl=False,adapt_prop=True,coeffs=[0.2, 1])
127-
#strut.STRUT(clf_transfer, 0, Xt, yt, Xt, yt,pruning_updated_node=True,no_prune_on_cl=False,adapt_prop=True,simple_weights=False,coeffs=[0.2, 1])
179+
#random forest
180+
transferred_rf = TransferForestClassifier(estimator=clf_transfer_rf,algo="strut")
181+
transferred_rf._strut(Xt, yt,adapt_prop=True,no_prune_on_cl=True,cl_no_prune=[0],
182+
leaf_loss_quantify=True,leaf_loss_threshold=0.5,no_prune_with_translation=True,
183+
root_source_values=root_source_values,Nkmin=Nkmin,coeffs=coeffs)
184+
128185
score = transferred_dt.estimator.score(Xt_test, yt_test)
129186
#score = clf_transfer.score(Xt_test, yt_test)
130187
print('Testing score transferred model ({}) : {:.3f}'.format(method, score))

tests/test_treeutils.py

Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import numpy as np
2+
from sklearn.tree import DecisionTreeClassifier
3+
from sklearn.ensemble import RandomForestClassifier
4+
5+
import adapt._tree_utils as ut
6+
7+
np.random.seed(0)
8+
9+
10+
# Generate training source data
11+
ns = 200
12+
ns_perclass = ns // 2
13+
mean_1 = (1, 1)
14+
var_1 = np.diag([1, 1])
15+
mean_2 = (3, 3)
16+
var_2 = np.diag([2, 2])
17+
Xs = np.r_[np.random.multivariate_normal(mean_1, var_1, size=ns_perclass),
18+
np.random.multivariate_normal(mean_2, var_2, size=ns_perclass)]
19+
ys = np.zeros(ns)
20+
ys[ns_perclass:] = 1
21+
# Generate training target data
22+
nt = 50
23+
# imbalanced
24+
nt_0 = nt // 10
25+
mean_1 = (6, 3)
26+
var_1 = np.diag([4, 1])
27+
mean_2 = (5, 5)
28+
var_2 = np.diag([1, 3])
29+
Xt = np.r_[np.random.multivariate_normal(mean_1, var_1, size=nt_0),
30+
np.random.multivariate_normal(mean_2, var_2, size=nt - nt_0)]
31+
yt = np.zeros(nt)
32+
yt[nt_0:] = 1
33+
# Generate testing target data
34+
nt_test = 1000
35+
nt_test_perclass = nt_test // 2
36+
Xt_test = np.r_[np.random.multivariate_normal(mean_1, var_1, size=nt_test_perclass),
37+
np.random.multivariate_normal(mean_2, var_2, size=nt_test_perclass)]
38+
yt_test = np.zeros(nt_test)
39+
yt_test[nt_test_perclass:] = 1
40+
41+
# Source classifier
42+
RF_SIZE = 10
43+
classes_test = [0,1]
44+
node_test = 5
45+
node_test2 = 4
46+
feats_test = np.array([0,1])
47+
values_test = np.array([5,10])
48+
clf_source_dt = DecisionTreeClassifier(max_depth=None)
49+
clf_source_rf = RandomForestClassifier(n_estimators=RF_SIZE)
50+
clf_source_dt.fit(Xs, ys)
51+
clf_source_rf.fit(Xs, ys)
52+
53+
def test_depth():
54+
ut.depth_tree(clf_source_dt)
55+
ut.depth_rf(clf_source_rf)
56+
ut.depth(clf_source_dt,node_test)
57+
ut.depth_array(clf_source_dt,np.arange(clf_source_dt.tree_.node_counte_count))
58+
59+
def test_rules():
60+
61+
ut.sub_nodes(clf_source_dt.tree_,node_test)
62+
parent,direction = ut.find_parent_vtree(clf_source_dt.tree_, node_test)
63+
parent,direction = ut.find_parent(clf_source_dt, node_test)
64+
p,t,b = ut.extract_rule_vtree(clf_source_dt.tree_,node_test)
65+
p,t,b = ut.extract_rule(clf_source_dt,node_test)
66+
p2,t2,b2 = ut.extract_rule(clf_source_dt,node_test2)
67+
68+
rule = p,t,b
69+
rule2 = p2,t2,b2
70+
split_0 = p[0],t[0]
71+
72+
ut.isinrule(rule, split_0)
73+
ut.isdisj_feat(p[0],t[0],p[1],t[1])
74+
ut.isdisj(rule,rule2)
75+
76+
ut.bounds_rule(rule,clf_source_dt.n_features_)
77+
78+
leaves,rules = ut.extract_leaves_rules(clf_source_dt)
79+
ut.add_to_parents(clf_source_dt, node_test, values_test)
80+
81+
def test_splits():
82+
leaves,rules = ut.extract_leaves_rules(clf_source_dt)
83+
p,t,b = ut.extract_rule(clf_source_dt,node_test)
84+
p2,t2,b2 = ut.extract_rule(clf_source_dt,node_test2)
85+
rule = p,t,b
86+
rule2 = p2,t2,b2
87+
88+
ut.coherent_new_split(p[1],t[1],rule2)
89+
ut.liste_non_coherent_splits(clf_source_dt,rule)
90+
91+
all_splits = np.zeros(clf_source_dt.tree_.node_count - leaves.size,dtype=[("phi",'<i8'),("th",'<f8')])
92+
coh_splits = ut.all_coherent_splits(rule,all_splits)
93+
s = coh_splits.size
94+
95+
ut.filter_feature(all_splits,feats_test)
96+
ut.new_random_split(np.ones(s)/s,coh_splits)
97+
98+
def test_error():
99+
e = ut.error(clf_source_dt,node_test)
100+
le = ut.leaf_error(clf_source_dt,node_test)
101+
return e,le
102+
103+
def test_distribution():
104+
ut.get_children_distributions(clf_source_dt,node_test)
105+
ut.get_node_distribution(clf_source_dt,node_test)
106+
ut.compute_class_distribution(classes_test,ys)
107+
108+
phi = clf_source_dt.tree_.feature[0]
109+
threshold = clf_source_dt.tree_.threshold[0]
110+
ut.compute_Q_children_target(Xs,ys,phi,threshold,classes_test)
111+
112+
def test_pruning_risk():
113+
ut.compute_LLR_estimates_homog(clf_source_dt)
114+
ut.contain_leaf_to_not_prune(clf_source_dt)
115+
116+
def test_divergence_computation():
117+
phi = clf_source_dt.tree_.feature[0]
118+
threshold = clf_source_dt.tree_.threshold[0]
119+
120+
Q_source_parent = ut.get_node_distribution(clf_source_dt,node_test)
121+
Q_source_l, Q_source_r = ut.get_children_distributions(clf_source_dt,node_test)
122+
Q_target_l, Q_target_r = ut.compute_Q_children_target(Xt,yt,phi,threshold,classes_test)
123+
124+
125+
ut.H(Q_source_parent)
126+
ut.GINI(Q_source_parent)
127+
ut.IG(Q_source_parent,[Q_target_l, Q_target_r])
128+
ut.DG(Q_source_l,Q_source_r,Q_target_l,Q_target_r)
129+
ut.JSD(Q_target_l, Q_source_l)
130+
131+
ut.KL_divergence(Q_source_l,Q_target_l)
132+
ut.threshold_selection(Q_source_parent,Q_source_l,Q_source_r,Xt,yt,phi,classes_test)
133+
134+
135+

0 commit comments

Comments
 (0)